// Copyright 2016 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.

package sasl

import (
	
	
	
	
	
	
	

	
)

const (
	gs2HeaderCBSupport         = "p=tls-unique,"
	gs2HeaderNoServerCBSupport = "y,"
	gs2HeaderNoCBSupport       = "n,"
)

var (
	clientKeyInput = []byte("Client Key")
	serverKeyInput = []byte("Server Key")
)

// The number of random bytes to generate for a nonce.
const noncerandlen = 16

func ( string,  *Negotiator) ( []byte) {
	, ,  := .Credentials()
	switch {
	case .TLSState() == nil || !strings.HasSuffix(, "-PLUS"):
		// We do not support channel binding
		 = []byte(gs2HeaderNoCBSupport)
	case .State()&RemoteCB == RemoteCB:
		// We support channel binding and the server does too
		 = []byte(gs2HeaderCBSupport)
	case .State()&RemoteCB != RemoteCB:
		// We support channel binding but the server does not
		 = []byte(gs2HeaderNoServerCBSupport)
	}
	if len() > 0 {
		 = append(, []byte(`a=`)...)
		 = append(, ...)
	}
	 = append(, ',')
	return
}

func ( string,  func() hash.Hash) Mechanism {
	// BUG(ssw): We need a way to cache the SCRAM client and server key
	// calculations.
	return Mechanism{
		Name: ,
		Start: func( *Negotiator) (bool, []byte, interface{}, error) {
			, ,  := .Credentials()

			// Escape "=" and ",". This is mostly the same as bytes.Replace but
			// faster because we can do both replacements in a single pass.
			 := bytes.Count(, []byte{'='}) + bytes.Count(, []byte{','})
			 := make([]byte, len()+(*2))
			 := 0
			 := 0
			for  := 0;  < ; ++ {
				 := 
				 += bytes.IndexAny([:], "=,")
				 += copy([:], [:])
				switch [] {
				case '=':
					 += copy([:], "=3D")
				case ',':
					 += copy([:], "=2C")
				}
				 =  + 1
			}
			copy([:], [:])

			 := make([]byte, 5+len(.Nonce())+len())
			copy(, "n=")
			copy([2:], )
			copy([2+len():], ",r=")
			copy([5+len():], .Nonce())

			return true, append(getGS2Header(, ), ...), , nil
		},
		Next: func( *Negotiator,  []byte,  interface{}) ( bool,  []byte,  interface{},  error) {
			if  == nil || len() == 0 {
				return , , , ErrInvalidChallenge
			}

			if .State()&Receiving == Receiving {
				panic("not yet implemented")
			}
			return scramClientNext(, , , , )
		},
	}
}

func ( string,  func() hash.Hash,  *Negotiator,  []byte,  interface{}) ( bool,  []byte,  interface{},  error) {
	, ,  := .Credentials()
	 := .State()

	switch  & StepMask {
	case AuthTextSent:
		 := -1
		var ,  []byte
		for ,  := range bytes.Split(, []byte{','}) {
			if len() < 3 || (len() >= 2 && [1] != '=') {
				continue
			}
			switch [0] {
			case 'i':
				 := string(bytes.TrimRight([2:], "\x00"))

				if ,  = strconv.Atoi();  != nil {
					return
				}
			case 's':
				 = make([]byte, base64.StdEncoding.DecodedLen(len()-2))
				var  int
				,  = base64.StdEncoding.Decode(, [2:])
				 = [:]
				if  != nil {
					return
				}
			case 'r':
				 = [2:]
			case 'm':
				// RFC 5802:
				// m: This attribute is reserved for future extensibility.  In this
				// version of SCRAM, its presence in a client or a server message
				// MUST cause authentication failure when the attribute is parsed by
				// the other end.
				 = errors.New("Server sent reserved attribute `m'")
				return
			}
		}

		switch {
		case  < 0:
			 = errors.New("Iteration count is missing")
			return
		case  < 0:
			 = errors.New("Iteration count is invalid")
			return
		case  == nil || !bytes.HasPrefix(, .Nonce()):
			 = errors.New("Server nonce does not match client nonce")
			return
		case  == nil:
			 = errors.New("Server sent empty salt")
			return
		}

		 := getGS2Header(, )
		 := .TLSState()
		var  []byte
		if  != nil && strings.HasSuffix(, "-PLUS") {
			 = make(
				[]byte,
				2+base64.StdEncoding.EncodedLen(len()+len(.TLSUnique)),
			)
			base64.StdEncoding.Encode([2:], append(, .TLSUnique...))
			[0] = 'c'
			[1] = '='
		} else {
			 = make(
				[]byte,
				2+base64.StdEncoding.EncodedLen(len()),
			)
			base64.StdEncoding.Encode([2:], )
			[0] = 'c'
			[1] = '='
		}
		 := append(, []byte(",r=")...)
		 = append(, ...)

		 := .([]byte)
		 := append(, ',')
		 = append(, ...)
		 = append(, ',')
		 = append(, ...)

		 := pbkdf2.Key(, , , ().Size(), )

		 := hmac.New(, )
		.Write(serverKeyInput)
		 := .Sum(nil)
		.Reset()

		.Write(clientKeyInput)
		 := .Sum(nil)

		 = hmac.New(, )
		.Write()
		 := .Sum(nil)

		 = ()
		.Write()
		 := .Sum(nil)
		 = hmac.New(, )
		.Write()
		 := .Sum(nil)
		 := make([]byte, len())
		xorBytes(, , )

		 := make([]byte, base64.StdEncoding.EncodedLen(len()))
		base64.StdEncoding.Encode(, )
		 := append(, []byte(",p=")...)
		 = append(, ...)

		return true, , , nil
	case ResponseSent:
		 := "v=" + base64.StdEncoding.EncodeToString(.([]byte))
		if  != string() {
			 = ErrAuthn
			return
		}
		// Success!
		return false, nil, nil, nil
	}
	 = ErrInvalidState
	return
}