package bun
import (
"bytes"
"context"
"database/sql"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
type union struct {
expr string
query *SelectQuery
}
type SelectQuery struct {
whereBaseQuery
distinctOn []schema .QueryWithArgs
joins []joinQuery
group []schema .QueryWithArgs
having []schema .QueryWithArgs
order []schema .QueryWithArgs
limit int32
offset int32
selFor schema .QueryWithArgs
union []union
}
func NewSelectQuery (db *DB ) *SelectQuery {
return &SelectQuery {
whereBaseQuery : whereBaseQuery {
baseQuery : baseQuery {
db : db ,
conn : db .DB ,
},
},
}
}
func (q *SelectQuery ) Conn (db IConn ) *SelectQuery {
q .setConn (db )
return q
}
func (q *SelectQuery ) Model (model interface {}) *SelectQuery {
q .setTableModel (model )
return q
}
func (q *SelectQuery ) Apply (fn func (*SelectQuery ) *SelectQuery ) *SelectQuery {
return fn (q )
}
func (q *SelectQuery ) With (name string , query schema .QueryAppender ) *SelectQuery {
q .addWith (name , query )
return q
}
func (q *SelectQuery ) Distinct () *SelectQuery {
q .distinctOn = make ([]schema .QueryWithArgs , 0 )
return q
}
func (q *SelectQuery ) DistinctOn (query string , args ...interface {}) *SelectQuery {
q .distinctOn = append (q .distinctOn , schema .SafeQuery (query , args ))
return q
}
func (q *SelectQuery ) Table (tables ...string ) *SelectQuery {
for _ , table := range tables {
q .addTable (schema .UnsafeIdent (table ))
}
return q
}
func (q *SelectQuery ) TableExpr (query string , args ...interface {}) *SelectQuery {
q .addTable (schema .SafeQuery (query , args ))
return q
}
func (q *SelectQuery ) ModelTableExpr (query string , args ...interface {}) *SelectQuery {
q .modelTable = schema .SafeQuery (query , args )
return q
}
func (q *SelectQuery ) Column (columns ...string ) *SelectQuery {
for _ , column := range columns {
q .addColumn (schema .UnsafeIdent (column ))
}
return q
}
func (q *SelectQuery ) ColumnExpr (query string , args ...interface {}) *SelectQuery {
q .addColumn (schema .SafeQuery (query , args ))
return q
}
func (q *SelectQuery ) ExcludeColumn (columns ...string ) *SelectQuery {
q .excludeColumn (columns )
return q
}
func (q *SelectQuery ) WherePK () *SelectQuery {
q .flags = q .flags .Set (wherePKFlag )
return q
}
func (q *SelectQuery ) Where (query string , args ...interface {}) *SelectQuery {
q .addWhere (schema .SafeQueryWithSep (query , args , " AND " ))
return q
}
func (q *SelectQuery ) WhereOr (query string , args ...interface {}) *SelectQuery {
q .addWhere (schema .SafeQueryWithSep (query , args , " OR " ))
return q
}
func (q *SelectQuery ) WhereGroup (sep string , fn func (*SelectQuery ) *SelectQuery ) *SelectQuery {
saved := q .where
q .where = nil
q = fn (q )
where := q .where
q .where = saved
q .addWhereGroup (sep , where )
return q
}
func (q *SelectQuery ) WhereDeleted () *SelectQuery {
q .whereDeleted ()
return q
}
func (q *SelectQuery ) WhereAllWithDeleted () *SelectQuery {
q .whereAllWithDeleted ()
return q
}
func (q *SelectQuery ) Group (columns ...string ) *SelectQuery {
for _ , column := range columns {
q .group = append (q .group , schema .UnsafeIdent (column ))
}
return q
}
func (q *SelectQuery ) GroupExpr (group string , args ...interface {}) *SelectQuery {
q .group = append (q .group , schema .SafeQuery (group , args ))
return q
}
func (q *SelectQuery ) Having (having string , args ...interface {}) *SelectQuery {
q .having = append (q .having , schema .SafeQuery (having , args ))
return q
}
func (q *SelectQuery ) Order (orders ...string ) *SelectQuery {
for _ , order := range orders {
if order == "" {
continue
}
index := strings .IndexByte (order , ' ' )
if index == -1 {
q .order = append (q .order , schema .UnsafeIdent (order ))
continue
}
field := order [:index ]
sort := order [index +1 :]
switch strings .ToUpper (sort ) {
case "ASC" , "DESC" , "ASC NULLS FIRST" , "DESC NULLS FIRST" ,
"ASC NULLS LAST" , "DESC NULLS LAST" :
q .order = append (q .order , schema .SafeQuery ("? ?" , []interface {}{
Ident (field ),
Safe (sort ),
}))
default :
q .order = append (q .order , schema .UnsafeIdent (order ))
}
}
return q
}
func (q *SelectQuery ) OrderExpr (query string , args ...interface {}) *SelectQuery {
q .order = append (q .order , schema .SafeQuery (query , args ))
return q
}
func (q *SelectQuery ) Limit (n int ) *SelectQuery {
q .limit = int32 (n )
return q
}
func (q *SelectQuery ) Offset (n int ) *SelectQuery {
q .offset = int32 (n )
return q
}
func (q *SelectQuery ) For (s string , args ...interface {}) *SelectQuery {
q .selFor = schema .SafeQuery (s , args )
return q
}
func (q *SelectQuery ) Union (other *SelectQuery ) *SelectQuery {
return q .addUnion (" UNION " , other )
}
func (q *SelectQuery ) UnionAll (other *SelectQuery ) *SelectQuery {
return q .addUnion (" UNION ALL " , other )
}
func (q *SelectQuery ) Intersect (other *SelectQuery ) *SelectQuery {
return q .addUnion (" INTERSECT " , other )
}
func (q *SelectQuery ) IntersectAll (other *SelectQuery ) *SelectQuery {
return q .addUnion (" INTERSECT ALL " , other )
}
func (q *SelectQuery ) Except (other *SelectQuery ) *SelectQuery {
return q .addUnion (" EXCEPT " , other )
}
func (q *SelectQuery ) ExceptAll (other *SelectQuery ) *SelectQuery {
return q .addUnion (" EXCEPT ALL " , other )
}
func (q *SelectQuery ) addUnion (expr string , other *SelectQuery ) *SelectQuery {
q .union = append (q .union , union {
expr : expr ,
query : other ,
})
return q
}
func (q *SelectQuery ) Join (join string , args ...interface {}) *SelectQuery {
q .joins = append (q .joins , joinQuery {
join : schema .SafeQuery (join , args ),
})
return q
}
func (q *SelectQuery ) JoinOn (cond string , args ...interface {}) *SelectQuery {
return q .joinOn (cond , args , " AND " )
}
func (q *SelectQuery ) JoinOnOr (cond string , args ...interface {}) *SelectQuery {
return q .joinOn (cond , args , " OR " )
}
func (q *SelectQuery ) joinOn (cond string , args []interface {}, sep string ) *SelectQuery {
if len (q .joins ) == 0 {
q .err = errors .New ("bun: query has no joins" )
return q
}
j := &q .joins [len (q .joins )-1 ]
j .on = append (j .on , schema .SafeQueryWithSep (cond , args , sep ))
return q
}
func (q *SelectQuery ) Relation (name string , apply ...func (*SelectQuery ) *SelectQuery ) *SelectQuery {
if len (apply ) > 1 {
panic ("only one apply function is supported" )
}
if q .tableModel == nil {
q .setErr (errNilModel )
return q
}
join := q .tableModel .Join (name )
if join == nil {
q .setErr (fmt .Errorf ("%s does not have relation=%q" , q .table , name ))
return q
}
if len (apply ) == 1 {
join .apply = apply [0 ]
}
return q
}
func (q *SelectQuery ) forEachHasOneJoin (fn func (*relationJoin ) error ) error {
if q .tableModel == nil {
return nil
}
return q ._forEachHasOneJoin (fn , q .tableModel .GetJoins ())
}
func (q *SelectQuery ) _forEachHasOneJoin (fn func (*relationJoin ) error , joins []relationJoin ) error {
for i := range joins {
j := &joins [i ]
switch j .Relation .Type {
case schema .HasOneRelation , schema .BelongsToRelation :
if err := fn (j ); err != nil {
return err
}
if err := q ._forEachHasOneJoin (fn , j .JoinModel .GetJoins ()); err != nil {
return err
}
}
}
return nil
}
func (q *SelectQuery ) selectJoins (ctx context .Context , joins []relationJoin ) error {
for i := range joins {
j := &joins [i ]
var err error
switch j .Relation .Type {
case schema .HasOneRelation , schema .BelongsToRelation :
err = q .selectJoins (ctx , j .JoinModel .GetJoins ())
case schema .HasManyRelation :
err = j .selectMany (ctx , q .db .NewSelect ())
case schema .ManyToManyRelation :
err = j .selectM2M (ctx , q .db .NewSelect ())
default :
panic ("not reached" )
}
if err != nil {
return err
}
}
return nil
}
func (q *SelectQuery ) Operation () string {
return "SELECT"
}
func (q *SelectQuery ) AppendQuery (fmter schema .Formatter , b []byte ) (_ []byte , err error ) {
return q .appendQuery (fmter , b , false )
}
func (q *SelectQuery ) appendQuery (
fmter schema .Formatter , b []byte , count bool ,
) (_ []byte , err error ) {
if q .err != nil {
return nil , q .err
}
fmter = formatterWithModel (fmter , q )
cteCount := count && (len (q .group ) > 0 || q .distinctOn != nil )
if cteCount {
b = append (b , "WITH _count_wrapper AS (" ...)
}
if len (q .union ) > 0 {
b = append (b , '(' )
}
b , err = q .appendWith (fmter , b )
if err != nil {
return nil , err
}
b = append (b , "SELECT " ...)
if len (q .distinctOn ) > 0 {
b = append (b , "DISTINCT ON (" ...)
for i , app := range q .distinctOn {
if i > 0 {
b = append (b , ", " ...)
}
b , err = app .AppendQuery (fmter , b )
if err != nil {
return nil , err
}
}
b = append (b , ") " ...)
} else if q .distinctOn != nil {
b = append (b , "DISTINCT " ...)
}
if count && !cteCount {
b = append (b , "count(*)" ...)
} else {
b , err = q .appendColumns (fmter , b )
if err != nil {
return nil , err
}
}
if q .hasTables () {
b , err = q .appendTables (fmter , b )
if err != nil {
return nil , err
}
}
if err := q .forEachHasOneJoin (func (j *relationJoin ) error {
b = append (b , ' ' )
b , err = j .appendHasOneJoin (fmter , b , q )
return err
}); err != nil {
return nil , err
}
for _ , j := range q .joins {
b , err = j .AppendQuery (fmter , b )
if err != nil {
return nil , err
}
}
b , err = q .appendWhere (fmter , b , true )
if err != nil {
return nil , err
}
if len (q .group ) > 0 {
b = append (b , " GROUP BY " ...)
for i , f := range q .group {
if i > 0 {
b = append (b , ", " ...)
}
b , err = f .AppendQuery (fmter , b )
if err != nil {
return nil , err
}
}
}
if len (q .having ) > 0 {
b = append (b , " HAVING " ...)
for i , f := range q .having {
if i > 0 {
b = append (b , " AND " ...)
}
b = append (b , '(' )
b , err = f .AppendQuery (fmter , b )
if err != nil {
return nil , err
}
b = append (b , ')' )
}
}
if !count {
b , err = q .appendOrder (fmter , b )
if err != nil {
return nil , err
}
if q .limit != 0 {
b = append (b , " LIMIT " ...)
b = strconv .AppendInt (b , int64 (q .limit ), 10 )
}
if q .offset != 0 {
b = append (b , " OFFSET " ...)
b = strconv .AppendInt (b , int64 (q .offset ), 10 )
}
if !q .selFor .IsZero () {
b = append (b , " FOR " ...)
b , err = q .selFor .AppendQuery (fmter , b )
if err != nil {
return nil , err
}
}
}
if len (q .union ) > 0 {
b = append (b , ')' )
for _ , u := range q .union {
b = append (b , u .expr ...)
b = append (b , '(' )
b , err = u .query .AppendQuery (fmter , b )
if err != nil {
return nil , err
}
b = append (b , ')' )
}
}
if cteCount {
b = append (b , ") SELECT count(*) FROM _count_wrapper" ...)
}
return b , nil
}
func (q *SelectQuery ) appendColumns (fmter schema .Formatter , b []byte ) (_ []byte , err error ) {
start := len (b )
switch {
case q .columns != nil :
for i , col := range q .columns {
if i > 0 {
b = append (b , ", " ...)
}
if col .Args == nil && q .table != nil {
if field , ok := q .table .FieldMap [col .Query ]; ok {
b = append (b , q .table .SQLAlias ...)
b = append (b , '.' )
b = append (b , field .SQLName ...)
continue
}
}
b , err = col .AppendQuery (fmter , b )
if err != nil {
return nil , err
}
}
case q .table != nil :
if len (q .table .Fields ) > 10 && fmter .IsNop () {
b = append (b , q .table .SQLAlias ...)
b = append (b , '.' )
b = dialect .AppendString (b , fmt .Sprintf ("%d columns" , len (q .table .Fields )))
} else {
b = appendColumns (b , q .table .SQLAlias , q .table .Fields )
}
default :
b = append (b , '*' )
}
if err := q .forEachHasOneJoin (func (join *relationJoin ) error {
if len (b ) != start {
b = append (b , ", " ...)
start = len (b )
}
b , err = q .appendHasOneColumns (fmter , b , join )
if err != nil {
return err
}
return nil
}); err != nil {
return nil , err
}
b = bytes .TrimSuffix (b , []byte (", " ))
return b , nil
}
func (q *SelectQuery ) appendHasOneColumns (
fmter schema .Formatter , b []byte , join *relationJoin ,
) (_ []byte , err error ) {
join .applyTo (q )
if join .columns != nil {
table := join .JoinModel .Table ()
for i , col := range join .columns {
if i > 0 {
b = append (b , ", " ...)
}
if col .Args == nil {
if field , ok := table .FieldMap [col .Query ]; ok {
b = join .appendAlias (fmter , b )
b = append (b , '.' )
b = append (b , field .SQLName ...)
b = append (b , " AS " ...)
b = join .appendAliasColumn (fmter , b , field .Name )
continue
}
}
b , err = col .AppendQuery (fmter , b )
if err != nil {
return nil , err
}
}
return b , nil
}
for i , field := range join .JoinModel .Table ().Fields {
if i > 0 {
b = append (b , ", " ...)
}
b = join .appendAlias (fmter , b )
b = append (b , '.' )
b = append (b , field .SQLName ...)
b = append (b , " AS " ...)
b = join .appendAliasColumn (fmter , b , field .Name )
}
return b , nil
}
func (q *SelectQuery ) appendTables (fmter schema .Formatter , b []byte ) (_ []byte , err error ) {
b = append (b , " FROM " ...)
return q .appendTablesWithAlias (fmter , b )
}
func (q *SelectQuery ) appendOrder (fmter schema .Formatter , b []byte ) (_ []byte , err error ) {
if len (q .order ) > 0 {
b = append (b , " ORDER BY " ...)
for i , f := range q .order {
if i > 0 {
b = append (b , ", " ...)
}
b , err = f .AppendQuery (fmter , b )
if err != nil {
return nil , err
}
}
return b , nil
}
return b , nil
}
func (q *SelectQuery ) Rows (ctx context .Context ) (*sql .Rows , error ) {
queryBytes , err := q .AppendQuery (q .db .fmter , q .db .makeQueryBytes ())
if err != nil {
return nil , err
}
query := internal .String (queryBytes )
return q .conn .QueryContext (ctx , query )
}
func (q *SelectQuery ) Exec (ctx context .Context ) (res sql .Result , err error ) {
queryBytes , err := q .AppendQuery (q .db .fmter , q .db .makeQueryBytes ())
if err != nil {
return nil , err
}
query := internal .String (queryBytes )
res , err = q .exec (ctx , q , query )
if err != nil {
return nil , err
}
return res , nil
}
func (q *SelectQuery ) Scan (ctx context .Context , dest ...interface {}) error {
model , err := q .getModel (dest )
if err != nil {
return err
}
if q .limit > 1 {
if model , ok := model .(interface { SetCap (int ) }); ok {
model .SetCap (int (q .limit ))
}
}
if q .table != nil {
if err := q .beforeSelectHook (ctx ); err != nil {
return err
}
}
queryBytes , err := q .AppendQuery (q .db .fmter , q .db .makeQueryBytes ())
if err != nil {
return err
}
query := internal .String (queryBytes )
res , err := q .scan (ctx , q , query , model , true )
if err != nil {
return err
}
if n , _ := res .RowsAffected (); n > 0 {
if tableModel , ok := model .(tableModel ); ok {
if err := q .selectJoins (ctx , tableModel .GetJoins ()); err != nil {
return err
}
}
}
if q .table != nil {
if err := q .afterSelectHook (ctx ); err != nil {
return err
}
}
return nil
}
func (q *SelectQuery ) beforeSelectHook (ctx context .Context ) error {
if hook , ok := q .table .ZeroIface .(BeforeSelectHook ); ok {
if err := hook .BeforeSelect (ctx , q ); err != nil {
return err
}
}
return nil
}
func (q *SelectQuery ) afterSelectHook (ctx context .Context ) error {
if hook , ok := q .table .ZeroIface .(AfterSelectHook ); ok {
if err := hook .AfterSelect (ctx , q ); err != nil {
return err
}
}
return nil
}
func (q *SelectQuery ) Count (ctx context .Context ) (int , error ) {
qq := countQuery {q }
queryBytes , err := qq .AppendQuery (q .db .fmter , nil )
if err != nil {
return 0 , err
}
query := internal .String (queryBytes )
ctx , event := q .db .beforeQuery (ctx , qq , query , nil )
var num int
err = q .conn .QueryRowContext (ctx , query ).Scan (&num )
q .db .afterQuery (ctx , event , nil , err )
return num , err
}
func (q *SelectQuery ) ScanAndCount (ctx context .Context , dest ...interface {}) (int , error ) {
var count int
var wg sync .WaitGroup
var mu sync .Mutex
var firstErr error
if q .limit >= 0 {
wg .Add (1 )
go func () {
defer wg .Done ()
if err := q .Scan (ctx , dest ...); err != nil {
mu .Lock ()
if firstErr == nil {
firstErr = err
}
mu .Unlock ()
}
}()
}
wg .Add (1 )
go func () {
defer wg .Done ()
var err error
count , err = q .Count (ctx )
if err != nil {
mu .Lock ()
if firstErr == nil {
firstErr = err
}
mu .Unlock ()
}
}()
wg .Wait ()
return count , firstErr
}
func (q *SelectQuery ) Exists (ctx context .Context ) (bool , error ) {
qq := existsQuery {q }
queryBytes , err := qq .AppendQuery (q .db .fmter , nil )
if err != nil {
return false , err
}
query := internal .String (queryBytes )
ctx , event := q .db .beforeQuery (ctx , qq , query , nil )
var exists bool
err = q .conn .QueryRowContext (ctx , query ).Scan (&exists )
q .db .afterQuery (ctx , event , nil , err )
return exists , err
}
type joinQuery struct {
join schema .QueryWithArgs
on []schema .QueryWithSep
}
func (j *joinQuery ) AppendQuery (fmter schema .Formatter , b []byte ) (_ []byte , err error ) {
b = append (b , ' ' )
b , err = j .join .AppendQuery (fmter , b )
if err != nil {
return nil , err
}
if len (j .on ) > 0 {
b = append (b , " ON " ...)
for i , on := range j .on {
if i > 0 {
b = append (b , on .Sep ...)
}
b = append (b , '(' )
b , err = on .AppendQuery (fmter , b )
if err != nil {
return nil , err
}
b = append (b , ')' )
}
}
return b , nil
}
type countQuery struct {
*SelectQuery
}
func (q countQuery ) AppendQuery (fmter schema .Formatter , b []byte ) (_ []byte , err error ) {
return q .appendQuery (fmter , b , true )
}
type existsQuery struct {
*SelectQuery
}
func (q existsQuery ) AppendQuery (fmter schema .Formatter , b []byte ) (_ []byte , err error ) {
b = append (b , "SELECT EXISTS (" ...)
b , err = q .appendQuery (fmter , b , false )
if err != nil {
return nil , err
}
b = append (b , ")" ...)
return b , nil
}
The pages are generated with Golds v0.3.6 . (GOOS=darwin GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds .