package bun
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
const (
wherePKFlag internal .Flag = 1 << iota
forceDeleteFlag
deletedFlag
allWithDeletedFlag
)
type withQuery struct {
name string
query schema .QueryAppender
}
type IConn interface {
QueryContext (ctx context .Context , query string , args ...interface {}) (*sql .Rows , error )
ExecContext (ctx context .Context , query string , args ...interface {}) (sql .Result , error )
QueryRowContext (ctx context .Context , query string , args ...interface {}) *sql .Row
}
var (
_ IConn = (*sql .DB )(nil )
_ IConn = (*sql .Conn )(nil )
_ IConn = (*sql .Tx )(nil )
_ IConn = (*DB )(nil )
_ IConn = (*Conn )(nil )
_ IConn = (*Tx )(nil )
)
type IDB interface {
IConn
NewValues (model interface {}) *ValuesQuery
NewSelect () *SelectQuery
NewInsert () *InsertQuery
NewUpdate () *UpdateQuery
NewDelete () *DeleteQuery
NewCreateTable () *CreateTableQuery
NewDropTable () *DropTableQuery
NewCreateIndex () *CreateIndexQuery
NewDropIndex () *DropIndexQuery
NewTruncateTable () *TruncateTableQuery
NewAddColumn () *AddColumnQuery
NewDropColumn () *DropColumnQuery
}
var (
_ IConn = (*DB )(nil )
_ IConn = (*Conn )(nil )
_ IConn = (*Tx )(nil )
)
type baseQuery struct {
db *DB
conn IConn
model model
err error
tableModel tableModel
table *schema .Table
with []withQuery
modelTable schema .QueryWithArgs
tables []schema .QueryWithArgs
columns []schema .QueryWithArgs
flags internal .Flag
}
func (q *baseQuery ) DB () *DB {
return q .db
}
func (q *baseQuery ) GetModel () Model {
return q .model
}
func (q *baseQuery ) setConn (db IConn ) {
switch db := db .(type ) {
case *DB :
q .conn = db .DB
case Conn :
q .conn = db .Conn
case Tx :
q .conn = db .Tx
default :
q .conn = db
}
}
func (q *baseQuery ) setTableModel (modeli interface {}) {
model , err := newSingleModel (q .db , modeli )
if err != nil {
q .setErr (err )
return
}
q .model = model
if tm , ok := model .(tableModel ); ok {
q .tableModel = tm
q .table = tm .Table ()
}
}
func (q *baseQuery ) setErr (err error ) {
if q .err == nil {
q .err = err
}
}
func (q *baseQuery ) getModel (dest []interface {}) (model , error ) {
if len (dest ) == 0 {
if q .model != nil {
return q .model , nil
}
return nil , errNilModel
}
return newModel (q .db , dest )
}
func (q *baseQuery ) checkSoftDelete () error {
if q .table == nil {
return errors .New ("bun: can't use soft deletes without a table" )
}
if q .table .SoftDeleteField == nil {
return fmt .Errorf ("%s does not have a soft delete field" , q .table )
}
if q .tableModel == nil {
return errors .New ("bun: can't use soft deletes without a table model" )
}
return nil
}
func (q *baseQuery ) whereDeleted () {
if err := q .checkSoftDelete (); err != nil {
q .setErr (err )
return
}
q .flags = q .flags .Set (deletedFlag )
q .flags = q .flags .Remove (allWithDeletedFlag )
}
func (q *baseQuery ) whereAllWithDeleted () {
if err := q .checkSoftDelete (); err != nil {
q .setErr (err )
return
}
q .flags = q .flags .Set (allWithDeletedFlag )
q .flags = q .flags .Remove (deletedFlag )
}
func (q *baseQuery ) isSoftDelete () bool {
if q .table != nil {
return q .table .SoftDeleteField != nil && !q .flags .Has (allWithDeletedFlag )
}
return false
}
func (q *baseQuery ) addWith (name string , query schema .QueryAppender ) {
q .with = append (q .with , withQuery {
name : name ,
query : query ,
})
}
func (q *baseQuery ) appendWith (fmter schema .Formatter , b []byte ) (_ []byte , err error ) {
if len (q .with ) == 0 {
return b , nil
}
b = append (b , "WITH " ...)
for i , with := range q .with {
if i > 0 {
b = append (b , ", " ...)
}
b = fmter .AppendIdent (b , with .name )
if q , ok := with .query .(schema .ColumnsAppender ); ok {
b = append (b , " (" ...)
b , err = q .AppendColumns (fmter , b )
if err != nil {
return nil , err
}
b = append (b , ")" ...)
}
b = append (b , " AS (" ...)
b , err = with .query .AppendQuery (fmter , b )
if err != nil {
return nil , err
}
b = append (b , ')' )
}
b = append (b , ' ' )
return b , nil
}
func (q *baseQuery ) addTable (table schema .QueryWithArgs ) {
q .tables = append (q .tables , table )
}
func (q *baseQuery ) addColumn (column schema .QueryWithArgs ) {
q .columns = append (q .columns , column )
}
func (q *baseQuery ) excludeColumn (columns []string ) {
if q .columns == nil {
for _ , f := range q .table .Fields {
q .columns = append (q .columns , schema .UnsafeIdent (f .Name ))
}
}
if len (columns ) == 1 && columns [0 ] == "*" {
q .columns = make ([]schema .QueryWithArgs , 0 )
return
}
for _ , column := range columns {
if !q ._excludeColumn (column ) {
q .setErr (fmt .Errorf ("bun: can't find column=%q" , column ))
return
}
}
}
func (q *baseQuery ) _excludeColumn (column string ) bool {
for i , col := range q .columns {
if col .Args == nil && col .Query == column {
q .columns = append (q .columns [:i ], q .columns [i +1 :]...)
return true
}
}
return false
}
func (q *baseQuery ) modelHasTableName () bool {
if !q .modelTable .IsZero () {
return q .modelTable .Query != ""
}
return q .table != nil
}
func (q *baseQuery ) hasTables () bool {
return q .modelHasTableName () || len (q .tables ) > 0
}
func (q *baseQuery ) appendTables (
fmter schema .Formatter , b []byte ,
) (_ []byte , err error ) {
return q ._appendTables (fmter , b , false )
}
func (q *baseQuery ) appendTablesWithAlias (
fmter schema .Formatter , b []byte ,
) (_ []byte , err error ) {
return q ._appendTables (fmter , b , true )
}
func (q *baseQuery ) _appendTables (
fmter schema .Formatter , b []byte , withAlias bool ,
) (_ []byte , err error ) {
startLen := len (b )
if q .modelHasTableName () {
if !q .modelTable .IsZero () {
b , err = q .modelTable .AppendQuery (fmter , b )
if err != nil {
return nil , err
}
} else {
b = fmter .AppendQuery (b , string (q .table .SQLNameForSelects ))
if withAlias && q .table .SQLAlias != q .table .SQLNameForSelects {
b = append (b , " AS " ...)
b = append (b , q .table .SQLAlias ...)
}
}
}
for _ , table := range q .tables {
if len (b ) > startLen {
b = append (b , ", " ...)
}
b , err = table .AppendQuery (fmter , b )
if err != nil {
return nil , err
}
}
return b , nil
}
func (q *baseQuery ) appendFirstTable (fmter schema .Formatter , b []byte ) ([]byte , error ) {
return q ._appendFirstTable (fmter , b , false )
}
func (q *baseQuery ) appendFirstTableWithAlias (
fmter schema .Formatter , b []byte ,
) ([]byte , error ) {
return q ._appendFirstTable (fmter , b , true )
}
func (q *baseQuery ) _appendFirstTable (
fmter schema .Formatter , b []byte , withAlias bool ,
) ([]byte , error ) {
if !q .modelTable .IsZero () {
return q .modelTable .AppendQuery (fmter , b )
}
if q .table != nil {
b = fmter .AppendQuery (b , string (q .table .SQLName ))
if withAlias {
b = append (b , " AS " ...)
b = append (b , q .table .SQLAlias ...)
}
return b , nil
}
if len (q .tables ) > 0 {
return q .tables [0 ].AppendQuery (fmter , b )
}
return nil , errors .New ("bun: query does not have a table" )
}
func (q *baseQuery ) hasMultiTables () bool {
if q .modelHasTableName () {
return len (q .tables ) >= 1
}
return len (q .tables ) >= 2
}
func (q *baseQuery ) appendOtherTables (fmter schema .Formatter , b []byte ) (_ []byte , err error ) {
tables := q .tables
if !q .modelHasTableName () {
tables = tables [1 :]
}
for i , table := range tables {
if i > 0 {
b = append (b , ", " ...)
}
b , err = table .AppendQuery (fmter , b )
if err != nil {
return nil , err
}
}
return b , nil
}
func (q *baseQuery ) appendColumns (fmter schema .Formatter , b []byte ) (_ []byte , err error ) {
for i , f := range q .columns {
if i > 0 {
b = append (b , ", " ...)
}
b , err = f .AppendQuery (fmter , b )
if err != nil {
return nil , err
}
}
return b , nil
}
func (q *baseQuery ) getFields () ([]*schema .Field , error ) {
if len (q .columns ) == 0 {
return q .table .Fields , nil
}
return q ._getFields (false )
}
func (q *baseQuery ) getDataFields () ([]*schema .Field , error ) {
if len (q .columns ) == 0 {
return q .table .DataFields , nil
}
return q ._getFields (true )
}
func (q *baseQuery ) _getFields (omitPK bool ) ([]*schema .Field , error ) {
fields := make ([]*schema .Field , 0 , len (q .columns ))
for _ , col := range q .columns {
if col .Args != nil {
continue
}
field , err := q .table .Field (col .Query )
if err != nil {
return nil , err
}
if omitPK && field .IsPK {
continue
}
fields = append (fields , field )
}
return fields , nil
}
func (q *baseQuery ) scan (
ctx context .Context ,
queryApp schema .Query ,
query string ,
model model ,
hasDest bool ,
) (sql .Result , error ) {
ctx , event := q .db .beforeQuery (ctx , queryApp , query , nil )
rows , err := q .conn .QueryContext (ctx , query )
if err != nil {
q .db .afterQuery (ctx , event , nil , err )
return nil , err
}
defer rows .Close ()
numRow , err := model .ScanRows (ctx , rows )
if err != nil {
q .db .afterQuery (ctx , event , nil , err )
return nil , err
}
if numRow == 0 && hasDest && isSingleRowModel (model ) {
err = sql .ErrNoRows
}
res := driver .RowsAffected (numRow )
q .db .afterQuery (ctx , event , res , err )
return res , err
}
func (q *baseQuery ) exec (
ctx context .Context ,
queryApp schema .Query ,
query string ,
) (sql .Result , error ) {
ctx , event := q .db .beforeQuery (ctx , queryApp , query , nil )
res , err := q .conn .ExecContext (ctx , query )
if err != nil {
q .db .afterQuery (ctx , event , nil , err )
return res , err
}
q .db .afterQuery (ctx , event , res , err )
return res , nil
}
func (q *baseQuery ) AppendNamedArg (fmter schema .Formatter , b []byte , name string ) ([]byte , bool ) {
if q .table == nil {
return b , false
}
if m , ok := q .tableModel .(*structTableModel ); ok {
if b , ok := m .AppendNamedArg (fmter , b , name ); ok {
return b , ok
}
}
switch name {
case "TableName" :
b = fmter .AppendQuery (b , string (q .table .SQLName ))
return b , true
case "TableAlias" :
b = fmter .AppendQuery (b , string (q .table .SQLAlias ))
return b , true
case "PKs" :
b = appendColumns (b , "" , q .table .PKs )
return b , true
case "TablePKs" :
b = appendColumns (b , q .table .SQLAlias , q .table .PKs )
return b , true
case "Columns" :
b = appendColumns (b , "" , q .table .Fields )
return b , true
case "TableColumns" :
b = appendColumns (b , q .table .SQLAlias , q .table .Fields )
return b , true
}
return b , false
}
func appendColumns (b []byte , table schema .Safe , fields []*schema .Field ) []byte {
for i , f := range fields {
if i > 0 {
b = append (b , ", " ...)
}
if len (table ) > 0 {
b = append (b , table ...)
b = append (b , '.' )
}
b = append (b , f .SQLName ...)
}
return b
}
func formatterWithModel (fmter schema .Formatter , model schema .NamedArgAppender ) schema .Formatter {
if fmter .IsNop () {
return fmter
}
return fmter .WithArg (model )
}
type whereBaseQuery struct {
baseQuery
where []schema .QueryWithSep
}
func (q *whereBaseQuery ) addWhere (where schema .QueryWithSep ) {
q .where = append (q .where , where )
}
func (q *whereBaseQuery ) addWhereGroup (sep string , where []schema .QueryWithSep ) {
if len (where ) == 0 {
return
}
q .addWhere (schema .SafeQueryWithSep ("" , nil , sep ))
q .addWhere (schema .SafeQueryWithSep ("" , nil , "(" ))
where [0 ].Sep = ""
q .where = append (q .where , where ...)
q .addWhere (schema .SafeQueryWithSep ("" , nil , ")" ))
}
func (q *whereBaseQuery ) mustAppendWhere (
fmter schema .Formatter , b []byte , withAlias bool ,
) ([]byte , error ) {
if len (q .where ) == 0 && !q .flags .Has (wherePKFlag ) {
err := errors .New ("bun: Update and Delete queries require at least one Where" )
return nil , err
}
return q .appendWhere (fmter , b , withAlias )
}
func (q *whereBaseQuery ) appendWhere (
fmter schema .Formatter , b []byte , withAlias bool ,
) (_ []byte , err error ) {
if len (q .where ) == 0 && !q .isSoftDelete () && !q .flags .Has (wherePKFlag ) {
return b , nil
}
b = append (b , " WHERE " ...)
startLen := len (b )
if len (q .where ) > 0 {
b , err = appendWhere (fmter , b , q .where )
if err != nil {
return nil , err
}
}
if q .isSoftDelete () {
if len (b ) > startLen {
b = append (b , " AND " ...)
}
if withAlias {
b = append (b , q .tableModel .Table ().SQLAlias ...)
b = append (b , '.' )
}
b = append (b , q .tableModel .Table ().SoftDeleteField .SQLName ...)
if q .flags .Has (deletedFlag ) {
b = append (b , " IS NOT NULL" ...)
} else {
b = append (b , " IS NULL" ...)
}
}
if q .flags .Has (wherePKFlag ) {
if len (b ) > startLen {
b = append (b , " AND " ...)
}
b , err = q .appendWherePK (fmter , b , withAlias )
if err != nil {
return nil , err
}
}
return b , nil
}
func appendWhere (
fmter schema .Formatter , b []byte , where []schema .QueryWithSep ,
) (_ []byte , err error ) {
for i , where := range where {
if i > 0 {
b = append (b , where .Sep ...)
}
if where .Query == "" {
continue
}
b = append (b , '(' )
b , err = where .AppendQuery (fmter , b )
if err != nil {
return nil , err
}
b = append (b , ')' )
}
return b , nil
}
func (q *whereBaseQuery ) appendWherePK (
fmter schema .Formatter , b []byte , withAlias bool ,
) (_ []byte , err error ) {
if q .table == nil {
err := fmt .Errorf ("bun: got %T, but WherePK requires a struct or slice-based model" , q .model )
return nil , err
}
if err := q .table .CheckPKs (); err != nil {
return nil , err
}
switch model := q .tableModel .(type ) {
case *structTableModel :
return q .appendWherePKStruct (fmter , b , model , withAlias )
case *sliceTableModel :
return q .appendWherePKSlice (fmter , b , model , withAlias )
}
return nil , fmt .Errorf ("bun: WherePK does not support %T" , q .tableModel )
}
func (q *whereBaseQuery ) appendWherePKStruct (
fmter schema .Formatter , b []byte , model *structTableModel , withAlias bool ,
) (_ []byte , err error ) {
if !model .strct .IsValid () {
return nil , errNilModel
}
isTemplate := fmter .IsNop ()
b = append (b , '(' )
for i , f := range q .table .PKs {
if i > 0 {
b = append (b , " AND " ...)
}
if withAlias {
b = append (b , q .table .SQLAlias ...)
b = append (b , '.' )
}
b = append (b , f .SQLName ...)
b = append (b , " = " ...)
if isTemplate {
b = append (b , '?' )
} else {
b = f .AppendValue (fmter , b , model .strct )
}
}
b = append (b , ')' )
return b , nil
}
func (q *whereBaseQuery ) appendWherePKSlice (
fmter schema .Formatter , b []byte , model *sliceTableModel , withAlias bool ,
) (_ []byte , err error ) {
if len (q .table .PKs ) > 1 {
b = append (b , '(' )
}
if withAlias {
b = appendColumns (b , q .table .SQLAlias , q .table .PKs )
} else {
b = appendColumns (b , "" , q .table .PKs )
}
if len (q .table .PKs ) > 1 {
b = append (b , ')' )
}
b = append (b , " IN (" ...)
isTemplate := fmter .IsNop ()
slice := model .slice
sliceLen := slice .Len ()
for i := 0 ; i < sliceLen ; i ++ {
if i > 0 {
if isTemplate {
break
}
b = append (b , ", " ...)
}
el := indirect (slice .Index (i ))
if len (q .table .PKs ) > 1 {
b = append (b , '(' )
}
for i , f := range q .table .PKs {
if i > 0 {
b = append (b , ", " ...)
}
if isTemplate {
b = append (b , '?' )
} else {
b = f .AppendValue (fmter , b , el )
}
}
if len (q .table .PKs ) > 1 {
b = append (b , ')' )
}
}
b = append (b , ')' )
return b , nil
}
type returningQuery struct {
returning []schema .QueryWithArgs
returningFields []*schema .Field
}
func (q *returningQuery ) addReturning (ret schema .QueryWithArgs ) {
q .returning = append (q .returning , ret )
}
func (q *returningQuery ) addReturningField (field *schema .Field ) {
if len (q .returning ) > 0 {
return
}
for _ , f := range q .returningFields {
if f == field {
return
}
}
q .returningFields = append (q .returningFields , field )
}
func (q *returningQuery ) hasReturning () bool {
if len (q .returning ) == 1 {
switch q .returning [0 ].Query {
case "null" , "NULL" :
return false
}
}
return len (q .returning ) > 0 || len (q .returningFields ) > 0
}
func (q *returningQuery ) appendReturning (
fmter schema .Formatter , b []byte ,
) (_ []byte , err error ) {
if !q .hasReturning () {
return b , nil
}
b = append (b , " RETURNING " ...)
for i , f := range q .returning {
if i > 0 {
b = append (b , ", " ...)
}
b , err = f .AppendQuery (fmter , b )
if err != nil {
return nil , err
}
}
if len (q .returning ) > 0 {
return b , nil
}
b = appendColumns (b , "" , q .returningFields )
return b , nil
}
type columnValue struct {
column string
value schema .QueryWithArgs
}
type customValueQuery struct {
modelValues map [string ]schema .QueryWithArgs
extraValues []columnValue
}
func (q *customValueQuery ) addValue (
table *schema .Table , column string , value string , args []interface {},
) {
if _ , ok := table .FieldMap [column ]; ok {
if q .modelValues == nil {
q .modelValues = make (map [string ]schema .QueryWithArgs )
}
q .modelValues [column ] = schema .SafeQuery (value , args )
} else {
q .extraValues = append (q .extraValues , columnValue {
column : column ,
value : schema .SafeQuery (value , args ),
})
}
}
type setQuery struct {
set []schema .QueryWithArgs
}
func (q *setQuery ) addSet (set schema .QueryWithArgs ) {
q .set = append (q .set , set )
}
func (q setQuery ) appendSet (fmter schema .Formatter , b []byte ) (_ []byte , err error ) {
for i , f := range q .set {
if i > 0 {
b = append (b , ", " ...)
}
b , err = f .AppendQuery (fmter , b )
if err != nil {
return nil , err
}
}
return b , nil
}
type cascadeQuery struct {
restrict bool
}
func (q cascadeQuery ) appendCascade (fmter schema .Formatter , b []byte ) []byte {
if !fmter .HasFeature (feature .TableCascade ) {
return b
}
if q .restrict {
b = append (b , " RESTRICT" ...)
} else {
b = append (b , " CASCADE" ...)
}
return b
}
The pages are generated with Golds v0.3.6 . (GOOS=darwin GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds .