简易实现 Go 的 ORM 框架(下) | 青训营

52 阅读13分钟

书接上回,juejin.cn/post/726781…

四、链式操作与更新删除

这部分要实现以下两个内容:

  • 通过链式(chain)操作,支持查询条件(where, order by, limit 等)的叠加。
  • 实现记录的更新(update)、删除(delete)和统计(count)功能。

1.支持 Update、Delete 和 Count

对于Sql语句的构造,一直是Clause在负责,所以要增加更新、删除、统计这三个功能的话,需要在Clause.go中新增它们对应的子句生成器:

const (
	INSERT Type = iota
	VALUES
	SELECT
	LIMIT
	WHERE
	ORDERBY
	UPDATE
	DELETE
	COUNT
)

接着实现对应字句的generator,并注册到全局变量generators中:

func _update(values ...interface{}) (string, []interface{}) {
	tableName := values[0]
	m := values[1].(map[string]interface{})
	var keys []string
	var vars []interface{}
	for k, v := range m {
		keys = append(keys, k+" = ?")
		vars = append(vars, v)
	}
	return fmt.Sprintf("UPDATE %s SET %s", tableName, strings.Join(keys, ", ")), vars
}

func _delete(values ...interface{}) (string, []interface{}) {
	return fmt.Sprintf("DELETE FROM %s", values[0]), []interface{}{}
}

func _count(values ...interface{}) (string, []interface{}) {
	return _select(values[0], []string{"count(*)"})
}

有了第三部分对Clause的讲解,看这部分的代码也没那么困难了,好理解了许多。

下面就要在session/record.go中拼接SQL语句并调用了:

func (s *Session) Update(kv ...interface{}) (int64, error) {
	m, ok := kv[0].(map[string]interface{})
	if !ok {
		m = make(map[string]interface{})
		for i := 0; i < len(kv); i += 2 {
			m[kv[i].(string)] = kv[i+1]
		}
	}
	s.clause.Set(clause.UPDATE, s.RefTable().Name, m)
	sql, vars := s.clause.Build(clause.UPDATE, clause.WHERE)
	result, err := s.Raw(sql, vars...).Exec()
	if err != nil {
		return 0, err
	}
	return result.RowsAffected()
}

func (s *Session) Delete() (int64, error) {
	s.clause.Set(clause.DELETE, s.RefTable().Name)
	sql, vars := s.clause.Build(clause.DELETE, clause.WHERE)
	result, err := s.Raw(sql, vars...).Exec()
	if err != nil {
		return 0, err
	}
	return result.RowsAffected()
}

func (s *Session) Count() (int64, error) {
	s.clause.Set(clause.COUNT, s.RefTable().Name)
	sql, vars := s.clause.Build(clause.COUNT, clause.WHERE)
	row := s.Raw(sql, vars...).QueryRow()
	var tmp int64
	if err := row.Scan(&tmp); err != nil {
		return 0, err
	}
	return tmp, nil
}

Update 方法接收一个或多个参数,其中第一个参数为 map[string]interface{} 类型的键值对,表示要更新成的新值。如果第一个参数不是该类型,则解析后续的可变参数,并形成一个键值对映射表。接着,使用 clause.Set 函数设置查询语句类型为 UPDATE,并将解析出来的表名和更改的键值对映射表传给 SQL 语句生成器,使用该生成器根据参数构建完整的 SQL 更新语句,最后通过 s.Raw 函数执行 SQL 查询并返回结果;

Delete 方法与 Update 方法类似,使用 clause.Set 函数设置查询语句类型为 DELETE,并将表名传入 SQL 语句生成器,使用该生成器构建 SQL 删除语句,最后通过 s.Raw 函数执行 SQL 查询并返回结果;

Count 方法使用 clause.Set 函数设置查询语句类型为 COUNT,并传入表名给SQL语句生成器,使用该生成器构建 SQL 统计语句,并通过 s.Raw 函数执行 SQL 语句查询,返回统计结果。

2.链式调用

链式调用是一种编程方式,在这种编程方式中,可以将多个方法调用连起来形成一个链,而不需要每个方法调用的结果都去声明一个新的变量。在链式调用中,每个方法调用的结果都是另一个对象的引用。也就是说,可以在一个方法调用之后,直接在返回的对象上调用另一个方法,这样就能够简介地完成一系列操作。

对于SQL语句的构造非常适合采用链式调用的方法,以下为一个例子:

s := geeorm.NewEngine("sqlite3", "gee.db").NewSession()
var users []User
s.Where("Age > 18").Limit(3).Find(&users)

从例子中可以看出,WHERELIMITORDER BY等查询条件语句非常适合链式调用,接下来在 session/record.go 中添加对应的方法:

func (s *Session) Limit(num int) *Session {
	s.clause.Set(clause.LIMIT, num)
	return s
}

func (s *Session) Where(desc string, args ...interface{}) *Session {
	var vars []interface{}
	s.clause.Set(clause.WHERE, append(append(vars, desc), args...)...)
	return s
}

func (s *Session) OrderBy(desc string) *Session {
	s.clause.Set(clause.ORDERBY, desc)
	return s
}

Limit 函数用于设置 SQL 查询语句中的返回记录数限制,通过 clause.Set 函数将该限制条件设置到 SQL 语句生成器(clause)中,并返回该 Session 结构体指针,以便链式调用。对于这一行代码s.clause.Set(clause.WHERE, append(append(vars, desc), args...)...)来说:第一次 appenddesc加入到vars的末尾,得到一个新的切片。第二次appendargs中所有的元素(使用...语法)也都加入到该新的切片的末尾。最后,第二个...表示将该新的切片展开,即将新的切片中的所有元素作为参数传递给Set方法,从而完成将条件设置到clause中的操作;

Where函数用于设置 SQL 查询语句中的条件,函数接收desc stringargs ...interface{}两个参数,其中desc表示 SQL 查询语句中的条件描述,args则表示该描述中需要传入的参数(可以没有参数)。函数使用clause.Set函数将该条件设置到 SQL 语句生成器中,并返回该Session结构体指针,以便链式调用;

OrderBy 函数用于设置 SQL 查询语句中的排序规则,函数接收一个参数desc string,表示 SQL 查询语句中的排序描述,可以包含多个字段,每个字段之间使用逗号分隔。函数使用 clause.Set 函数将该排序规则设置到 SQL 语句生成器中,并返回该Session结构体指针,以便链式调用。

3.First只返回一条记录

gorm框架中,经常会使用First()方法来返回一条记录,具体使用为:

u := &User{}
_ = s.OrderBy("Age DESC").First(u)

下面就来实现:

func (s *Session) First(value interface{}) error {
	dest := reflect.Indirect(reflect.ValueOf(value))
	destSlice := reflect.New(reflect.SliceOf(dest.Type())).Elem()
	if err := s.Limit(1).Find(destSlice.Addr().Interface()); err != nil {
		return err
	}
	if destSlice.Len() == 0 {
		return errors.New("NOT FOUND")
	}
	dest.Set(destSlice.Index(0))
	return nil
}

首先,在函数内部,使用 reflect 包将目标结构体或指针变量value转换为一个 reflect.Value 类型的变量dest,并从该变量中获取真正的结构体类型的信息。然后,创建一个匿名的切片类型,并将该切片的元素类型设置为目标结构体的类型。接着,通过Limit 函数将 SQL 查询语句中的限制条件设置为1,以获取仅查询一条记录的效果。

接下来,使用Find函数从数据库中查询符合条件的记录,并将查询结果映射到创建的 destSlice 切片中。如果查询失败,则直接返回一个非 nil的错误;如果查询成功却没有记录,则返回一个NOT FOUND的错误。

最后,通过Set方法将切片中的第一条记录赋值给目标结构体并返回nil,表示查询成功。如果切片为空,则返回一个NOT FOUND的错误。

五、实现钩子

本部分将实现以下两个功能:

  • 通过反射(reflect)获取结构体绑定的钩子(hooks),并调用。
  • 支持增删查改(CRUD)前后调用钩子。

1.Hook机制

Hook通常被叫做钩子,是一种允许程序代码拦截、监控或修改其他代码行为的机制。通俗来说,Hook机制就是在程序执行过程中,通过注入一些代码逻辑,在某些特定的时间点或事件发生时获取或干预程序的运行或执行流程。例如Goland在执行Ctrl+s的时候,会自动格式化代码再保存。对于ORM框架来说,在增删改查的前后记录都是非常适合的。

2.实现钩子

GeeORM 的钩子与结构体绑定,即每个结构体需要实现各自的钩子。hook相关的代码实现在session/hooks.go中:

const (
	BeforeQuery  = "BeforeQuery"
	AfterQuery   = "AfterQuery"
	BeforeUpdate = "BeforeUpdate"
	AfterUpdate  = "AfterUpdate"
	BeforeDelete = "BeforeDelete"
	AfterDelete  = "AfterDelete"
	BeforeInsert = "BeforeInsert"
	AfterInsert  = "AfterInsert"
)

func (s *Session) CallMethod(method string, value interface{}) {
	fm := reflect.ValueOf(s.RefTable().Model).MethodByName(method)
	if value != nil {
		fm = reflect.ValueOf(value).MethodByName(method)
	}
	param := []reflect.Value{reflect.ValueOf(s)}
	if fm.IsValid() {
		if v := fm.Call(param); len(v) > 0 {
			if err, ok := v[0].Interface().(error); ok {
				log.Error(err)
			}
		}
	}
	return
}

CallMethod()方法使用reflect包获取指定方法的函数对象,其中s.RefTable().Model表示指定模型对应的结构体对象。因此,获取函数对象的方式是使用reflect.ValueOf获取该结构体对象的类型,并通过MethodByName方法获取该类型中名称为method的方法的反射类型。如果传递了非空的value参数,则获取该参数的类型,并获取其内部类型的名称为method的方法的反射类型。

然后,创建一个[]reflect.Value类型的变量param,并将当前Session结构体的指针作为参数,构建一个参数列表。接下来,调用fm.Call(param)方法,将包含当前Session结构体的指针的param作为参数传递给该函数对象,并执行该函数,返回一个[]reflect.Value类型的值。最后判断函数是否有效。

接下来,依次在Find、Insert、Update、Delete 方法内部调用CallMethod()方法。例如,Find方法修改为:

func (s *Session) Find(values interface{}) error {
	s.CallMethod(BeforeQuery, nil)
    // ...
    for rows.Next() {
        dest := reflect.New(destType).Elem()
        // ...
        s.CallMethod(AfterQuery, dest.Addr().Interface())
        // ...
	}
	return rows.Close()
}

六、支持事务(Transaction)

本部分将实现以下两个功能:

  • 介绍数据库中的事务(transaction)。
  • 封装事务,用户自定义回调函数实现原子操作。

1.事务的ACID属性

数据库事务(transaction)是访问并可能操作各种数据项的一个数据库操作序列,这些操作要么全部执行,要么全部不执行,是一个不可分割的工作单位。事务由事务开始与事务结束之间执行的全部数据库操作组成。

举一个简单的例子,转账。A 转账给 B 一万元,那么数据库至少需要执行 2 个操作:

  • 1)A 的账户减掉一万元。
  • 2)B 的账户增加一万元。

这两个操作要么全部执行,代表转账成功。任意一个操作失败了,之前的操作都必须回退,代表转账失败。一个操作完成,另一个操作失败,这种结果是不能够接受的。这种场景就非常适合利用数据库事务的特性来解决。

如果一个数据库支持事务,那么必须具备 ACID 四个属性。

  • 原子性(Atomicity):一个事务是一个原子操作,它要么全部执行,要么全部不执行。如果一个事务的所有操作都成功,则认为该事务成功,事务中任何一个操作失败,整个事务都必须回滚到最初状态。这保证了事务的完整性。
  • 一致性(Consistency):一个事务开始之前和结束之后,数据库中的数据必须保持一致性状态。这意味着数据必须符合所有的预设规则,包括完整性约束、默认值、触发器和任何其他的数据库规则。如果一个事务获得了数据库的一致性,它将是一个有效的事务。
  • 隔离性(Isolation):一个事务的执行必须与其他事务是隔离的,即一个事务的执行结果对其他事务是透明的。一个事务在提交之前,对其他事务不可见。
  • 持久性(Durability):一个事务完成之后,它对于数据库的改变必须永久保存,即使发生了系统崩溃或断电等灾难性事件。

2.SQLite和Go标准库中的事务

如果想要在SQLite中创建事务,可以采用以下方法:

sqlite> BEGIN;
sqlite> DELETE FROM User WHERE Age > 25;
sqlite> INSERT INTO User VALUES ("Tom", 25), ("Jack", 18);
sqlite> COMMIT;

BEGIN开启事务,COMMIT提交事务,ROLLBACK回滚事务。任何一个事务,均以BEGIN开始,COMMITROLLBACK结束。

Go语言中也提供了支持事务的结构,具体代码如下:

func main() {
	db, _ := sql.Open("sqlite3", "gee.db")
	defer func() { _ = db.Close() }()
	_, _ = db.Exec("CREATE TABLE IF NOT EXISTS User(`Name` text);")

	tx, _ := db.Begin()
	_, err1 := tx.Exec("INSERT INTO User(`Name`) VALUES (?)", "Tom")
	_, err2 := tx.Exec("INSERT INTO User(`Name`) VALUES (?)", "Jack")
	if err1 != nil || err2 != nil {
		_ = tx.Rollback()
		log.Println("Rollback", err1, err2)
	} else {
		_ = tx.Commit()
		log.Println("Commit")
	}
}

具体表现为:调用db.Begin()得到*sql.Tx对象,使用tx.Exec()执行一系列操作,如果发生错误,通过tx.Rollback()回滚,如果没有发生错误,则通过tx.Commit()提交。

3.GeeORM 支持事务

对于之前的增删改查操作,都是执行后自动提交,且每个操作相互独立。如果想要支持事务,执行SQL语句的对象应该由sql.DB改为*sql.Tx,为了防止对代码进行大量的修改,我们设计:当Session中的tx不为空时,使用tx执行SQL语句,否则使用db。接下来修改raw.go

type Session struct {
	db       *sql.DB
	dialect  dialect.Dialect
	tx       *sql.Tx
	refTable *schema.Schema
	clause   clause.Clause
	sql      strings.Builder
	sqlVars  []interface{}
}

type CommonDB interface {
	Query(query string, args ...interface{}) (*sql.Rows, error)
	QueryRow(query string, args ...interface{}) *sql.Row
	Exec(query string, args ...interface{}) (sql.Result, error)
}

var _ CommonDB = (*sql.DB)(nil)
var _ CommonDB = (*sql.Tx)(nil)

func (s *Session) DB() CommonDB {
	if s.tx != nil {
		return s.tx
	}
	return s.db
}

接着新建文件session/transaction.go封装事务的Begin、Commit 和 Rollback三个接口:

func (s *Session) Begin() (err error) {
	log.Info("transaction begin")
	if s.tx, err = s.db.Begin(); err != nil {
		log.Error(err)
		return
	}
	return
}

func (s *Session) Commit() (err error) {
	log.Info("transaction commit")
	if err = s.tx.Commit(); err != nil {
		log.Error(err)
	}
	return
}

func (s *Session) Rollback() (err error) {
	log.Info("transaction rollback")
	if err = s.tx.Rollback(); err != nil {
		log.Error(err)
	}
	return
}

主要就是调用s.db.Begin()得到一个sql.Tx并赋值给s.tx,其余的操作就是在原有的基础上加上log

最后一步,在geeorm.go中为用户提供一键式使用的接口:

type TxFunc func(session2 *session.Session) (interface{}, error)

func (e *Engine) Transcation(f TxFunc) (result interface{}, err error) {
	s := e.NewSession()
	if err := s.Begin(); err != nil {
		return nil, err
	}
	defer func() {
		if p := recover(); p != nil {
			_ = s.Rollback()
			panic(p)
		} else if err != nil {
			_ = s.Rollback()
		} else {
			err = s.Commit()
		}
	}()
	return f(s)
}

这部分代码是实现了一个数据库事务,其中TxFunc是一个函数类型,表示要在事务中执行的函数。

Transcation方法接收一个TxFunc类型的参数f,并返回一个接口型的变量result 和一个错误变量err

首先创建一个新的Session实例s,这个实例是通过调用Engine 实例的方法NewSession创建的;

然后通过调用Begin方法开始一个新的事务。如果在开始事务时发生了错误,则返回错误;

在事务中调用函数f,并将s作为参数传递给该函数;

接下来,用defer定义一个匿名函数。如果panic函数被调用,则回滚事务,并重新抛出panic。如果在运行函数时发生了错误,则回滚事务。如果函数执行成功,则提交事务。最后,Transcation方法返回f(s)函数的执行结果和错误(如果有)。

这样的话用户只需要将所有的操作放到一个回调函数中,作为入参传递给e.Transaction(),发生任何错误,自动回滚,如果没有错误发生,则提交。

七、数据库迁移(Migrate)

本部分将实现以下两个功能:

  • 结构体(struct)变更时,数据库表的字段(field)自动迁移(migrate)。
  • 仅支持字段新增与删除,不支持字段类型变更。

1.使用SQL语句Migrate

本框架的迁移操作仅针对最为简单的场景,即支持字段的新增与删除,不支持字段类型变更。

让我们来看看Sqlite是如何支持Migrate的:

# 新增字段
ALTER TABLE table_name ADD COLUMN col_name, col_type;
# 删除字段
CREATE TABLE new_table AS SELECT col1, col2, ... from old_table
DROP TABLE old_table
ALTER TABLE new_table RENAME TO old_table;

对于 SQLite 来说,删除字段并不像新增字段那么容易,一个比较可行的方法需要执行下列几个步骤:

old_table中挑选需要保留的字段到new_table中;删除old_table;重命名new_tableold_table

2.GeeORM实现Migrate

按照原生的SQL命令,利用之前实现的事务,在geeorm.go中实现Migrate 方法:

func difference(a []string, b []string) (diff []string) {
	mapB := make(map[string]bool)
	for _, v := range b {
		mapB[v] = true
	}
	for _, v := range a {
		if _, ok := mapB[v]; !ok {
			diff = append(diff, v)
		}
	}
	return
}

func (e *Engine) Migrate(value interface{}) error {
	_, err := e.Transcation(func(s *session.Session) (result interface{}, err error) {
		if !s.Model(value).HasTable() {
			log.Infof("table %s doesn't exist", s.RefTable().Name)
			return nil, s.CreateTable()
		}
		table := s.RefTable()
		rows, _ := s.Raw(fmt.Sprintf("SELECT * FROM %s LIMIT 1", table.Name)).QueryRows()
		columns, _ := rows.Columns()
		addCols := difference(table.FieldNames, columns)
		delCols := difference(columns, table.FieldNames)
		log.Infof("added cols %v, deleted cols %v", addCols, delCols)
		for _, col := range addCols {
			f := table.GetField(col)
			sqlStr := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s;", table.Name, f.Name, f.Type)
			if _, err = s.Raw(sqlStr).Exec(); err != nil {
				return
			}
		}
		if len(delCols) == 0 {
			return
		}
		tmp := "tmp_" + table.Name
		fieldStr := strings.Join(table.FieldNames, ", ")
		s.Raw(fmt.Sprintf("CREATE TABLE %s AS SELECT %s from %s;", tmp, fieldStr, table.Name))
		s.Raw(fmt.Sprintf("DROP TABLE %s;", table.Name))
		s.Raw(fmt.Sprintf("ALTER TABLE %s RENAME TO %s;", tmp, table.Name))
		_, err = s.Exec()
		return
	})
	return err
}

difference方法被用于找到两个字符串切片table.FieldNamescolumns中不同的元素,以便进行相应的表结构变更,新表 - 旧表 = 新增字段,旧表 - 新表 = 删除字段;

Migrate方法使用e.Transcation创建一个新的Session并启动一个事务。在事务内部,首先判断指定结构体对应的表是否存在。如果不存在,则创建新表并返回错误。如果表已经存在,则获取该表的字段列表和数据类型,并查询表中的一行数据。

然后使用difference函数找到要添加和删除的字段,并用ALTER TABLE语句添加新字段。如果在执行添加字段的过程中发生错误,则返回错误。如果要删除的字段列表不为空,则通过创建临时表、将数据从旧表复制到临时表、删除旧表和重命名临时表的方式实现删除字段。

最后,调用s.Exec执行所有 SQL 语句并提交事务。如果在执行此过程中发生错误,则回滚事务并返回错误。

八、总结

至此,geeorm的框架就实行完毕了,总体功能较为粗造,比如数据库的迁移仅仅考虑了最简单的场景。实现的特性也比较少,比如结构体嵌套的场景,外键的场景,复合主键的场景都没有覆盖。可它不是一个在生产中使用的框架,而是一个帮助初学者理解orm框架的入门小项目,从这个角度来说,我认为它是完美的!

在其中可以很好的体会反射的使用,以及如何完成表结构和结构体之间的映射,如何支持多种数据库的操作,以及对事务的支持,这篇博客尽管已经完成,但还是需要反复的回过头来阅读。