package pgdriver

import (
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	

	
)

// https://www.postgresql.org/docs/current/protocol-message-formats.html
//nolint:deadcode,varcheck,unused
const (
	commandCompleteMsg  = 'C'
	errorResponseMsg    = 'E'
	noticeResponseMsg   = 'N'
	parameterStatusMsg  = 'S'
	authenticationOKMsg = 'R'
	backendKeyDataMsg   = 'K'
	noDataMsg           = 'n'
	passwordMessageMsg  = 'p'
	terminateMsg        = 'X'

	saslInitialResponseMsg        = 'p'
	authenticationSASLContinueMsg = 'R'
	saslResponseMsg               = 'p'
	authenticationSASLFinalMsg    = 'R'

	authenticationOK                = 0
	authenticationCleartextPassword = 3
	authenticationMD5Password       = 5
	authenticationSASL              = 10

	notificationResponseMsg = 'A'

	describeMsg             = 'D'
	parameterDescriptionMsg = 't'

	queryMsg              = 'Q'
	readyForQueryMsg      = 'Z'
	emptyQueryResponseMsg = 'I'
	rowDescriptionMsg     = 'T'
	dataRowMsg            = 'D'

	parseMsg         = 'P'
	parseCompleteMsg = '1'

	bindMsg         = 'B'
	bindCompleteMsg = '2'

	executeMsg = 'E'

	syncMsg  = 'S'
	flushMsg = 'H'

	closeMsg         = 'C'
	closeCompleteMsg = '3'

	copyInResponseMsg  = 'G'
	copyOutResponseMsg = 'H'
	copyDataMsg        = 'd'
	copyDoneMsg        = 'c'
)

var errEmptyQuery = errors.New("pgdriver: query is empty")

type reader struct {
	*bufio.Reader
	buf []byte
}

func ( io.Reader) *reader {
	return &reader{
		Reader: bufio.NewReader(),
		buf:    make([]byte, 128),
	}
}

func ( *reader) ( int) ([]byte, error) {
	if  <= len(.buf) {
		 := .buf[:]
		,  := io.ReadFull(.Reader, )
		return , 
	}

	 := make([]byte, )
	,  := io.ReadFull(.Reader, )
	return , 
}

func ( *reader) ( int) error {
	,  := .ReadTemp()
	return 
}

func ( context.Context,  *Conn,  *tls.Config) error {
	if  := writeSSLMsg(, );  != nil {
		return 
	}

	 := .reader(, -1)

	,  := .ReadByte()
	if  != nil {
		return 
	}
	if  != 'S' {
		return errors.New("pgdriver: SSL is not enabled on the server")
	}

	.netConn = tls.Client(.netConn, )
	.Reset(.netConn)

	return nil
}

func ( context.Context,  *Conn) error {
	 := getWriteBuffer()
	defer putWriteBuffer()

	.StartMessage(0)
	.WriteInt32(80877103)
	.FinishMessage()

	return .write(, )
}

//------------------------------------------------------------------------------

func ( context.Context,  *Conn) error {
	if  := writeStartup(, );  != nil {
		return 
	}

	 := .reader(, -1)

	for {
		, ,  := readMessageType()
		if  != nil {
			return 
		}

		switch  {
		case backendKeyDataMsg:
			,  := readInt32()
			if  != nil {
				return 
			}
			,  := readInt32()
			if  != nil {
				return 
			}
			.processID = 
			.secretKey = 
		case authenticationOKMsg:
			if  := auth(, , );  != nil {
				return 
			}
		case readyForQueryMsg:
			return .Discard()
		case parameterStatusMsg, noticeResponseMsg:
			if  := .Discard();  != nil {
				return 
			}
		case errorResponseMsg:
			,  := readError()
			if  != nil {
				return 
			}
			return 
		default:
			return fmt.Errorf("pgdriver: unexpected startup message: %q", )
		}
	}
}

func ( context.Context,  *Conn) error {
	 := getWriteBuffer()
	defer putWriteBuffer()

	.StartMessage(0)
	.WriteInt32(196608)
	.WriteString("user")
	.WriteString(.driver.cfg.User)
	.WriteString("database")
	.WriteString(.driver.cfg.Database)
	if .driver.cfg.AppName != "" {
		.WriteString("application_name")
		.WriteString(.driver.cfg.AppName)
	}
	.WriteString("")
	.FinishMessage()

	return .write(, )
}

//------------------------------------------------------------------------------

func ( context.Context,  *Conn,  *reader) error {
	,  := readInt32()
	if  != nil {
		return 
	}

	switch  {
	case authenticationOK:
		return nil
	case authenticationCleartextPassword:
		return authCleartext(, , )
	case authenticationMD5Password:
		return authMD5(, , )
	case authenticationSASL:
		if  := authSASL(, , );  != nil {
			return fmt.Errorf("pgdriver: SASL: %w", )
		}
		return nil
	default:
		return fmt.Errorf("pgdriver: unknown authentication message: %q", )
	}
}

func ( context.Context,  *Conn,  *reader) error {
	if  := writePassword(, , .driver.cfg.Password);  != nil {
		return 
	}
	return readAuthOK(, )
}

func ( *Conn,  *reader) error {
	, ,  := readMessageType()
	if  != nil {
		return 
	}

	switch  {
	case authenticationOKMsg:
		,  := readInt32()
		if  != nil {
			return 
		}
		if  != 0 {
			return fmt.Errorf("pgdriver: unexpected authentication code: %q", )
		}
		return nil
	case errorResponseMsg:
		,  := readError()
		if  != nil {
			return 
		}
		return 
	default:
		return fmt.Errorf("pgdriver: unknown password message: %q", )
	}
}

//------------------------------------------------------------------------------

func ( context.Context,  *Conn,  *reader) error {
	,  := .ReadTemp(4)
	if  != nil {
		return 
	}

	 := "md5" + md5s(md5s(.driver.cfg.Password+.driver.cfg.User)+string())
	if  := writePassword(, , );  != nil {
		return 
	}

	return readAuthOK(, )
}

func ( context.Context,  *Conn,  string) error {
	 := getWriteBuffer()
	defer putWriteBuffer()

	.StartMessage(passwordMessageMsg)
	.WriteString()
	.FinishMessage()

	return .write(, )
}

func ( string) string {
	 := md5.Sum([]byte())
	return hex.EncodeToString([:])
}

//------------------------------------------------------------------------------

func ( context.Context,  *Conn,  *reader) error {
	var  sasl.Mechanism

:
	for {
		,  := readString()
		if  != nil {
			return 
		}

		switch  {
		case "":
			break 
		case sasl.ScramSha256.Name:
			 = sasl.ScramSha256
		case sasl.ScramSha256Plus.Name:
			// ignore
		default:
			return fmt.Errorf("got %q, wanted %q", , sasl.ScramSha256.Name)
		}
	}

	 := sasl.Credentials(func() (, ,  []byte) {
		return []byte(.driver.cfg.User), []byte(.driver.cfg.Password), nil
	})
	 := sasl.NewClient(, )

	, ,  := .Step(nil)
	if  != nil {
		return fmt.Errorf("client.Step 1 failed: %w", )
	}

	if  := saslWriteInitialResponse(, , , );  != nil {
		return 
	}

	, ,  := readMessageType()
	if  != nil {
		return 
	}

	switch  {
	case authenticationSASLContinueMsg:
		,  := readInt32()
		if  != nil {
			return 
		}
		if  != 11 {
			return fmt.Errorf("got %q, wanted %q", , 11)
		}

		,  := .ReadTemp( - 4)
		if  != nil {
			return 
		}

		_, ,  = .Step()
		if  != nil {
			return fmt.Errorf("client.Step 2 failed: %w", )
		}

		if  := saslWriteResponse(, , );  != nil {
			return 
		}

		,  = saslReadAuthFinal(, )
		if  != nil {
			return 
		}

		if , ,  := .Step();  != nil {
			return fmt.Errorf("client.Step 3 failed: %w", )
		}

		if .State() != sasl.ValidServerResponse {
			return fmt.Errorf("got state=%q, wanted %q", .State(), sasl.ValidServerResponse)
		}

		return nil
	case errorResponseMsg:
		,  := readError()
		if  != nil {
			return 
		}
		return 
	default:
		return fmt.Errorf("got %q, wanted %q", , authenticationSASLContinueMsg)
	}
}

func (
	 context.Context,  *Conn,  sasl.Mechanism,  []byte,
) error {
	 := getWriteBuffer()
	defer putWriteBuffer()

	.StartMessage(saslInitialResponseMsg)
	.WriteString(.Name)
	.WriteInt32(int32(len()))
	if ,  := .Write();  != nil {
		return 
	}
	.FinishMessage()

	return .write(, )
}

func ( context.Context,  *Conn,  []byte) error {
	 := getWriteBuffer()
	defer putWriteBuffer()

	.StartMessage(saslResponseMsg)
	if ,  := .Write();  != nil {
		return 
	}
	.FinishMessage()

	return .write(, )
}

func ( *Conn,  *reader) ([]byte, error) {
	, ,  := readMessageType()
	if  != nil {
		return nil, 
	}

	switch  {
	case authenticationSASLFinalMsg:
		,  := readInt32()
		if  != nil {
			return nil, 
		}
		if  != 12 {
			return nil, fmt.Errorf("got %q, wanted %q", , 12)
		}

		 := make([]byte, -4)
		if ,  := io.ReadFull(, );  != nil {
			return nil, 
		}

		if  := readAuthOK(, );  != nil {
			return nil, 
		}

		return , nil
	case errorResponseMsg:
		,  := readError()
		if  != nil {
			return nil, 
		}
		return nil, 
	default:
		return nil, fmt.Errorf("got %q, wanted %q", , authenticationSASLFinalMsg)
	}
}

//------------------------------------------------------------------------------

func ( context.Context,  *Conn,  string) error {
	 := getWriteBuffer()
	defer putWriteBuffer()

	.StartMessage(queryMsg)
	.WriteString()
	.FinishMessage()

	return .write(, )
}

func ( context.Context,  *Conn) (sql.Result, error) {
	 := .reader(, -1)

	var  driver.Result
	var  error
	for {
		, ,  := readMessageType()
		if  != nil {
			return nil, 
		}

		switch  {
		case errorResponseMsg:
			,  := readError()
			if  != nil {
				return nil, 
			}
			if  == nil {
				 = 
			}
		case emptyQueryResponseMsg:
			if  == nil {
				 = errEmptyQuery
			}
		case commandCompleteMsg:
			,  := .ReadTemp()
			if  != nil {
				 = 
				break
			}

			,  := parseResult()
			if  != nil {
				 = 
			} else {
				 = 
			}
		case describeMsg,
			rowDescriptionMsg,
			noticeResponseMsg,
			parameterStatusMsg:
			if  := .Discard();  != nil {
				return nil, 
			}
		case readyForQueryMsg:
			if  := .Discard();  != nil {
				return nil, 
			}
			return , 
		default:
			return nil, fmt.Errorf("pgdriver: Exec: unexpected message %q", )
		}
	}
}

func ( context.Context,  *Conn) (*rows, error) {
	 := .reader(, -1)
	var  error
	for {
		, ,  := readMessageType()
		if  != nil {
			return nil, 
		}

		switch  {
		case rowDescriptionMsg:
			,  := readRowDescription()
			if  != nil {
				return nil, 
			}
			return newRows(, , true), nil
		case commandCompleteMsg:
			if  := .Discard();  != nil {
				return nil, 
			}
		case readyForQueryMsg:
			if  := .Discard();  != nil {
				return nil, 
			}
			if  != nil {
				return nil, 
			}
			return &rows{closed: true}, nil
		case errorResponseMsg:
			,  := readError()
			if  != nil {
				return nil, 
			}
			if  == nil {
				 = 
			}
		case emptyQueryResponseMsg:
			if  == nil {
				 = errEmptyQuery
			}
		case noticeResponseMsg, parameterStatusMsg:
			if  := .Discard();  != nil {
				return nil, 
			}
		default:
			return nil, fmt.Errorf("pgdriver: newRows: unexpected message %q", )
		}
	}
}

//------------------------------------------------------------------------------

var rowDescPool sync.Pool

type rowDescription struct {
	buf      []byte
	names    []string
	types    []int32
	numInput int16
}

func ( int) *rowDescription {
	if  < 16 {
		 = 16
	}
	return &rowDescription{
		buf:      make([]byte, 0, 16*),
		names:    make([]string, 0, ),
		types:    make([]int32, 0, ),
		numInput: -1,
	}
}

func ( *rowDescription) ( int) {
	.buf = make([]byte, 0, 16*)
	.names = .names[:0]
	.types = .types[:0]
	.numInput = -1
}

func ( *rowDescription) ( []byte) {
	if len(.buf)+len() > cap(.buf) {
		.buf = make([]byte, 0, cap(.buf))
	}

	 := len(.buf)
	.buf = append(.buf, ...)
	.names = append(.names, bytesToString(.buf[:]))
}

func ( *rowDescription) ( int32) {
	.types = append(.types, )
}

func ( *reader) (*rowDescription, error) {
	,  := readInt16()
	if  != nil {
		return nil, 
	}

	,  := rowDescPool.Get().(*rowDescription)
	if ! {
		 = newRowDescription(int())
	} else {
		.reset(int())
	}

	for  := 0;  < int(); ++ {
		,  := .ReadSlice(0)
		if  != nil {
			return nil, 
		}
		.addName([:len()-1])

		if ,  := .ReadTemp(6);  != nil {
			return nil, 
		}

		,  := readInt32()
		if  != nil {
			return nil, 
		}
		.addType()

		if ,  := .ReadTemp(8);  != nil {
			return nil, 
		}
	}

	return , nil
}

//------------------------------------------------------------------------------

func ( context.Context,  *reader) (,  string,  error) {
	for {
		, ,  := readMessageType()
		if  != nil {
			return "", "", 
		}

		switch  {
		case commandCompleteMsg, readyForQueryMsg, noticeResponseMsg:
			if  := .Discard();  != nil {
				return "", "", 
			}
		case errorResponseMsg:
			,  := readError()
			if  != nil {
				return "", "", 
			}
			return "", "", 
		case notificationResponseMsg:
			if  := .Discard(4);  != nil {
				return "", "", 
			}
			,  = readString()
			if  != nil {
				return "", "", 
			}
			,  = readString()
			if  != nil {
				return "", "", 
			}
			return , , nil
		default:
			return "", "", fmt.Errorf("pgdriver: readNotification: unexpected message %q", )
		}
	}
}

//------------------------------------------------------------------------------

func ( context.Context,  *Conn, ,  string) error {
	 := getWriteBuffer()
	defer putWriteBuffer()

	.StartMessage(parseMsg)
	.WriteString()
	.WriteString()
	.WriteInt16(0)
	.FinishMessage()

	.StartMessage(describeMsg)
	.WriteByte('S')
	.WriteString()
	.FinishMessage()

	.StartMessage(syncMsg)
	.FinishMessage()

	return .write(, )
}

func ( context.Context,  *Conn) (*rowDescription, error) {
	 := .reader(, -1)
	var  int16
	var  *rowDescription
	var  error
	for {
		, ,  := readMessageType()
		if  != nil {
			return nil, 
		}

		switch  {
		case parseCompleteMsg:
			if  := .Discard();  != nil {
				return nil, 
			}
		case rowDescriptionMsg: // response to DESCRIBE message.
			,  = readRowDescription()
			if  != nil {
				return nil, 
			}
			.numInput = 
		case parameterDescriptionMsg: // response to DESCRIBE message.
			,  = readInt16()
			if  != nil {
				return nil, 
			}

			for  := 0;  < int(); ++ {
				if ,  := readInt32();  != nil {
					return nil, 
				}
			}
		case noDataMsg: // response to DESCRIBE message.
			if  := .Discard();  != nil {
				return nil, 
			}
		case readyForQueryMsg:
			if  := .Discard();  != nil {
				return nil, 
			}
			if  != nil {
				return nil, 
			}
			return , 
		case errorResponseMsg:
			,  := readError()
			if  != nil {
				return nil, 
			}
			if  == nil {
				 = 
			}
		case noticeResponseMsg, parameterStatusMsg:
			if  := .Discard();  != nil {
				return nil, 
			}
		default:
			return nil, fmt.Errorf("pgdriver: readParseDescribeSync: unexpected message %q", )
		}
	}
}

func ( context.Context,  *Conn,  string,  []driver.NamedValue) error {
	 := getWriteBuffer()
	defer putWriteBuffer()

	.StartMessage(bindMsg)
	.WriteString("")
	.WriteString()
	.WriteInt16(0)
	.WriteInt16(int16(len()))
	for  := range  {
		.StartParam()
		,  := appendStmtArg(.Bytes, [].Value)
		if  != nil {
			return 
		}
		if  != nil {
			.Bytes = 
			.FinishParam()
		} else {
			.FinishNullParam()
		}
	}
	.WriteInt16(0)
	.FinishMessage()

	.StartMessage(executeMsg)
	.WriteString("")
	.WriteInt32(0)
	.FinishMessage()

	.StartMessage(syncMsg)
	.FinishMessage()

	return .write(, )
}

func ( context.Context,  *Conn) (driver.Result, error) {
	 := .reader(, -1)
	var  driver.Result
	var  error
	for {
		, ,  := readMessageType()
		if  != nil {
			return nil, 
		}

		switch  {
		case bindCompleteMsg, dataRowMsg:
			if  := .Discard();  != nil {
				return nil, 
			}
		case commandCompleteMsg: // response to EXECUTE message.
			,  := .ReadTemp()
			if  != nil {
				return nil, 
			}

			,  := parseResult()
			if  != nil {
				if  == nil {
					 = 
				}
			} else {
				 = 
			}
		case readyForQueryMsg: // Response to SYNC message.
			if  := .Discard();  != nil {
				return nil, 
			}
			if  != nil {
				return nil, 
			}
			return , nil
		case errorResponseMsg:
			,  := readError()
			if  != nil {
				return nil, 
			}
			if  == nil {
				 = 
			}
		case emptyQueryResponseMsg:
			if  == nil {
				 = errEmptyQuery
			}
		case noticeResponseMsg, parameterStatusMsg:
			if  := .Discard();  != nil {
				return nil, 
			}
		default:
			return nil, fmt.Errorf("pgdriver: readExtQuery: unexpected message %q", )
		}
	}
}

func ( context.Context,  *Conn,  *rowDescription) (*rows, error) {
	 := .reader(, -1)
	var  error
	for {
		, ,  := readMessageType()
		if  != nil {
			return nil, 
		}

		switch  {
		case bindCompleteMsg:
			if  := .Discard();  != nil {
				return nil, 
			}
			return newRows(, , false), nil
		case commandCompleteMsg: // response to EXECUTE message.
			if  := .Discard();  != nil {
				return nil, 
			}
		case readyForQueryMsg: // Response to SYNC message.
			if  := .Discard();  != nil {
				return nil, 
			}
			if  != nil {
				return nil, 
			}
			return &rows{closed: true}, nil
		case errorResponseMsg:
			,  := readError()
			if  != nil {
				return nil, 
			}
			if  == nil {
				 = 
			}
		case emptyQueryResponseMsg:
			if  == nil {
				 = errEmptyQuery
			}
		case noticeResponseMsg, parameterStatusMsg:
			if  := .Discard();  != nil {
				return nil, 
			}
		default:
			return nil, fmt.Errorf("pgdriver: readExtQueryData: unexpected message %q", )
		}
	}
}

func ( context.Context,  *Conn,  string) error {
	 := getWriteBuffer()
	defer putWriteBuffer()

	.StartMessage(closeMsg)
	.WriteByte('S') //nolint
	.WriteString()
	.FinishMessage()

	.StartMessage(flushMsg)
	.FinishMessage()

	return .write(, )
}

func ( context.Context,  *Conn) error {
	 := .reader(, -1)
	for {
		, ,  := readMessageType()
		if  != nil {
			return 
		}

		switch  {
		case closeCompleteMsg:
			return .Discard()
		case errorResponseMsg:
			,  := readError()
			if  != nil {
				return 
			}
			return 
		case noticeResponseMsg, parameterStatusMsg:
			if  := .Discard();  != nil {
				return 
			}
		default:
			return fmt.Errorf("pgdriver: readCloseCompleteMsg: unexpected message %q", )
		}
	}
}

//------------------------------------------------------------------------------

func ( *reader) (byte, int, error) {
	,  := .ReadByte()
	if  != nil {
		return 0, 0, 
	}
	,  := readInt32()
	if  != nil {
		return 0, 0, 
	}
	return , int() - 4, nil
}

func ( *reader) (int16, error) {
	,  := .ReadTemp(2)
	if  != nil {
		return 0, 
	}
	return int16(binary.BigEndian.Uint16()), nil
}

func ( *reader) (int32, error) {
	,  := .ReadTemp(4)
	if  != nil {
		return 0, 
	}
	return int32(binary.BigEndian.Uint32()), nil
}

func ( *reader) (string, error) {
	,  := .ReadSlice(0)
	if  != nil {
		return "", 
	}
	return string([:len()-1]), nil
}

func ( *reader) (error, error) {
	 := make(map[byte]string)
	for {
		,  := .ReadByte()
		if  != nil {
			return nil, 
		}
		if  == 0 {
			break
		}
		,  := readString()
		if  != nil {
			return nil, 
		}
		[] = 
	}
	return Error{m: }, nil
}

//------------------------------------------------------------------------------

func ( []byte,  driver.Value) ([]byte, error) {
	switch v := .(type) {
	case nil:
		return nil, nil
	case int64:
		return strconv.AppendInt(, , 10), nil
	case float64:
		switch {
		case math.IsNaN():
			return append(, "NaN"...), nil
		case math.IsInf(, 1):
			return append(, "Infinity"...), nil
		case math.IsInf(, -1):
			return append(, "-Infinity"...), nil
		default:
			return strconv.AppendFloat(, , 'f', -1, 64), nil
		}
	case bool:
		if  {
			return append(, "TRUE"...), nil
		}
		return append(, "FALSE"...), nil
	case []byte:
		if  == nil {
			return nil, nil
		}

		 = append(, `\x`...)

		 := len()
		 = append(, make([]byte, hex.EncodedLen(len()))...)
		hex.Encode([:], )

		return , nil
	case string:
		for ,  := range  {
			if  == 0 {
				continue
			}
			if  < utf8.RuneSelf {
				 = append(, byte())
				continue
			}
			 := len()
			if cap()- < utf8.UTFMax {
				 = append(, make([]byte, utf8.UTFMax)...)
			}
			 := utf8.EncodeRune([:+utf8.UTFMax], )
			 = [:+]
		}
		return , nil
	case time.Time:
		if .IsZero() {
			return nil, nil
		}
		return .UTC().AppendFormat(, "2006-01-02 15:04:05.999999-07:00"), nil
	default:
		return nil, fmt.Errorf("pgdriver: unexpected arg: %T", )
	}
}