package bun
import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"
"time"
"github.com/uptrace/bun/schema"
)
type structTableModel struct {
db *DB
table *schema .Table
rel *schema .Relation
joins []relationJoin
dest interface {}
root reflect .Value
index []int
strct reflect .Value
structInited bool
structInitErr error
columns []string
scanIndex int
}
var _ tableModel = (*structTableModel )(nil )
func newStructTableModel (db *DB , dest interface {}, table *schema .Table ) *structTableModel {
return &structTableModel {
db : db ,
table : table ,
dest : dest ,
}
}
func newStructTableModelValue (db *DB , dest interface {}, v reflect .Value ) *structTableModel {
return &structTableModel {
db : db ,
table : db .Table (v .Type ()),
dest : dest ,
root : v ,
strct : v ,
}
}
func (m *structTableModel ) Value () interface {} {
return m .dest
}
func (m *structTableModel ) Table () *schema .Table {
return m .table
}
func (m *structTableModel ) Relation () *schema .Relation {
return m .rel
}
func (m *structTableModel ) Root () reflect .Value {
return m .root
}
func (m *structTableModel ) Index () []int {
return m .index
}
func (m *structTableModel ) ParentIndex () []int {
return m .index [:len (m .index )-len (m .rel .Field .Index )]
}
func (m *structTableModel ) Mount (host reflect .Value ) {
m .strct = host .FieldByIndex (m .rel .Field .Index )
m .structInited = false
}
func (m *structTableModel ) initStruct () error {
if m .structInited {
return m .structInitErr
}
m .structInited = true
switch m .strct .Kind () {
case reflect .Invalid :
m .structInitErr = errNilModel
return m .structInitErr
case reflect .Interface :
m .strct = m .strct .Elem ()
}
if m .strct .Kind () == reflect .Ptr {
if m .strct .IsNil () {
m .strct .Set (reflect .New (m .strct .Type ().Elem ()))
m .strct = m .strct .Elem ()
} else {
m .strct = m .strct .Elem ()
}
}
m .mountJoins ()
return nil
}
func (m *structTableModel ) mountJoins () {
for i := range m .joins {
j := &m .joins [i ]
switch j .Relation .Type {
case schema .HasOneRelation , schema .BelongsToRelation :
j .JoinModel .Mount (m .strct )
}
}
}
var _ schema .BeforeScanHook = (*structTableModel )(nil )
func (m *structTableModel ) BeforeScan (ctx context .Context ) error {
if !m .table .HasBeforeScanHook () {
return nil
}
return callBeforeScanHook (ctx , m .strct .Addr ())
}
var _ schema .AfterScanHook = (*structTableModel )(nil )
func (m *structTableModel ) AfterScan (ctx context .Context ) error {
if !m .table .HasAfterScanHook () || !m .structInited {
return nil
}
var firstErr error
if err := callAfterScanHook (ctx , m .strct .Addr ()); err != nil && firstErr == nil {
firstErr = err
}
for _ , j := range m .joins {
switch j .Relation .Type {
case schema .HasOneRelation , schema .BelongsToRelation :
if err := j .JoinModel .AfterScan (ctx ); err != nil && firstErr == nil {
firstErr = err
}
}
}
return firstErr
}
func (m *structTableModel ) GetJoin (name string ) *relationJoin {
for i := range m .joins {
j := &m .joins [i ]
if j .Relation .Field .Name == name || j .Relation .Field .GoName == name {
return j
}
}
return nil
}
func (m *structTableModel ) GetJoins () []relationJoin {
return m .joins
}
func (m *structTableModel ) AddJoin (j relationJoin ) *relationJoin {
m .joins = append (m .joins , j )
return &m .joins [len (m .joins )-1 ]
}
func (m *structTableModel ) Join (name string ) *relationJoin {
return m .join (m .strct , name )
}
func (m *structTableModel ) join (bind reflect .Value , name string ) *relationJoin {
path := strings .Split (name , "." )
index := make ([]int , 0 , len (path ))
currJoin := relationJoin {
BaseModel : m ,
JoinModel : m ,
}
var lastJoin *relationJoin
for _ , name := range path {
relation , ok := currJoin .JoinModel .Table ().Relations [name ]
if !ok {
return nil
}
currJoin .Relation = relation
index = append (index , relation .Field .Index ...)
if j := currJoin .JoinModel .GetJoin (name ); j != nil {
currJoin .BaseModel = j .BaseModel
currJoin .JoinModel = j .JoinModel
lastJoin = j
} else {
model , err := newTableModelIndex (m .db , m .table , bind , index , relation )
if err != nil {
return nil
}
currJoin .Parent = lastJoin
currJoin .BaseModel = currJoin .JoinModel
currJoin .JoinModel = model
lastJoin = currJoin .BaseModel .AddJoin (currJoin )
}
}
return lastJoin
}
func (m *structTableModel ) updateSoftDeleteField (tm time .Time ) error {
if !m .strct .IsValid () {
return nil
}
fv := m .table .SoftDeleteField .Value (m .strct )
return m .table .UpdateSoftDeleteField (fv , tm )
}
func (m *structTableModel ) ScanRows (ctx context .Context , rows *sql .Rows ) (int , error ) {
if !rows .Next () {
return 0 , rows .Err ()
}
var n int
if err := m .ScanRow (ctx , rows ); err != nil {
return 0 , err
}
n ++
for rows .Next () {
n ++
}
if err := rows .Err (); err != nil {
return 0 , err
}
return n , nil
}
func (m *structTableModel ) ScanRow (ctx context .Context , rows *sql .Rows ) error {
columns , err := rows .Columns ()
if err != nil {
return err
}
m .columns = columns
dest := makeDest (m , len (columns ))
return m .scanRow (ctx , rows , dest )
}
func (m *structTableModel ) scanRow (ctx context .Context , rows *sql .Rows , dest []interface {}) error {
if err := m .BeforeScan (ctx ); err != nil {
return err
}
m .scanIndex = 0
if err := rows .Scan (dest ...); err != nil {
return err
}
if err := m .AfterScan (ctx ); err != nil {
return err
}
return nil
}
func (m *structTableModel ) Scan (src interface {}) error {
column := m .columns [m .scanIndex ]
m .scanIndex ++
return m .ScanColumn (unquote (column ), src )
}
func (m *structTableModel ) ScanColumn (column string , src interface {}) error {
if ok , err := m .scanColumn (column , src ); ok {
return err
}
if column == "" || column [0 ] == '_' || m .db .flags .Has (discardUnknownColumns ) {
return nil
}
return fmt .Errorf ("bun: %s does not have column %q" , m .table .TypeName , column )
}
func (m *structTableModel ) scanColumn (column string , src interface {}) (bool , error ) {
if src != nil {
if err := m .initStruct (); err != nil {
return true , err
}
}
if field , ok := m .table .FieldMap [column ]; ok {
if src == nil && m .isNil () {
return true , nil
}
return true , field .ScanValue (m .strct , src )
}
if joinName , column := splitColumn (column ); joinName != "" {
if join := m .GetJoin (joinName ); join != nil {
return true , join .JoinModel .ScanColumn (column , src )
}
if m .table .ModelName == joinName {
return true , m .ScanColumn (column , src )
}
}
return false , nil
}
func (m *structTableModel ) isNil () bool {
return m .strct .Kind () == reflect .Ptr && m .strct .IsNil ()
}
func (m *structTableModel ) AppendNamedArg (
fmter schema .Formatter , b []byte , name string ,
) ([]byte , bool ) {
return m .table .AppendNamedArg (fmter , b , name , m .strct )
}
func unquote (s string ) string {
if s == "" {
return s
}
if s [0 ] == '"' && s [len (s )-1 ] == '"' {
return s [1 : len (s )-1 ]
}
return s
}
func splitColumn (s string ) (string , string ) {
if i := strings .Index (s , "__" ); i >= 0 {
return s [:i ], s [i +2 :]
}
return "" , s
}
The pages are generated with Golds v0.3.6 . (GOOS=darwin GOARCH=amd64)
Golds is a Go 101 project developed by Tapir Liu .
PR and bug reports are welcome and can be submitted to the issue list .
Please follow @Go100and1 (reachable from the left QR code) to get the latest news of Golds .