【背景介绍】
Gorm是go的一款orm的组件服务,核心就是把日常执行的SQL语句构建成对象,然后对象转换成sql去调用数据库驱动思路。
组件的特性: 对于数据库和表,及各种特性的全面封装,这样开起来使用的表达式会非常的轻巧。
同时也支持数据的常用及特殊场景的需求。 如常用的增删改成,事务等机制。 特殊的有迁移, 事件回调机制等。
话不多说,以下会分为几块来介绍下:
目录结构-> 了解下项目的结构层次,这样清楚核心的代码和非核心代码的分布,及功能模块主要集中在哪些文件
代码设计-> 采用了UML类图,把核心模块进行区域划分,同时也体现了各种依赖的关系
源码分析-> 基于代码设计的角度来对于源代码的解析,突出重要表达的内容。 尤其是设计的思路。
相关资料引用-> 更高的视角去了解gorm的应用价值和支持的业务场景,以及类似组件的选型对比。
一、目录结构
├── association.go
├── callbacks // 回调
│ ├── associations.go
│ ├── callbacks.go
│ ├── callmethod.go
│ ├── create.go // 创建
│ ├── delete.go // 删除
│ ├── helper.go // 辅助类
│ ├── interfaces.go
│ ├── preload.go
│ ├── query.go // 查询
│ ├── raw.go //
│ ├── row.go
│ ├── transaction.go // 事务
│ ├── update.go // 更新
├── callbacks.go // 回调
├── chainable_api.go
├── clause // 条件表达式
│ ├── clause.go
│ ├── delete.go
│ ├── expression.go // 表达式
│ ├── from.go // from 条件
│ ├── group_by.go // group by 条件
│ ├── insert.go // insert 条件
│ ├── joins.go // joins 连接条件
│ ├── limit.go // limit的限制条件
│ ├── locking.go // 锁条件
│ ├── on_conflict.go
│ ├── order_by.go // order by 条件
│ ├── returning.go
│ ├── select.go // select 的条件
│ ├── set.go
│ ├── update.go // 更新条件
│ ├── values.go //
│ ├── where.go // where 条件
│ └── with.go
├── errors.go
├── finisher_api.go
├── gorm.go // g orm 启动器
├── interfaces.go // 暴露的接口
├── logger // 日志
│ ├── logger.go
│ ├── sql.go
├── migrator // 迁移
│ ├── column_type.go // 列类型
│ └── migrator.go
├── migrator.go // 迁移
├── model.go // 模型层
├── prepare_stmt.go // 预表达式
├── scan.go
├── schema // 库
│ ├── check.go // 检查/约束
│ ├── field.go // 字段
│ ├── index.go // 索引
│ ├── interfaces.go // 接口层
│ ├── naming.go // 命名空间
│ ├── pool.go // 连接池
│ ├── relationship.go //关系
│ ├── schema.go // 库
│ ├── serializer.go // 序列化
│ ├── utils.go // 数据库工具
├── soft_delete.go // 逻辑删除
├── statement.go // 表达式
└── utils
├── utils.go // 工具服务
二、 代码设计
2.1 数据库配置和启动
2.2 Statment的设计
2.3 数据库协议的设计
2.4 表和列的设计:支持增删改成的操作
2.5 字段及其特性
2.6 各种查询表达式的构建
2.7 事务和预表达式的设计
2.8 日志的处理
2.9 事件机制,如服务于事件回调/触发 + 库的迁移
三、 源码分析
DB-初始化
-
初始化DB链接,并且建立连接池
-
使用DB
-
DB事务管理
-
ORM 逻辑处理
- statement.go
// DB returns `*sql.DB`
func (db *DB) DB() (*sql.DB, error) {
connPool := db.ConnPool
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn()
}
if sqldb, ok := connPool.(*sql.DB); ok {
return sqldb, nil
}
return nil, ErrInvalidDB
}
// DB GORM DB definition
type DB struct {
*Config
Error error
RowsAffected int64
Statement *Statement
clone int
}
// 数据库核心配置
// Config GORM config
type Config struct {
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
// You can disable it by setting `SkipDefaultTransaction` to true
SkipDefaultTransaction bool
// NamingStrategy tables, columns naming strategy
NamingStrategy schema.Namer
// FullSaveAssociations full save associations
FullSaveAssociations bool
// Logger
Logger logger.Interface
// NowFunc the function to be used when creating a new timestamp
NowFunc func() time.Time
// DryRun generate sql without execute
DryRun bool
// PrepareStmt executes the given query in cached statement
PrepareStmt bool
// DisableAutomaticPing
DisableAutomaticPing bool
// DisableForeignKeyConstraintWhenMigrating
DisableForeignKeyConstraintWhenMigrating bool
// DisableNestedTransaction disable nested transaction
DisableNestedTransaction bool
// AllowGlobalUpdate allow global update
AllowGlobalUpdate bool
// QueryFields executes the SQL query with all fields of the table
QueryFields bool
// CreateBatchSize default create batch size
CreateBatchSize int
// ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder
// ConnPool db conn pool
ConnPool ConnPool
// Dialector database dialector
Dialector
// Plugins registered plugins
Plugins map[string]Plugin
callbacks *callbacks
cacheStore *sync.Map
}
// 构造表达式
func (db *DB) getInstance() *DB {
if db.clone > 0 {
tx := &DB{Config: db.Config, Error: db.Error}
if db.clone == 1 {
// clone with new statement
tx.Statement = &Statement{
DB: tx,
ConnPool: db.Statement.ConnPool,
Context: db.Statement.Context,
Clauses: map[string]clause.Clause{},
Vars: make([]interface{}, 0, 8),
}
} else {
// with clone statement
tx.Statement = db.Statement.clone()
tx.Statement.DB = tx
}
return tx
}
return db
}
// 建立连接加载数据库驱动
// Open initialize db session based on dialector
func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
config := &Config{}
sort.Slice(opts, func(i, j int) bool {
_, isConfig := opts[i].(*Config)
_, isConfig2 := opts[j].(*Config)
return isConfig && !isConfig2
})
for _, opt := range opts {
if opt != nil {
if applyErr := opt.Apply(config); applyErr != nil {
return nil, applyErr
}
defer func(opt Option) {
if errr := opt.AfterInitialize(db); errr != nil {
err = errr
}
}(opt)
}
}
if d, ok := dialector.(interface{ Apply(*Config) error }); ok {
if err = d.Apply(config); err != nil {
return
}
}
core-0.1-Statement-构建
- note
- statement.go
// Statement statement
type Statement struct {
*DB
TableExpr *clause.Expr
Table string
Model interface{}
Unscoped bool
Dest interface{}
ReflectValue reflect.Value
Clauses map[string]clause.Clause
BuildClauses []string
Distinct bool
Selects []string // selected columns
Omits []string // omit columns
Joins []join
Preloads map[string][]interface{}
Settings sync.Map
ConnPool ConnPool
Schema *schema.Schema
Context context.Context
RaiseErrorOnNotFound bool
SkipHooks bool
SQL strings.Builder
Vars []interface{}
CurDestIndex int
attrs []interface{}
assigns []interface{}
scopes []func(*DB) *DB
}
core-0.2-statement-条件语句的构建
-
- 根据sql 条件的语句来构建statement
- statement.go
// Build build sql with clauses names
func (stmt *Statement) Build(clauses ...string) {
var firstClauseWritten bool
for _, name := range clauses {
if c, ok := stmt.Clauses[name]; ok {
if firstClauseWritten {
stmt.WriteByte(' ')
}
firstClauseWritten = true
if b, ok := stmt.DB.ClauseBuilders[name]; ok {
b(c, stmt)
} else {
c.Build(stmt)
}
}
}
}
// 构造器来构造表达式
// Build build clause
func (c Clause) Build(builder Builder) {
if c.Builder != nil {
c.Builder(c, builder)
} else if c.Expression != nil {
if c.BeforeExpression != nil {
c.BeforeExpression.Build(builder)
builder.WriteByte(' ')
}
if c.Name != "" {
builder.WriteString(c.Name)
builder.WriteByte(' ')
}
if c.AfterNameExpression != nil {
c.AfterNameExpression.Build(builder)
builder.WriteByte(' ')
}
c.Expression.Build(builder)
if c.AfterExpression != nil {
builder.WriteByte(' ')
c.AfterExpression.Build(builder)
}
}
}
// 添加statement的 条件表达式
// Clauses Add clauses
func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
tx = db.getInstance()
var whereConds []interface{}
for _, cond := range conds {
if c, ok := cond.(clause.Interface); ok {
tx.Statement.AddClause(c)
} else if optimizer, ok := cond.(StatementModifier); ok {
optimizer.ModifyStatement(tx.Statement)
} else {
whereConds = append(whereConds, cond)
}
}
if len(whereConds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(whereConds[0], whereConds[1:]...)})
}
return
}
core-0.3-BuildCondition-构建表达式的条件,如where条件
- 核心思想是通过反射机制来构建where等各种表达式下
- statement.go
// BuildCondition build condition
func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []clause.Expression {
if s, ok := query.(string); ok {
// if it is a number, then treats it as primary key
if _, err := strconv.Atoi(s); err != nil {
if s == "" && len(args) == 0 {
return nil
}
if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) {
// looks like a where condition
return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
}
if len(args) > 0 && strings.Contains(s, "@") {
// looks like a named query
return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}}
}
if strings.Contains(strings.TrimSpace(s), " ") {
// looks like a where condition
return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
}
if len(args) == 1 {
return []clause.Expression{clause.Eq{Column: s, Value: args[0]}}
}
}
}
// 构建where条件
case *DB:
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
if len(where.Exprs) == 1 {
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
where.Exprs[0] = clause.AndConditions(orConds)
}
}
conds = append(conds, clause.And(where.Exprs...))
} else if cs.Expression != nil {
conds = append(conds, cs.Expression)
}
}
// 默认通过反射来构建
reflectValue := reflect.Indirect(reflect.ValueOf(arg))
for reflectValue.Kind() == reflect.Ptr {
reflectValue = reflectValue.Elem()
}
if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
selectedColumns := map[string]bool{}
if idx == 0 {
for _, v := range args[1:] {
if vs, ok := v.(string); ok {
selectedColumns[vs] = true
}
}
}
restricted := len(selectedColumns) != 0
switch reflectValue.Kind() {
case reflect.Struct:
for _, field := range s.Fields {
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
}
}
}
}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
for _, field := range s.Fields {
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
}
}
}
}
}
}
if restricted {
break
}
} else if !reflectValue.IsValid() {
stmt.AddError(ErrInvalidData)
} else if len(conds) == 0 {
if len(args) == 1 {
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
// optimize reflect value length
valueLen := reflectValue.Len()
values := make([]interface{}, valueLen)
for i := 0; i < valueLen; i++ {
values[i] = reflectValue.Index(i).Interface()
}
if len(values) > 0 {
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
}
return conds
}
}
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
}
SelectAndOmitColumns-查询和过滤类的信息
- note
- statement.go
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
results := map[string]bool{}
notRestricted := false
// select columns
for _, column := range stmt.Selects {
if stmt.Schema == nil {
results[column] = true
} else if column == "*" {
notRestricted = true
for _, dbName := range stmt.Schema.DBNames {
results[dbName] = true
}
} else if column == clause.Associations {
for _, rel := range stmt.Schema.Relationships.Relations {
results[rel.Name] = true
}
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
results[field.DBName] = true
} else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 2 {
results[matches[1]] = true
} else {
results[column] = true
}
}
// ....
}
core-0.5-SetColumn-填充查询的字段信息
- 构造查询的字段信息
- statement.go
// SetColumn set column's value
// stmt.SetColumn("Name", "jinzhu") // Hooks Method
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
if v, ok := stmt.Dest.(map[string]interface{}); ok {
v[name] = value
} else if v, ok := stmt.Dest.([]map[string]interface{}); ok {
for _, m := range v {
m[name] = value
}
} else if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(name); field != nil {
destValue := reflect.ValueOf(stmt.Dest)
for destValue.Kind() == reflect.Ptr {
destValue = destValue.Elem()
}
}
}
Schema-0.1-协议设计
- note
- schema.go
type Schema struct {
Name string
ModelType reflect.Type
Table string
PrioritizedPrimaryField *Field
DBNames []string
PrimaryFields []*Field
PrimaryFieldDBNames []string
Fields []*Field
FieldsByName map[string]*Field
FieldsByDBName map[string]*Field
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
Relationships Relationships
CreateClauses []clause.Interface
QueryClauses []clause.Interface
UpdateClauses []clause.Interface
DeleteClauses []clause.Interface
BeforeCreate, AfterCreate bool
BeforeUpdate, AfterUpdate bool
BeforeDelete, AfterDelete bool
BeforeSave, AfterSave bool
AfterFind bool
err error
initialized chan struct{}
namer Namer
cacheStore *sync.Map
}
Schema-0.2-解析字段和列的信息
- note
- schema.go
schema := &Schema{
Name: modelType.Name(),
ModelType: modelType,
Table: tableName,
FieldsByName: map[string]*Field{},
FieldsByDBName: map[string]*Field{},
Relationships: Relationships{Relations: map[string]*Relationship{}},
cacheStore: cacheStore,
namer: namer,
initialized: make(chan struct{}),
}
// When the schema initialization is completed, the channel will be closed
defer close(schema.initialized)
// Load exist schmema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
for i := 0; i < modelType.NumField(); i++ {
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {
schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...)
} else {
schema.Fields = append(schema.Fields, field)
}
}
}
for _, field := range schema.Fields {
if field.DBName == "" && field.DataType != "" {
field.DBName = namer.ColumnName(schema.Table, field.Name)
}
if field.DBName != "" {
// nonexistence or shortest path or first appear prioritized if has permission
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
schema.DBNames = append(schema.DBNames, field.DBName)
}
schema.FieldsByDBName[field.DBName] = field
schema.FieldsByName[field.Name] = field
if v != nil && v.PrimaryKey {
for idx, f := range schema.PrimaryFields {
if f == v {
schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
}
}
}
if field.PrimaryKey {
schema.PrimaryFields = append(schema.PrimaryFields, field)
}
}
}
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
schema.FieldsByName[field.Name] = field
}
field.setupValuerAndSetter()
}
Scchema-0.3-ParseField
- 解析字段的信息填充
- field.go
// ParseField parses reflect.StructField to Field
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
var (
err error
tagSetting = ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";")
)
field := &Field{
Name: fieldStruct.Name,
DBName: tagSetting["COLUMN"],
BindNames: []string{fieldStruct.Name},
FieldType: fieldStruct.Type,
IndirectFieldType: fieldStruct.Type,
StructField: fieldStruct,
Tag: fieldStruct.Tag,
TagSettings: tagSetting,
Schema: schema,
Creatable: true,
Updatable: true,
Readable: true,
PrimaryKey: utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"]),
AutoIncrement: utils.CheckTruth(tagSetting["AUTOINCREMENT"]),
HasDefaultValue: utils.CheckTruth(tagSetting["AUTOINCREMENT"]),
NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
Comment: tagSetting["COMMENT"],
AutoIncrementIncrement: 1,
}
for field.IndirectFieldType.Kind() == reflect.Ptr {
field.IndirectFieldType = field.IndirectFieldType.Elem()
}
fieldValue := reflect.New(field.IndirectFieldType)
// if field is valuer, used its value or first field as data type
valuer, isValuer := fieldValue.Interface().(driver.Valuer)
if isValuer {
if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok {
if v, err := valuer.Value(); reflect.ValueOf(v).IsValid() && err == nil {
fieldValue = reflect.ValueOf(v)
}
// Use the field struct's first field type as data type, e.g: use `string` for sql.NullString
var getRealFieldValue func(reflect.Value)
getRealFieldValue = func(v reflect.Value) {
var (
rv = reflect.Indirect(v)
rvType = rv.Type()
)
if rv.Kind() == reflect.Struct && !rvType.ConvertibleTo(TimeReflectType) {
for i := 0; i < rvType.NumField(); i++ {
for key, value := range ParseTagSetting(rvType.Field(i).Tag.Get("gorm"), ";") {
if _, ok := field.TagSettings[key]; !ok {
field.TagSettings[key] = value
}
}
}
for i := 0; i < rvType.NumField(); i++ {
newFieldType := rvType.Field(i).Type
for newFieldType.Kind() == reflect.Ptr {
newFieldType = newFieldType.Elem()
}
fieldValue = reflect.New(newFieldType)
if rvType != reflect.Indirect(fieldValue).Type() {
getRealFieldValue(fieldValue)
}
if fieldValue.IsValid() {
return
}
}
}
}
getRealFieldValue(fieldValue)
}
}