Gorm源码分析和UML设计

2,464 阅读6分钟

【背景介绍】

Gorm是go的一款orm的组件服务,核心就是把日常执行的SQL语句构建成对象,然后对象转换成sql去调用数据库驱动思路。

组件的特性: 对于数据库和表,及各种特性的全面封装,这样开起来使用的表达式会非常的轻巧。

同时也支持数据的常用及特殊场景的需求。 如常用的增删改成,事务等机制。 特殊的有迁移, 事件回调机制等。

话不多说,以下会分为几块来介绍下

目录结构-> 了解下项目的结构层次,这样清楚核心的代码和非核心代码的分布,及功能模块主要集中在哪些文件

代码设计-> 采用了UML类图,把核心模块进行区域划分,同时也体现了各种依赖的关系

源码分析-> 基于代码设计的角度来对于源代码的解析,突出重要表达的内容。 尤其是设计的思路。

相关资料引用-> 更高的视角去了解gorm的应用价值和支持的业务场景,以及类似组件的选型对比。

一、目录结构

├── association.go
├── callbacks				// 回调
│   ├── associations.go
│   ├── callbacks.go
│   ├── callmethod.go
│   ├── create.go			// 创建
│   ├── delete.go			// 删除
│   ├── helper.go			// 辅助类
│   ├── interfaces.go
│   ├── preload.go
│   ├── query.go			// 查询
│   ├── raw.go				// 
│   ├── row.go
│   ├── transaction.go // 事务
│   ├── update.go			 // 更新
├── callbacks.go			 // 回调
├── chainable_api.go
├── clause						// 条件表达式
│   ├── clause.go
│   ├── delete.go
│   ├── expression.go   // 表达式
│   ├── from.go				  // from 条件
│   ├── group_by.go		  // group by 条件
│   ├── insert.go				// insert 条件
│   ├── joins.go				// joins 连接条件
│   ├── limit.go			  // limit的限制条件
│   ├── locking.go			// 锁条件
│   ├── on_conflict.go
│   ├── order_by.go			// order by 条件
│   ├── returning.go
│   ├── select.go				// select 的条件
│   ├── set.go			
│   ├── update.go				// 更新条件
│   ├── values.go				//
│   ├── where.go				// where 条件
│   └── with.go			
├── errors.go
├── finisher_api.go
├── gorm.go						// g orm 启动器
├── interfaces.go			// 暴露的接口
├── logger						// 日志
│   ├── logger.go
│   ├── sql.go
├── migrator					// 迁移
│   ├── column_type.go // 列类型
│   └── migrator.go
├── migrator.go				// 迁移
├── model.go					// 模型层
├── prepare_stmt.go   // 预表达式
├── scan.go					
├── schema						// 库
│   ├── check.go      // 检查/约束
│   ├── field.go      // 字段
│   ├── index.go	    // 索引
│   ├── interfaces.go	// 接口层
│   ├── naming.go			// 命名空间
│   ├── pool.go				// 连接池
│   ├── relationship.go //关系
│   ├── schema.go       // 库
│   ├── serializer.go   // 序列化
│   ├── utils.go        // 数据库工具
├── soft_delete.go		// 逻辑删除
├── statement.go			// 表达式
└── utils
    ├── utils.go			// 工具服务

二、 代码设计

2.1 数据库配置和启动

image-20220407020339945

2.2 Statment的设计

image-20220407020528001

2.3 数据库协议的设计

image-20220407020639256

2.4 表和列的设计:支持增删改成的操作

image-20220407020839697

2.5 字段及其特性

image-20220407020949770

2.6 各种查询表达式的构建

image-20220407021048912

2.7 事务和预表达式的设计

image-20220407021205051

2.8 日志的处理

image-20220407021243074

2.9 事件机制,如服务于事件回调/触发 + 库的迁移

image-20220407021406470

三、 源码分析

DB-初始化

  1. 初始化DB链接,并且建立连接池

  2. 使用DB

  3. DB事务管理

  4. ORM 逻辑处理

  • statement.go
// DB returns `*sql.DB`
func (db *DB) DB() (*sql.DB, error) {
	connPool := db.ConnPool

	if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
		return dbConnector.GetDBConn()
	}

	if sqldb, ok := connPool.(*sql.DB); ok {
		return sqldb, nil
	}

	return nil, ErrInvalidDB
}
// DB GORM DB definition
type DB struct {
	*Config
	Error        error
	RowsAffected int64
	Statement    *Statement
	clone        int
}

// 数据库核心配置
// Config GORM config
type Config struct {
	// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
	// You can disable it by setting `SkipDefaultTransaction` to true
	SkipDefaultTransaction bool
	// NamingStrategy tables, columns naming strategy
	NamingStrategy schema.Namer
	// FullSaveAssociations full save associations
	FullSaveAssociations bool
	// Logger
	Logger logger.Interface
	// NowFunc the function to be used when creating a new timestamp
	NowFunc func() time.Time
	// DryRun generate sql without execute
	DryRun bool
	// PrepareStmt executes the given query in cached statement
	PrepareStmt bool
	// DisableAutomaticPing
	DisableAutomaticPing bool
	// DisableForeignKeyConstraintWhenMigrating
	DisableForeignKeyConstraintWhenMigrating bool
	// DisableNestedTransaction disable nested transaction
	DisableNestedTransaction bool
	// AllowGlobalUpdate allow global update
	AllowGlobalUpdate bool
	// QueryFields executes the SQL query with all fields of the table
	QueryFields bool
	// CreateBatchSize default create batch size
	CreateBatchSize int

	// ClauseBuilders clause builder
	ClauseBuilders map[string]clause.ClauseBuilder
	// ConnPool db conn pool
	ConnPool ConnPool
	// Dialector database dialector
	Dialector
	// Plugins registered plugins
	Plugins map[string]Plugin

	callbacks  *callbacks
	cacheStore *sync.Map
}

// 构造表达式

func (db *DB) getInstance() *DB {
	if db.clone > 0 {
		tx := &DB{Config: db.Config, Error: db.Error}

		if db.clone == 1 {
			// clone with new statement
			tx.Statement = &Statement{
				DB:       tx,
				ConnPool: db.Statement.ConnPool,
				Context:  db.Statement.Context,
				Clauses:  map[string]clause.Clause{},
				Vars:     make([]interface{}, 0, 8),
			}
		} else {
			// with clone statement
			tx.Statement = db.Statement.clone()
			tx.Statement.DB = tx
		}

		return tx
	}

	return db
}

// 建立连接加载数据库驱动
// Open initialize db session based on dialector
func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
	config := &Config{}

	sort.Slice(opts, func(i, j int) bool {
		_, isConfig := opts[i].(*Config)
		_, isConfig2 := opts[j].(*Config)
		return isConfig && !isConfig2
	})

	for _, opt := range opts {
		if opt != nil {
			if applyErr := opt.Apply(config); applyErr != nil {
				return nil, applyErr
			}
			defer func(opt Option) {
				if errr := opt.AfterInitialize(db); errr != nil {
					err = errr
				}
			}(opt)
		}
	}

	if d, ok := dialector.(interface{ Apply(*Config) error }); ok {
		if err = d.Apply(config); err != nil {
			return
		}
	}

core-0.1-Statement-构建

  • note
  • statement.go
// Statement statement
type Statement struct {
	*DB
	TableExpr            *clause.Expr
	Table                string
	Model                interface{}
	Unscoped             bool
	Dest                 interface{}
	ReflectValue         reflect.Value
	Clauses              map[string]clause.Clause
	BuildClauses         []string
	Distinct             bool
	Selects              []string // selected columns
	Omits                []string // omit columns
	Joins                []join
	Preloads             map[string][]interface{}
	Settings             sync.Map
	ConnPool             ConnPool
	Schema               *schema.Schema
	Context              context.Context
	RaiseErrorOnNotFound bool
	SkipHooks            bool
	SQL                  strings.Builder
	Vars                 []interface{}
	CurDestIndex         int
	attrs                []interface{}
	assigns              []interface{}
	scopes               []func(*DB) *DB
}

core-0.2-statement-条件语句的构建

    1. 根据sql 条件的语句来构建statement
  • statement.go
// Build build sql with clauses names
func (stmt *Statement) Build(clauses ...string) {
	var firstClauseWritten bool

	for _, name := range clauses {
		if c, ok := stmt.Clauses[name]; ok {
			if firstClauseWritten {
				stmt.WriteByte(' ')
			}

			firstClauseWritten = true
			if b, ok := stmt.DB.ClauseBuilders[name]; ok {
				b(c, stmt)
			} else {
				c.Build(stmt)
			}
		}
	}
}

// 构造器来构造表达式
// Build build clause
func (c Clause) Build(builder Builder) {
	if c.Builder != nil {
		c.Builder(c, builder)
	} else if c.Expression != nil {
		if c.BeforeExpression != nil {
			c.BeforeExpression.Build(builder)
			builder.WriteByte(' ')
		}

		if c.Name != "" {
			builder.WriteString(c.Name)
			builder.WriteByte(' ')
		}

		if c.AfterNameExpression != nil {
			c.AfterNameExpression.Build(builder)
			builder.WriteByte(' ')
		}

		c.Expression.Build(builder)

		if c.AfterExpression != nil {
			builder.WriteByte(' ')
			c.AfterExpression.Build(builder)
		}
	}
}

// 添加statement的 条件表达式
// Clauses Add clauses
func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
	tx = db.getInstance()
	var whereConds []interface{}

	for _, cond := range conds {
		if c, ok := cond.(clause.Interface); ok {
			tx.Statement.AddClause(c)
		} else if optimizer, ok := cond.(StatementModifier); ok {
			optimizer.ModifyStatement(tx.Statement)
		} else {
			whereConds = append(whereConds, cond)
		}
	}

	if len(whereConds) > 0 {
		tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(whereConds[0], whereConds[1:]...)})
	}
	return
}

core-0.3-BuildCondition-构建表达式的条件,如where条件

  • 核心思想是通过反射机制来构建where等各种表达式下
  • statement.go
// BuildCondition build condition
func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []clause.Expression {
	if s, ok := query.(string); ok {
		// if it is a number, then treats it as primary key
		if _, err := strconv.Atoi(s); err != nil {
			if s == "" && len(args) == 0 {
				return nil
			}

			if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) {
				// looks like a where condition
				return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
			}

			if len(args) > 0 && strings.Contains(s, "@") {
				// looks like a named query
				return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}}
			}

			if strings.Contains(strings.TrimSpace(s), " ") {
				// looks like a where condition
				return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
			}

			if len(args) == 1 {
				return []clause.Expression{clause.Eq{Column: s, Value: args[0]}}
			}
		}
	}

// 构建where条件
	case *DB:
			if cs, ok := v.Statement.Clauses["WHERE"]; ok {
				if where, ok := cs.Expression.(clause.Where); ok {
					if len(where.Exprs) == 1 {
						if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
							where.Exprs[0] = clause.AndConditions(orConds)
						}
					}
					conds = append(conds, clause.And(where.Exprs...))
				} else if cs.Expression != nil {
					conds = append(conds, cs.Expression)
				}
			}


// 默认通过反射来构建
	reflectValue := reflect.Indirect(reflect.ValueOf(arg))
			for reflectValue.Kind() == reflect.Ptr {
				reflectValue = reflectValue.Elem()
			}

			if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
				selectedColumns := map[string]bool{}
				if idx == 0 {
					for _, v := range args[1:] {
						if vs, ok := v.(string); ok {
							selectedColumns[vs] = true
						}
					}
				}
				restricted := len(selectedColumns) != 0

				switch reflectValue.Kind() {
				case reflect.Struct:
					for _, field := range s.Fields {
						selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
						if selected || (!restricted && field.Readable) {
							if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
								if field.DBName != "" {
									conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
								} else if field.DataType != "" {
									conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
								}
							}
						}
					}
				case reflect.Slice, reflect.Array:
					for i := 0; i < reflectValue.Len(); i++ {
						for _, field := range s.Fields {
							selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
							if selected || (!restricted && field.Readable) {
								if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
									if field.DBName != "" {
										conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
									} else if field.DataType != "" {
										conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
									}
								}
							}
						}
					}
				}

				if restricted {
					break
				}
			} else if !reflectValue.IsValid() {
				stmt.AddError(ErrInvalidData)
			} else if len(conds) == 0 {
				if len(args) == 1 {
					switch reflectValue.Kind() {
					case reflect.Slice, reflect.Array:
						// optimize reflect value length
						valueLen := reflectValue.Len()
						values := make([]interface{}, valueLen)
						for i := 0; i < valueLen; i++ {
							values[i] = reflectValue.Index(i).Interface()
						}

						if len(values) > 0 {
							conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
						}
						return conds
					}
				}

				conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
			}

SelectAndOmitColumns-查询和过滤类的信息

  • note
  • statement.go
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
	results := map[string]bool{}
	notRestricted := false

	// select columns
	for _, column := range stmt.Selects {
		if stmt.Schema == nil {
			results[column] = true
		} else if column == "*" {
			notRestricted = true
			for _, dbName := range stmt.Schema.DBNames {
				results[dbName] = true
			}
		} else if column == clause.Associations {
			for _, rel := range stmt.Schema.Relationships.Relations {
				results[rel.Name] = true
			}
		} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
			results[field.DBName] = true
		} else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 2 {
			results[matches[1]] = true
		} else {
			results[column] = true
		}
	}
	// ....
}

core-0.5-SetColumn-填充查询的字段信息

  • 构造查询的字段信息
  • statement.go
// SetColumn set column's value
//   stmt.SetColumn("Name", "jinzhu") // Hooks Method
//   stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
	if v, ok := stmt.Dest.(map[string]interface{}); ok {
		v[name] = value
	} else if v, ok := stmt.Dest.([]map[string]interface{}); ok {
		for _, m := range v {
			m[name] = value
		}
	} else if stmt.Schema != nil {
		if field := stmt.Schema.LookUpField(name); field != nil {
			destValue := reflect.ValueOf(stmt.Dest)
			for destValue.Kind() == reflect.Ptr {
				destValue = destValue.Elem()
			}
	}
}

Schema-0.1-协议设计

  • note
  • schema.go
type Schema struct {
	Name                      string
	ModelType                 reflect.Type
	Table                     string
	PrioritizedPrimaryField   *Field
	DBNames                   []string
	PrimaryFields             []*Field
	PrimaryFieldDBNames       []string
	Fields                    []*Field
	FieldsByName              map[string]*Field
	FieldsByDBName            map[string]*Field
	FieldsWithDefaultDBValue  []*Field // fields with default value assigned by database
	Relationships             Relationships
	CreateClauses             []clause.Interface
	QueryClauses              []clause.Interface
	UpdateClauses             []clause.Interface
	DeleteClauses             []clause.Interface
	BeforeCreate, AfterCreate bool
	BeforeUpdate, AfterUpdate bool
	BeforeDelete, AfterDelete bool
	BeforeSave, AfterSave     bool
	AfterFind                 bool
	err                       error
	initialized               chan struct{}
	namer                     Namer
	cacheStore                *sync.Map
}

Schema-0.2-解析字段和列的信息

  • note
  • schema.go
	schema := &Schema{
		Name:           modelType.Name(),
		ModelType:      modelType,
		Table:          tableName,
		FieldsByName:   map[string]*Field{},
		FieldsByDBName: map[string]*Field{},
		Relationships:  Relationships{Relations: map[string]*Relationship{}},
		cacheStore:     cacheStore,
		namer:          namer,
		initialized:    make(chan struct{}),
	}
	// When the schema initialization is completed, the channel will be closed
	defer close(schema.initialized)

	// Load exist schmema cache, return if exists
	if v, ok := cacheStore.Load(schemaCacheKey); ok {
		s := v.(*Schema)
		// Wait for the initialization of other goroutines to complete
		<-s.initialized
		return s, s.err
	}

	for i := 0; i < modelType.NumField(); i++ {
		if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
			if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {
				schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...)
			} else {
				schema.Fields = append(schema.Fields, field)
			}
		}
	}

	for _, field := range schema.Fields {
		if field.DBName == "" && field.DataType != "" {
			field.DBName = namer.ColumnName(schema.Table, field.Name)
		}

		if field.DBName != "" {
			// nonexistence or shortest path or first appear prioritized if has permission
			if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
				if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
					schema.DBNames = append(schema.DBNames, field.DBName)
				}
				schema.FieldsByDBName[field.DBName] = field
				schema.FieldsByName[field.Name] = field

				if v != nil && v.PrimaryKey {
					for idx, f := range schema.PrimaryFields {
						if f == v {
							schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
						}
					}
				}

				if field.PrimaryKey {
					schema.PrimaryFields = append(schema.PrimaryFields, field)
				}
			}
		}

		if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
			schema.FieldsByName[field.Name] = field
		}

		field.setupValuerAndSetter()
	}

Scchema-0.3-ParseField

  • 解析字段的信息填充
  • field.go
// ParseField parses reflect.StructField to Field
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
	var (
		err        error
		tagSetting = ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";")
	)

	field := &Field{
		Name:                   fieldStruct.Name,
		DBName:                 tagSetting["COLUMN"],
		BindNames:              []string{fieldStruct.Name},
		FieldType:              fieldStruct.Type,
		IndirectFieldType:      fieldStruct.Type,
		StructField:            fieldStruct,
		Tag:                    fieldStruct.Tag,
		TagSettings:            tagSetting,
		Schema:                 schema,
		Creatable:              true,
		Updatable:              true,
		Readable:               true,
		PrimaryKey:             utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"]),
		AutoIncrement:          utils.CheckTruth(tagSetting["AUTOINCREMENT"]),
		HasDefaultValue:        utils.CheckTruth(tagSetting["AUTOINCREMENT"]),
		NotNull:                utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
		Unique:                 utils.CheckTruth(tagSetting["UNIQUE"]),
		Comment:                tagSetting["COMMENT"],
		AutoIncrementIncrement: 1,
	}

	for field.IndirectFieldType.Kind() == reflect.Ptr {
		field.IndirectFieldType = field.IndirectFieldType.Elem()
	}

	fieldValue := reflect.New(field.IndirectFieldType)
	// if field is valuer, used its value or first field as data type
	valuer, isValuer := fieldValue.Interface().(driver.Valuer)
	if isValuer {
		if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok {
			if v, err := valuer.Value(); reflect.ValueOf(v).IsValid() && err == nil {
				fieldValue = reflect.ValueOf(v)
			}

			// Use the field struct's first field type as data type, e.g: use `string` for sql.NullString
			var getRealFieldValue func(reflect.Value)
			getRealFieldValue = func(v reflect.Value) {
				var (
					rv     = reflect.Indirect(v)
					rvType = rv.Type()
				)

				if rv.Kind() == reflect.Struct && !rvType.ConvertibleTo(TimeReflectType) {
					for i := 0; i < rvType.NumField(); i++ {
						for key, value := range ParseTagSetting(rvType.Field(i).Tag.Get("gorm"), ";") {
							if _, ok := field.TagSettings[key]; !ok {
								field.TagSettings[key] = value
							}
						}
					}

					for i := 0; i < rvType.NumField(); i++ {
						newFieldType := rvType.Field(i).Type
						for newFieldType.Kind() == reflect.Ptr {
							newFieldType = newFieldType.Elem()
						}

						fieldValue = reflect.New(newFieldType)
						if rvType != reflect.Indirect(fieldValue).Type() {
							getRealFieldValue(fieldValue)
						}

						if fieldValue.IsValid() {
							return
						}
					}
				}
			}

			getRealFieldValue(fieldValue)
		}
	}

四、 相关资料和引用

gotouml.png