DATABASE/SQL与GORM设计与实践|青训营笔记

561 阅读7分钟

理解DATABASE/SQL

database/sql包提供了保证SQL或类SQL数据库的泛用接口,使用sql包时必须注入(至少)一个数据库驱动。参见sqldrivers获取驱动列表。

基本用法

以下为基本演示示例,更多用法可查看官方文档

package main

import (
   "database/sql"
   "fmt"
   _ "github.com/go-sql-driver/mysql"
   "log"
   "time"
)

type Blog struct {
   Id          int64     `json:"id"`
   Title       string    `json:"title"`
   Content     string    `json:"content"`
   Description string    `json:"description"`
   CreateTime  time.Time `json:"create_time"`
}

func main() {
   // 打开连接,填充驱动类型以及数据源
   db, err := sql.Open("mysql", "root:2001@tcp(127.0.0.1:3306)/blog?parseTime=true")
   if err != nil {
      panic(err)
   }
   // 查询语句
   rows, err := db.Query("select id, title, content, description, create_time from t_blog")
   if err != nil {
      log.Println(err)
   }
   defer rows.Close()

   var blogList []Blog
   // 处理数据
   for rows.Next() {
      var blog Blog
      err := rows.Scan(&blog.Id, &blog.Title, &blog.Content, &blog.Description, &blog.CreateTime)
      if err != nil {
         log.Fatal(err)
      }
      blogList = append(blogList, blog)
   }
   // 处理错误
   if rows.Err() != nil {
      log.Fatal(rows.Err())
   }
   fmt.Println(blogList)
}

设计原理

DATABASE/SQL库为我们的应用程序提供了连接数据库以及操作数据库的一系列功能,如下图所示:

image.png

DB结构体中维护了各种连接信息,包括连接池配置等,同时也提供了各种方法设置配置信息:

type DB struct {
   // Atomic access only. At top of struct to prevent mis-alignment
   // on 32-bit platforms. Of type time.Duration.
   waitDuration int64 // Total time waited for new connections.

   connector driver.Connector
   // numClosed is an atomic counter which represents a total number of
   // closed connections. Stmt.openStmt checks it before cleaning closed
   // connections in Stmt.css.
   numClosed uint64

   mu           sync.Mutex    // protects following fields
   freeConn     []*driverConn // free connections ordered by returnedAt oldest to newest
   connRequests map[uint64]chan connRequest
   nextRequest  uint64 // Next key to use in connRequests.
   numOpen      int    // number of opened and pending open connections
   // Used to signal the need for new connections
   // a goroutine running connectionOpener() reads on this chan and
   // maybeOpenNewConnections sends on the chan (one send per needed connection)
   // It is closed during db.Close(). The close tells the connectionOpener
   // goroutine to exit.
   openerCh          chan struct{}
   closed            bool
   dep               map[finalCloser]depSet
   lastPut           map[*driverConn]string // stacktrace of last conn's put; debug only
   maxIdleCount      int                    // zero means defaultMaxIdleConns; negative means 0
   maxOpen           int                    // <= 0 means unlimited
   maxLifetime       time.Duration          // maximum amount of time a connection may be reused
   maxIdleTime       time.Duration          // maximum amount of time a connection may be idle before being closed
   cleanerCh         chan struct{}
   waitCount         int64 // Total number of connections waited for.
   maxIdleClosed     int64 // Total number of connections closed due to idle count.
   maxIdleTimeClosed int64 // Total number of connections closed due to idle time.
   maxLifetimeClosed int64 // Total number of connections closed due to max connection lifetime limit.

   stop func() // stop cancels the connection opener.
}

我们看以下代码,是使用池化技术处理操作的实际案例:

func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
   var rows *Rows
   var err error
   var isBadConn bool
   // 失败重试次数
   for i := 0; i < maxBadConnRetries; i++ {
      // 使用连接池处理操作
      rows, err = db.query(ctx, query, args, cachedOrNewConn)
      isBadConn = errors.Is(err, driver.ErrBadConn)
      if !isBadConn {
         break
      }
   }
   // 如果重试失败,则创建新的连接处理操作
   if isBadConn {
      return db.query(ctx, query, args, alwaysNewConn)
   }
   return rows, err
}

sql包中还提供了手动注册数据库驱动的方法:

// Register makes a database driver available by the provided name.
// If Register is called twice with the same name or if driver is nil,
// it panics.
func Register(name string, driver driver.Driver) {
   driversMu.Lock()
   defer driversMu.Unlock()
   if driver == nil {
      panic("sql: Register driver is nil")
   }
   if _, dup := drivers[name]; dup {
      panic("sql: Register called twice for driver " + name)
   }
   drivers[name] = driver
}

main方法启动时会调用github.com/go-sql-driver/mysql包下的init方法,注册mysql驱动:

func init() {
   sql.Register("mysql", &MySQLDriver{})
}

每一个数据库驱动必须实现Driver接口:

type Driver interface {
   // 获取连接
   Open(name string) (Conn, error)
}

我们看看MySQLDriver中的实现:

func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
    // 解析配置
   cfg, err := ParseDSN(dsn)
   if err != nil {
      return nil, err
   }
   c := &connector{
      cfg: cfg,
   }
   // 获取连接
   return c.Connect(context.Background())
}

另外一个重要的接口Connector,其中定义了两个方法:

type Connector interface {
   // 获取连接
   Connect(context.Context) (Conn, error)

   // 获取驱动
   Driver() Driver
}

github.com/go-sql-driver/mysql包下的connector.go为对应实现,可自行查看。

我们也可以通过配置的方式连接数据库:

connector,err:=mysql.NewConnector(&mysql.Config{
   User:      "root",
   Passwd:    "2001",
   Net:       "tcp",
   Addr:      "127.0.0.1:3306",
   DBName:    "blog",
   ParseTime: true,
})

db := sql.OpenDB(connector)

DB连接一般有以下几种方式:

  1. 直接连接——Conn
  2. 预编译连接——Stmt
  3. 实现事务连接——Tx

处理返回数据一般有以下方式:

  1. Exec/ExecContext——Result
  2. Query/QueryContext——Rows(Columns)
  3. QueryRow/QueryRowContext——Row(Rows简化)

我们来看看Rows接口,定义了处理返回数据需要具备的一些操作:

type Rows interface {
   // 列的名称
   Columns() []string

   // 关闭迭代器
   Close() error

   // 将下一行数据填充到提供的切片中
   Next(dest []Value) error
}

github.com/go-sql-driver/mysql包有两个实现binaryRows和textRows,可以自行打开查看。

GORM使用简介

官方文档:GORM 指南

GORM设计原则:API精简、测试优先、最小惊讶、灵活扩展、无依赖、可信赖

GORM功能完善:

  • 全功能 ORM
  • 关联 (Has One,Has Many,Belongs To,Many To Many,多态,单表继承)
  • Create,Save,Update,Delete,Find 中钩子方法
  • 支持 Preload、Joins的预加载
  • 事务,嵌套事务,Save Point,Rollback To Saved Point
  • Context、预编译模式、DryRun 模式
  • 批量插入,FindInBatches,Find/Create with Map,使用 SQL 表达式、Context Valuer 进行 CRUD
  • SQL 构建器,Upsert,数据库锁,Optimizer/Index/Comment Hint,命名参数,子查询
  • 复合主键,索引,约束
  • Auto Migration
  • 自定义 Logger
  • 灵活的可扩展插件 API:Database Resolver(多数据库,读写分离)、Prometheus…
  • 每个特性都经过了测试的重重考验
  • 开发者友好

基本用法

创建连接并查询数据:

package main

import (
   "fmt"
   "gorm.io/driver/mysql"
   "gorm.io/gorm"
   "gorm.io/gorm/schema"
   "log"
   "time"
)

type Blog struct {
   Id          int64     `json:"id"`
   Title       string    `json:"title"`
   Content     string    `json:"content"`
   Description string    `json:"description"`
   CreateTime  time.Time `json:"create_time"`
}

func (Blog) TableName() string {
   return "t_blog"
}

func main() {
   db, err := gorm.Open(
      mysql.Open("root:2001@tcp(127.0.0.1:3306)/blog?parseTime=true"),
      &gorm.Config{
         NamingStrategy: schema.NamingStrategy{
            SingularTable: true,
         },
      },
   )
   if err != nil {
      panic(err)
   }

   var blogList []Blog
   if err := db.Select("id", "title", "content", "description", "create_time").Find(&blogList, 1).Error; err != nil {
      log.Print(err)
   }

   fmt.Println(blogList)

}

增删改查操作如下:

package main

import (
   "errors"
   "fmt"
   "gorm.io/driver/mysql"
   "gorm.io/gorm"
   "gorm.io/gorm/schema"
   "log"
   "strconv"
)

type Product struct {
   Id    int64  `json:"id"`
   Name  string `json:"name"`
   Color string `json:"color"`
}

func (Product) TableName() string {
   return "t_product"
}

func main() {
   db, err := gorm.Open(
      mysql.Open("root:2001@tcp(127.0.0.1:3306)/blog?parseTime=true"),
      &gorm.Config{
         NamingStrategy: schema.NamingStrategy{
            SingularTable: true,
         },
      },
   )
   if err != nil {
      panic(err)
   }
   // 创建表、缺失的外键、约束、列和索引
   //db.AutoMigrate(&Product{})
   // 创建表
   // db.Migrator().CreateTable(&Product{})
   // https://gorm.io/zh_CN/docs/migration.html

   // 插入数据
   product := Product{
      Name:  "1",
      Color: "red",
   }
   if err := db.Create(&product).Error; err != nil {
      log.Print(err)
   }
   fmt.Println(product.Id)

   // 批量插入数据
   productList := []Product{
      {
         Name:  "2",
         Color: "red",
      },
      {
         Name:  "3",
         Color: "red",
      },
      {
         Name:  "4",
         Color: "red",
      },
   }
   // 批量插入
   db.Create(&productList)
   //db.CreateInBatches(&productList, 100)
   for _, product := range productList {
      fmt.Println(product.Id)
   }
   // 指定一批的数量

   // 查询id为1的product
   var p Product
   db.First(&p, 1)
   fmt.Println(p)
   // 查询name为1的product
   db.First(&p, "name = ?", "1")
   fmt.Println(p)

   result := db.Find(&productList, []int{1, 2, 3})
   // 找到的记录数
   fmt.Printf(strconv.FormatInt(result.RowsAffected, 10))
   // 查找不到数据
   errors.Is(result.Error, gorm.ErrRecordNotFound)

   // 更新某个字段
   db.Model(&p).Update("name", "update")
   db.Model(&p).Update("color", "blue")

   // 更新多个字段
   db.Model(&p).Updates(Product{
      Name:  "1",
      Color: "1",
   })
   db.Model(&p).Updates(map[string]interface{}{
      "name":  "update",
      "color": "blue",
   })

   // 批量更新
   db.Model(&Product{}).Where("id < ?", 10).Updates(map[string]interface{}{
      "name":  "update",
      "color": "blue",
   })

   // 删除product
   db.Delete(&p)
}

模型定义

对于以下结构体,创建表时:

  • 表名为struct name 的snake_cases复数格式
  • 字段名为field name的snake_case单数格式
  • ID /ld字段为主键,如果为数字,则为自增主键
  • CreatedAt字段,创建时,保存当前时间
  • UpdatedAt字段,创建、更新时,保存当前时间
  • gorm.DeletedAt字段,默认开启soft delete模式

详细规则可以参考官方文档-归约

type User struct {
   Id           int64          `json:"id"`
   Name         string         `json:"name"`
   Email        *string        `json:"email"`
   Age          int8           `json:"age"`
   Birthday     *time.Time     `json:"birthday"`
   MemberNumber sql.NullString `json:"member_number"`
   ActivatedAt  sql.NullTime   `json:"activated_at"`
   CreatedAt    time.Time      `json:"created_at"`
   UpdatedAt    time.Time      `json:"updated_at"`
   DeletedAt    gorm.DeletedAt `json:"deleted_at" gorm:"index"`
}

关联操作

关联关系包含一对一、一对多、多对多等,在GORM中表示如下:

type Account struct {
}

type Pet struct {
   gorm.Model
   UserId *uint `json:"user_id"`
   Toy    Toy   `json:"toy" gorm:"ploymorphic:Owner"`
}
type Toy struct {
}

type Company struct {
}

type Language struct {
   gorm.Model
   Name string `json:"name"`
}

// User 拥有一个 Account (has one),拥有多个Pets (has many),多个Toys (多态 has many)
// 属于某Company (belongs to)属于某Manager(单表belongs to)管理Team (单表 has many)
// 会多种 Languages (many to many)拥有很多 Friends(单表 many to many)
// 并且他的 Pet也有一个玩具Toy(多态has one)
type User struct {
   gorm.Model
   Name      string     `json:"name"`
   Account   Account    `json:"account"`
   Pets      []*Pet     `json:"pets"`
   Toys      []Toy      `json:"toys" gorm:"ploymorphic:Owner"`
   CompanyId *int       `json:"company_id"`
   Company   Company    `json:"company"`
   ManagerID *uint      `json:"manager_id"`
   Manager   *User      `json:"manager"`
   Team      []User     `json:"team" gorm:"foreignkey:ManagerID"`
   Languages []Language `json:"languages" gorm:"many2many:UserSpeak"`
   Friends   []*User    `json:"friends" gorm:"many2many:user_friends"`
}

关联操作的CRUD如下:

// 保存用户及其关联(upsert)
db.Save(&User{
   Name: "jinzhu",
   Languages: []Language{
      languageZH,
      languageEN,
   },
})
var user User
var languages []Language
// 关联模式
langAssociation := db.Model(&user).Association("Languages")
// 查询关联
langAssociation.Find(&languages)
// 将汉语,英语语添加到用户掌握的语言中
langAssociation.Append([]Language{languageZH, languageEN})
// 把用户掌握的语言替换为汉语,德语
langAssociation.Replace([]Language{languageZH, languageDE})
// 删除用户掌握的两个语言
langAssociation.Delete(languageZH, languageEN)
// 删除用户所有掌握的语言
langAssociation.Clear()
// 返回用户所掌握的语言的数量
langAssociation.Count()

// 批量模式 Append, Replace
var users = []User{User{}, User{}, User{}}
langAssociation = db.Model(&users).Association("Languages")

// 批量模式 Append,Replace,参数需要与源数据长度相同
// 例如:我们有3个user:将userA 添加到user1的 Team
// 将userB添加到user2的 Team,将userA、userB、userC添加到user3的 Team
userA, UserB, UserC := User{}, User{}, User{}
db.Model(&users).Association("Team").Append(&userA, &UserB, &[]User{
   userA, UserB, UserC,
})

查询时预加载数据:

package main

import (
   "gorm.io/driver/mysql"
   "gorm.io/gorm"
   "gorm.io/gorm/clause"
   "gorm.io/gorm/schema"
)

type Account struct {
}

type Pet struct {
   gorm.Model
   UserId *uint `json:"user_id"`
   Toy    Toy   `json:"toy" gorm:"ploymorphic:Owner"`
}
type Toy struct {
}

type Company struct {
   Alive bool `json:"alive"`
}

type Language struct {
   gorm.Model
   Name string `json:"name"`
}

// User 拥有一个 Account (has one),拥有多个Pets (has many),多个Toys (多态 has many)
// 属于某Company (belongs to)属于某Manager(单表belongs to)管理Team (单表 has many)
// 会多种 Languages (many to many)拥有很多 Friends(单表 many to many)
// 并且他的 Pet也有一个玩具Toy(多态has one)
type User struct {
   gorm.Model
   Name      string     `json:"name"`
   Account   Account    `json:"account"`
   Pets      []*Pet     `json:"pets"`
   Toys      []Toy      `json:"toys" gorm:"ploymorphic:Owner"`
   CompanyId *int       `json:"company_id"`
   Company   Company    `json:"company"`
   ManagerID *uint      `json:"manager_id"`
   Manager   *User      `json:"manager"`
   Team      []User     `json:"team" gorm:"foreignkey:ManagerID"`
   Languages []Language `json:"languages" gorm:"many2many:UserSpeak"`
   Friends   []*User    `json:"friends" gorm:"many2many:user_friends"`
}

func main() {
   db, err := gorm.Open(
      mysql.Open("root:2001@tcp(127.0.0.1:3306)/blog?parseTime=true"),
      &gorm.Config{
         NamingStrategy: schema.NamingStrategy{
            SingularTable: true,
         },
      },
   )
   if err != nil {
      panic(err)
   }

   var users []User
   // 查询用户的时候并找出其宠物,账户信息
   db.Preload("Pets").Preload("Account").Find(&users)
   // select * from users
   // select * from pets where user_id in (1,2,3,4) // 一对多
   // select * from accouts where user_id in (1,2,3,4) // 一对一

   var user User
   // 使用Join SQL加载(单条JOIN SQL)
   db.Joins("Company").Joins("Manager").First(&user, 1)
   db.Joins("Company", db.Where(&Company{Alive: true})).Find(users)

   // 预加载全部关联(只加载一级关联)
   db.Preload(clause.Associations).Find(&users)

   // 多级预加载
   db.Preload("Orders.OrderItems.Product").Find(&users)
   // 多级预加载+预加载全部一级关联
   db.Preload("Orders.OrderItems.Product").Preload(clause.Associations).Find(&users)

   // 查询用户的时候找出其未取消的订单
   db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
   db.Preload("Orders", "state = ?", "paid").Preload("Orders.OrderItems").Find(&users)

   db.Preload("Orders", func(db *gorm.DB) *gorm.DB {
      return db.Or("orders.amount DESC")
   }).Find(&users)

}

使用GORM进行级联删除:

package main

import (
   "gorm.io/driver/mysql"
   "gorm.io/gorm"
   "gorm.io/gorm/clause"
   "gorm.io/gorm/schema"
)

type Order struct {
   UserId uint `json:"user_id"`
}

type Account struct {
   UserId uint `json:"user_id"`
}

type CreditCard struct {
   UserId uint `json:"user_id"`
}

type User struct {
   ID          uint         `json:"id"`
   Name        string       `json:"name"`
   Orders      []Order      `json:"orders"  gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
   Account     Account      `json:"account"  gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
   CreditCards []CreditCard `json:"credit_cards"  gorm:"constraint:OnUpdate:CASCADE,OnDelete:CASCADE"`
}

func main() {
   db, err := gorm.Open(
      mysql.Open("root:2001@tcp(127.0.0.1:3306)/blog?parseTime=true"),
      &gorm.Config{
         NamingStrategy: schema.NamingStrategy{
            SingularTable: true,
         },
      },
   )
   if err != nil {
      panic(err)
   }

   // 需要使用GORM Migrate数据库迁移数据库外键才行
   db.AutoMigrate(&User{})
   // 如果未启用软删除,在删除User时会自动删除其依赖
   db.Delete(&User{})

   // 方法2:使用Select实现级联删除,不依赖数据库约束及软删除
   // 删除user时,也删除user的account
   db.Select("Account").Delete(&User{})

   // 删除user时,也删除user的Orders、CreditCards 记录,也删除订单的BillingAddress
   db.Select("Orders", "Orders.BillingAddress", "CreditCards").Delete(&User{})

   // 删除user时,也删除用户及其依赖的所有has one/many、many2many记录
   db.Select(clause.Associations).Delete(&User{})
}

GORM设计原理

GORM是基于Go原生DATABASE/SQL编写的:

image.png

SQL生成

我们看一个GORM的查询语句:

db.Where("role <> ?", "manager").Where("age > ?",35).Limit(100).Order("age desc").Find(&user)

中间的Where、Limit、Order都是链式方法,拼接SQL语句,最后的Find方法才是执行方法,确定查询的表以及返回值的类型,最终执行的SQL如下:

SELECT *FROM users WHERE role <> "manager" AND age > 35 ORDER BY age desc LIMIT 100

GORM 内部使用 SQL builder 生成 SQL。对于每个操作,GORM 都会创建一个 *gorm.Statement 对象,所有的 GORM API 都是在为 statement 添加、修改子句,最后,GORM 会根据这些子句生成 SQL

我们可以查看GORM是如何添加SQL条件的:

func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
   tx = db.getInstance()
   if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
      tx.Statement.AddClause(clause.Where{Exprs: conds})
   }
   return
}

func (db *DB) Limit(limit int) (tx *DB) {
   tx = db.getInstance()
   tx.Statement.AddClause(clause.Limit{Limit: &limit})
   return
}

GORM Finisher方法执行GORM Statement:

func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
   tx = db.getInstance()
   if len(conds) > 0 {
      if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
         tx.Statement.AddClause(clause.Where{Exprs: exprs})
      }
   }
   tx.Statement.Dest = dest
   return tx.callbacks.Query().Execute(tx)
}

func (p *processor) Execute(db *DB) *DB {
   // call scopes
   for len(db.Statement.scopes) > 0 {
      db = db.executeScopes()
   }

   var (
      curTime           = time.Now()
      stmt              = db.Statement
      resetBuildClauses bool
   )

   if len(stmt.BuildClauses) == 0 {
      stmt.BuildClauses = p.Clauses
      resetBuildClauses = true
   }

   if optimizer, ok := db.Statement.Dest.(StatementModifier); ok {
      optimizer.ModifyStatement(stmt)
   }

   // assign model values
   if stmt.Model == nil {
      stmt.Model = stmt.Dest
   } else if stmt.Dest == nil {
      stmt.Dest = stmt.Model
   }

   // parse model values
   if stmt.Model != nil {
      if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) {
         if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil {
            db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table("users")", err))
         } else {
            db.AddError(err)
         }
      }
   }

   // assign stmt.ReflectValue
   if stmt.Dest != nil {
      stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
      for stmt.ReflectValue.Kind() == reflect.Ptr {
         if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() {
            stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem()))
         }

         stmt.ReflectValue = stmt.ReflectValue.Elem()
      }
      if !stmt.ReflectValue.IsValid() {
         db.AddError(ErrInvalidValue)
      }
   }

   for _, f := range p.fns {
      f(db)
   }

   if stmt.SQL.Len() > 0 {
      db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
         sql, vars := stmt.SQL.String(), stmt.Vars
         if filter, ok := db.Logger.(ParamsFilter); ok {
            sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...)
         }
         return db.Dialector.Explain(sql, vars...), db.RowsAffected
      }, db.Error)
   }

   if !stmt.DB.DryRun {
      stmt.SQL.Reset()
      stmt.Vars = nil
   }

   if resetBuildClauses {
      stmt.BuildClauses = nil
   }

   return db
}

不同数据库甚至不同版本的数据库支持的SQL不同,例如当前读,需要针对不同的数据库版本生成不同的SQL语句,GORM是怎么做的呢?

// 不同数据库甚至不同版本的数据库支持的SQL不同
// SELECT * FROM `users` LOCK IN SHARE MODE // MySQL < 8,MariaDB
// SELECT * FROM `users` FOR SHARE OF `users`// MySQL 8
db.Clauses(clause.Locking{
   Strength: "SHARE",
   Table:    clause.Table{Name: clause.CurrentTable},
}).Find(&users)

我们查看GORM中的Initialize方法,在初始化的时候就已经查询了数据库的版本。

func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
   if dialector.DriverName == "" {
      dialector.DriverName = "mysql"
   }

   if dialector.DefaultDatetimePrecision == nil {
      dialector.DefaultDatetimePrecision = &defaultDatetimePrecision
   }

   if dialector.Conn != nil {
      db.ConnPool = dialector.Conn
   } else {
      db.ConnPool, err = sql.Open(dialector.DriverName, dialector.DSN)
      if err != nil {
         return err
      }
   }

   withReturning := false
   if !dialector.Config.SkipInitializeWithVersion {
      err = db.ConnPool.QueryRowContext(context.Background(), "SELECT VERSION()").Scan(&dialector.ServerVersion)
      if err != nil {
         return err
      }

      if strings.Contains(dialector.ServerVersion, "MariaDB") {
         dialector.Config.DontSupportRenameIndex = true
         dialector.Config.DontSupportRenameColumn = true
         dialector.Config.DontSupportForShareClause = true
         dialector.Config.DontSupportNullAsDefaultValue = true
         withReturning = checkVersion(dialector.ServerVersion, "10.5")
      } else if strings.HasPrefix(dialector.ServerVersion, "5.6.") {
         dialector.Config.DontSupportRenameIndex = true
         dialector.Config.DontSupportRenameColumn = true
         dialector.Config.DontSupportForShareClause = true
      } else if strings.HasPrefix(dialector.ServerVersion, "5.7.") {
         dialector.Config.DontSupportRenameColumn = true
         dialector.Config.DontSupportForShareClause = true
      } else if strings.HasPrefix(dialector.ServerVersion, "5.") {
         dialector.Config.DisableDatetimePrecision = true
         dialector.Config.DontSupportRenameIndex = true
         dialector.Config.DontSupportRenameColumn = true
         dialector.Config.DontSupportForShareClause = true
      }

      if strings.Contains(dialector.ServerVersion, "TiDB") {
         dialector.Config.DontSupportRenameColumnUnique = true
      }
   }

   // register callbacks
   callbackConfig := &callbacks.Config{
      CreateClauses: CreateClauses,
      QueryClauses:  QueryClauses,
      UpdateClauses: UpdateClauses,
      DeleteClauses: DeleteClauses,
   }

   if !dialector.Config.DisableWithReturning && withReturning {
      callbackConfig.LastInsertIDReversed = true

      if !utils.Contains(callbackConfig.CreateClauses, "RETURNING") {
         callbackConfig.CreateClauses = append(callbackConfig.CreateClauses, "RETURNING")
      }

      if !utils.Contains(callbackConfig.UpdateClauses, "RETURNING") {
         callbackConfig.UpdateClauses = append(callbackConfig.UpdateClauses, "RETURNING")
      }

      if !utils.Contains(callbackConfig.DeleteClauses, "RETURNING") {
         callbackConfig.DeleteClauses = append(callbackConfig.DeleteClauses, "RETURNING")
      }
   }

   callbacks.RegisterDefaultCallbacks(db, callbackConfig)

   for k, v := range dialector.ClauseBuilders() {
      db.ClauseBuilders[k] = v
   }
   return
}

GORM中可以通过hints扩展Clauses,官方文档中有相关介绍。

// 扩展SELECT Clause
db.Clauses(hints.New("MRR(idx1)")).Find(&User{})
// SELECT /*+MRR(idx1)*/ * FROM users

// 扩展FROM Clause
db.Clauses(hints.UseIndex("idx_user_name")).Find(&User{})
// SELECT * FROM `users` USE INDEX (`idx_user_name`)

db.Clauses(hints.ForceIndex("idx_user_name", "idx_user_id")).Find(&User{})
// SELECT * FROM `users` FORCE INDEX FOR JOIN (`idx_user_name`, `idx_user_name`)

// 自由扩展Clause前中后
db.Clauses(hints.Comment("select", "master")).Find(&User{})
// SELECT /*master*/ * FROM `users`;

db.Clauses(hints.CommentBefore("insert", "node2")).Find(&User{})
// /*node2*/ INSERT INTO `users` ...;

db.Clauses(hints.CommentAfter("where", "hint")).Find(&User{}, "id = ?", 1)
// SELECT * FROM `users` WHERE id = ? /* hint*/

插件扩展

GORM 自身也是基于Callbacks的,包括Create、Query、Update、Delete、Row、Raw。此外,您也完全可以根据自己的意愿自定义 GORM Callbacks。

Callbacks会注册到全局 *gorm.DB,而不是会话级别。如果想要 *gorm.DB 具有不同的回调,需要初始化另一个 *gorm.DB

插件官方文档

例如我们可以预定义针对CREATE的插件:

// 预定义CREATE CALLBACKS
db.Callback().Create().Register("gorm:begin_transaction", func(db *gorm.DB) {})
db.Callback().Create().Register("gorm:before_create", func(db *gorm.DB) {})
db.Callback().Create().Register("gorm:save_before_associations", func(db *gorm.DB) {})
db.Callback().Create().Register("gorm:create", func(db *gorm.DB) {})
db.Callback().Create().Register("gorm:save_after_associations", func(db *gorm.DB) {})
db.Callback().Create().Register("gorm:after_create", func(db *gorm.DB) {})
db.Callback().Create().Register("gorm:commit_or_rollback_transaction", func(db *gorm.DB) {})

在注册插件时,将会调用processor的compile()方法:

func (c *callback) Register(name string, fn func(*DB)) error {
   c.name = name
   c.handler = fn
   c.processor.callbacks = append(c.processor.callbacks, c)
   return c.processor.compile()
}

我们查看compile()方法:

func (p *processor) compile() (err error) {
   var callbacks []*callback
   for _, callback := range p.callbacks {
      if callback.match == nil || callback.match(p.db) {
         callbacks = append(callbacks, callback)
      }
   }
   p.callbacks = callbacks

   if p.fns, err = sortCallbacks(p.callbacks); err != nil {
      p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
   }
   return
}

compile()方法将callbacks进行排序后赋给fns字段,在最终执行SQL时,会遍历fns回调所有函数:

func (p *processor) Execute(db *DB) *DB {
   // 回调插件
   for _, f := range p.fns {
      f(db)
   }
   //...
}

插件的各种使用姿势:

// 注册新Callback
db.Callback().Create().Register("myplugin", func(db *gorm.DB) {})
// 删除Callback
db.Callback().Create().Remove("gorm:begin_transaction")
// 替换Callback
db.Callback().Create().Replace("gorm:begin_transaction", func(db *gorm.DB) {})
// 查询注册的Callback
db.Callback().Create().Get("gorm:begin_transaction")
// 指定Callback顺序
db.Callback().Create().Before("gorm:create").After("myplugin").Register("myplugin2", func(db *gorm.DB) {})
// 注册到所有服务之前
db.Callback().Create().Before("*").Register("myplugin:new_callback", func(db *gorm.DB) {})
// 注册时检查条件
db.Callback().Create().Match(func(db *gorm.DB) bool {
   return !db.SkipDefaultTransaction
}).Register("gorm:begin_transaction", func(db *gorm.DB) {})

插件提供了一种拦截机制,其使用场景有很多,例如:

  1. 多租户
  2. 多数据库、读写分离
  3. 加解密、混沌工程等

下面简单介绍一下各种使用场景

多租户:

在多租户场景下,大多数操作需要针对某一租户,如果每次操作数据库时都手动添加条件,非常不方便,我们就可以使用插件机制一劳永逸:

package main

import (
   "context"
   "errors"
   "gorm.io/driver/mysql"
   "gorm.io/gorm"
   "gorm.io/gorm/schema"
)

func main() {
   db, err := gorm.Open(
      mysql.Open("root:2001@tcp(127.0.0.1:3306)/blog?parseTime=true"),
      &gorm.Config{
         NamingStrategy: schema.NamingStrategy{
            SingularTable: true,
         },
      },
   )
   if err != nil {
      panic(err)
   }

   // 根据TenantID 过滤
   db.Callback().Create().Before("gorm:create").Register("set_tenant_scope", setTenantID)
   db.Callback().Query().Before("gorm:query").Register("set_tenant_scope", setTenantScope)
   db.Callback().Delete().Before("gorm:delete").Register("set_tenant_scope", setTenantScope)
   db.Callback().Update().Before("gorm:update").Register("set_tenant_scope", setTenantScope)
}

func setTenantScope(db *gorm.DB) {
   if tenantID, err := getTenantID(db.Statement.Context); err != nil {
      db.Where("tenant_id = ?", tenantID)
   } else {
      db.AddError(err)
   }
}

func getTenantID(ctx context.Context) (string, error) {
   if tenantId, ok := ctx.Value("tenant_id").(string); ok {
      return tenantId, nil
   }
   return "", errors.New("find tenant id error")
}

func setTenantID(db *gorm.DB) {
   db.WithContext(context.WithValue(nil, "tenant_id", "1"))
}

分库分表、读写分离:

DBResolver 为 GORM 提供了多个数据库支持,支持以下功能:

  • 支持多个 sources、replicas
  • 读写分离
  • 根据工作表、struct 自动切换连接
  • 手动切换连接
  • Sources/Replicas 负载均衡
  • 适用于原生 SQL
  • 事务

官方文档参考:官方文档

db.Use(dbresolver.Register(dbresolver.Config{
  // use `db2` as sources, `db3`, `db4` as replicas
  Sources:  []gorm.Dialector{mysql.Open("db2_dsn")},
  Replicas: []gorm.Dialector{mysql.Open("db3_dsn"), mysql.Open("db4_dsn")},
  // sources/replicas load balancing policy
  Policy: dbresolver.RandomPolicy{},
  // print sources/replicas mode in logger
  ResolverModeReplica: true,
}).Register(dbresolver.Config{
  // use `db1` as sources (DB's default connection), `db5` as replicas for `User`, `Address`
  Replicas: []gorm.Dialector{mysql.Open("db5_dsn")},
}, &User{}, &Address{}).Register(dbresolver.Config{
  // use `db6`, `db7` as sources, `db8` as replicas for `orders`, `Product`
  Sources:  []gorm.Dialector{mysql.Open("db6_dsn"), mysql.Open("db7_dsn")},
  Replicas: []gorm.Dialector{mysql.Open("db8_dsn")},
}, "orders", &Product{}, "secondary"))

// 使用 Write 模式:从 sources db `db1` 读取 user
db.Clauses(dbresolver.Write).First(&user)
// 指定 Resolver:从 `secondary` 的 replicas db `db8` 读取 user
db.Clauses(dbresolver.Use("secondary")).First(&user)
// 指定 Resolver 和 Write 模式:从 `secondary` 的 sources db `db6` 或 `db7` 读取 user
db.Clauses(dbresolver.Use("secondary"), dbresolver.Write).First(&user)

ConnPool

database/sql库中的DB维护了连接池,其实现了ConnPool接口:

// ConnPool db conns pool interface
type ConnPool interface {
   PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
   ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
   QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
   QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}

PrepareStmtDB实现了ConnPool接口,我们可以使用PrepareStmt模式:

package main

import (
   "gorm.io/driver/mysql"
   "gorm.io/gorm"
   "gorm.io/gorm/schema"
)

type User struct {
   gorm.Model
   Name      string  `json:"name"`
   CompanyId *int    `json:"company_id"`
   ManagerID *uint   `json:"manager_id"`
   Manager   *User   `json:"manager"`
   Team      []User  `json:"team" gorm:"foreignkey:ManagerID"`
   Friends   []*User `json:"friends" gorm:"many2many:user_friends"`
}

func main() {
   db, err := gorm.Open(
      mysql.Open("root:2001@tcp(127.0.0.1:3306)/blog?parseTime=true"),
      &gorm.Config{
         NamingStrategy: schema.NamingStrategy{
            SingularTable: true,
         },
         // 全局模式,所有DB操作都会预编译并缓存(缓存不含参数部分)
         PrepareStmt: true,
      },
   )
   if err != nil {
      panic(err)
   }

   var user User
   db.First(&user, 1)

   // 会话模式,后续会话的操作都会预编译并缓存
   tx := db.Session(&gorm.Session{
      PrepareStmt: true,
   })
   tx.First(&user, 1)
   tx.Model(&user).Update("Age", 18)

   // 全局缓存的语句可被会话使用
   tx.Find(&user, 2)

   stmtManger, _ := tx.ConnPool.(*gorm.PreparedStmtDB)
   // 关闭当前会话的预编译语句
   stmtManger.Close()

}

Dialector

数据库具有许多方言,例如mysql、postgres,clickhouse等,GORM也支持多种方言,Dialector可以做到以下优势:

  • 定制SQL生成
  • 定制GORM插件
  • 定制ConnPool
  • 定制企业特性逻辑
package main

import (
   "xxx.io/caches"
   "gorm.io/driver/clickhouse"
   "gorm.io/driver/mysql"
   "gorm.io/driver/postgres"
   "gorm.io/gorm"
)

func main() {
   dsn := "root:2001@tcp(127.0.0.1:3306)/blog?parseTime=true&charset=utf8mb4&loc=Local"
   db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
   db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
   db, err := gorm.Open(clickhouse.Open(dsn), &gorm.Config{})
   db, err := gorm.Open(caches.New(caches.Config{
      Fallback: mysql.Open(dsn),
      Store:    lru.New(lru.Config{}),
   }), &gorm.Config{})

}

GORM最佳实践

GORM还有以下高级特性:

  1. 表达式
  2. 序列化
  3. 批量数据操作
  4. 代码复用
  5. 分库分表
  6. Sharding
  7. 混沌工程
  8. Logger/Trace
  9. 数据库迁移
  10. Gen代码生成
  11. Raw SQL
  12. 安全

接下来我们演示这些高级特性如何使用

表达式

表达式相关操作:

package main

import (
   "context"
   "fmt"
   "gorm.io/datatypes"
   "gorm.io/driver/mysql"
   "gorm.io/gorm"
   "gorm.io/gorm/clause"
   "gorm.io/gorm/schema"
)

type User struct {
   ID        int      `json:"id"`
   Name      string   `json:"name"`
   CompanyId *int     `json:"company_id"`
   ManagerID *uint    `json:"manager_id"`
   Manager   *User    `json:"manager"`
   Team      []User   `json:"team" gorm:"foreignkey:ManagerID"`
   Friends   []*User  `json:"friends" gorm:"many2many:user_friends"`
   Location  Location `json:"location"`
}

type Company struct {
   ID   int
   Name string
}

type Product struct {
   Id    int64  `json:"id"`
   Name  string `json:"name"`
   Color string `json:"color"`
}

func main() {
   db, err := gorm.Open(
      mysql.Open("root:2001@tcp(127.0.0.1:3306)/blog?parseTime=true"),
      &gorm.Config{
         NamingStrategy: schema.NamingStrategy{
            SingularTable: true,
         },
         // 全局模式,所有DB操作都会预编译并缓存(缓存不含参数部分)
         PrepareStmt: true,
      },
   )
   if err != nil {
      panic(err)
   }

   // SQL表达式更新创建
   // 方法1:通过gorm.Expr使用SQL表达式
   db.Model(User{}).Create(map[string]interface{}{
      "Name":     "jinzhu",
      "Location": gorm.Expr("ST_PointFormText(?)", "POINT(100,100)"),
   })
   // INSERT INTO "user_with_points" ("name","location") VALUES ("jinzhu", ST_PointFromText("POIMNT(100 100)"));
   db.Model(&Product{}).Update("price", gorm.Expr("price * ? + ?", 2, 100))

   // 方法2:使用GORMValuer使用SQL表达式/ SubQuery
   db.Model(User{}).Create(map[string]interface{}{
      "Name":     "jinzhu",
      "Location": Location{X: 100, Y: 100},
   })
   db.Model(&User{ID: 1}).Updates(User{Name: "jinzhu", Location: Location{X: 100, Y: 100}})

   // 方法3:通过*gorm.DB 使用SubQuery
   subQuery := db.Model(&Company{}).Select("name").Where("companies.id = users.company_id")
   db.Model(&User{}).Updates(map[string]interface{}{
      "company_name": subQuery,
   })
   // UPDATE “users" SET "company_name” = (SELECT name FROM companies WHERE companies.id = users .company_id)

   // SQL表达式查询
   // 方法1:使用gorm.Expr
   db.Where("location = ?", gorm.Expr("ST_PointFormText(?)", "POINT(100,100)")).First(&User{})
   // SELECT * FROM `users` WHERE `location` = ST_PointFromText("POINT(100 100)");

   // 方法2:Struct定义GormValuer
   db.Where("location = ?", Location{X: 100, Y: 100}).First(&User{})
   // SELECT * FROM ‘users WHERE ‘location' = ST_PointFromText("POINT(100 100)");

   // 方法3:自定义查询SQL实现接口clause.Expression
   db.Find(&User{}, datatypes.JSONQuery("attributes").HasKey("role"))
   db.Clauses(datatypes.JSONQuery("attributes").HasKey("org", "name")).Find(&User{})
   // 方法4:SubQuery
   db.Where("name in (?)", db.Model(&User{}).Select("name").Where("id > 10")).Find(&User{})
}

type Location struct {
   X, Y int
}

func (loc Location) GormValue(ctx context.Context, db *gorm.DB) clause.Expr {
   return gorm.Expr("ST_PointFormText(?)", fmt.Sprintf("POINT(%d %d)", loc.X, loc.Y))
}

序列化

数据序列化:官方文档

一个Serializer需要实现如何对数据进行序列化和反序列化,所以需要实现如下接口:

import "gorm.io/gorm/schema"

type SerializerInterface interface {
    Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) error
    SerializerValuerInterface
}

type SerializerValuerInterface interface {
    Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error)
}

例如,默认 JSONSerializer 的实现如下:

// JSONSerializer json序列化器
type JSONSerializer struct {
}

// 实现 Scan 方法
func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
    fieldValue := reflect.New(field.FieldType)

    if dbValue != nil {
        var bytes []byte
        switch v := dbValue.(type) {
        case []byte:
            bytes = v
        case string:
            bytes = []byte(v)
        default:
            return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue)
        }

        err = json.Unmarshal(bytes, fieldValue.Interface())
    }

    field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
    return
}

// 实现 Value 方法
func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
    return json.Marshal(fieldValue)
}

批量操作

package main

import (
   "fmt"
   "gorm.io/driver/mysql"
   "gorm.io/driver/sqlite"
   "gorm.io/gorm"
   "gorm.io/gorm/clause"
   "gorm.io/gorm/schema"
   "strconv"
)

type User struct {
   ID        int     `json:"id"`
   Name      string  `json:"name"`
   CompanyId *int    `json:"company_id"`
   ManagerID *uint   `json:"manager_id"`
   Manager   *User   `json:"manager"`
   Team      []User  `json:"team" gorm:"foreignkey:ManagerID"`
   Friends   []*User `json:"friends" gorm:"many2many:user_friends"`
}

func main() {
   db, err := gorm.Open(
      mysql.Open("root:2001@tcp(127.0.0.1:3306)/blog?parseTime=true"),
      &gorm.Config{
         NamingStrategy: schema.NamingStrategy{
            SingularTable: true,
         },
         // 全局模式,所有DB操作都会预编译并缓存(缓存不含参数部分)
         PrepareStmt: true,
      },
   )
   if err != nil {
      panic(err)
   }
   // 批量创建
   var users = []User{{Name: "jinzhu1"}, {Name: "jinzhu2"}, {Name: "jinzhu3"}}
   db.Create(&users)
   db.CreateInBatches(users, 100)
   for _, user := range users {
      fmt.Printf(strconv.Itoa(user.ID)) // 1,2,3
   }
   // 批量查询
   rows, err := db.Model(&User{}).Where("role = ?", "admin").Rows()

   for rows.Next() {
      name, age, email := 0, 0, 0

      // 方法1: sql.Rows Scan
      rows.Scan(&name, &age, &email) // NULL 值的情况?
      // 方法2:gorm ScanRows
      db.ScanRows(rows, &User{})
      // xxx
   }
   db.Where("role = ?", "admin").FindInBatches(&users, 100, func(tx *gorm.DB, batch int) error {
      return nil
   })

   // 批量更新
   //忽略数据冲突
   db.Clauses(clause.OnConflict{DoNothing: true}).Create(&users)
   //INSERT INTO `users` *** ON DUPLICATE KEY DO NOTHING; // postgreSQL
   //INSERT INTO `users` *** ON DUPLICATE KEY UPDATE `id`=`id` ; // MysQL

   db.Clauses(clause.Insert{Modifier: "IGNORE"}).Create(&users)
   // INSERT IGNORE INTO "users” ***; // MySQL

   // 数据冲突时更新某些字段
   db.Clauses(clause.OnConflict{
      Columns:   []clause.Column{{Name: "id"}},
      DoUpdates: clause.Assignments(map[string]interface{}{"deleted_at": nil}),
   }).Create(&users)
   // INSERT INTO `users` *** ON DUPLICATE KEY UPDATE `deleted_at` = NULL

   db.Clauses(clause.OnConflict{
      Columns:   []clause.Column{{Name: "id"}},
      DoUpdates: clause.Assignments(map[string]interface{}{"count": gorm.Expr("GREATEST(count, VALUES(count))")}),
   }).Create(&users) //使用SQL表达式更新
   // INSERT INTO `users` *** ON DUPLICATE KEY UPDATE `deleted_at` = NULL

   // 数据冲突时更新某些字段为新值
   db.Clauses(clause.OnConflict{
      Columns:   []clause.Column{{Name: "id"}},
      DoUpdates: clause.AssignmentColumns([]string{"name", "age"}),
   }).Create(&users) //使用SQL表达式更新
   // INSERT INTO `users` *** ON DUPLICATE KEY UPDATE `name` = VALUES(name), `age`=VALUES(age);

   // 数据冲突时更新全部字段(除主键)为新值
   db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&users)
   // INSERT INTO 'users’ *** ON DUPLICATE KEY UPDATE ‘name'=VALUES(name), `age`=VALUES(age),...;

   // 批量数据加速操作
   // 方法1:关闭默认事务
   db, err = gorm.Open(sqlite.Open("gorm.db"), &gorm.Config{
      SkipDefaultTransaction: true,
   })
   db.Create(&User{})

   tx := db.Session(&gorm.Session{SkipDefaultTransaction: true})
   tx.Create(&User{})

   // 方法2:默认批量导入会调用Hooks方法,使用`SkipHooks`跳过
   db.Session(&gorm.Session{SkipHooks: true}).Create(&users)
   db.Session(&gorm.Session{SkipHooks: true}).CreateInBatches(users, 1000)

   // 方法3: 使用 Prepared Statement
   db, err = gorm.Open(sqlite.Open("gorm.db"), &gorm.Config{PrepareStmt: true})
   db.Create(&users)

   // 混合使用
   db = db.Session(&gorm.Session{
      PrepareStmt:            true,
      SkipDefaultTransaction: true,
      SkipHooks:              true, CreateBatchSize: 1000,
   })
   db.Create(&users)
}

代码复用

package main

import (
   "gorm.io/driver/mysql"
   "gorm.io/gorm"
   "gorm.io/gorm/schema"
   "net/http"
   "strconv"
)

type User struct {
   ID int `json:"id"`
}

type Article struct {
   ID int `json:"id"`
}

func main() {
   db, err := gorm.Open(
      mysql.Open("root:2001@tcp(127.0.0.1:3306)/blog?parseTime=true"),
      &gorm.Config{
         NamingStrategy: schema.NamingStrategy{
            SingularTable: true,
         },
         // 全局模式,所有DB操作都会预编译并缓存(缓存不含参数部分)
         PrepareStmt: true,
      },
   )
   if err != nil {
      panic(err)
   }
   r := &http.Request{}
   //代码共享
   db.Scopes(Paginate(r)).Find(&[]User{})
   db.Scopes(Paginate(r)).Find(&[]Article{})

}
func Paginate(r *http.Request) func(db *gorm.DB) *gorm.DB {
   return func(db *gorm.DB) *gorm.DB {
      page, _ := strconv.Atoi(r.Query("page"))
      if page == 0 {
         page = 1
      }
      pageSize, _ := strconv.Atoi(r.Query("page_size"))
      switch {
      case pageSize > 100:
         pageSize = 100
      case pageSize <= 0:
         pageSize = 10
      }
      offset := (page - 1) * pageSize
      return db.Offset(offset).Limit(pageSize)
   }
}

分库分表

package main

import (
   "gorm.io/driver/mysql"
   "gorm.io/gorm"
   "gorm.io/gorm/schema"
   "strconv"
)

type User struct {
   ID int `json:"id"`
}

func (u User) TableName() string {
   return "users"
}

type Article struct {
   ID int `json:"id"`
}

func main() {
   db, err := gorm.Open(
      mysql.Open("root:2001@tcp(127.0.0.1:3306)/blog?parseTime=true"),
      &gorm.Config{
         NamingStrategy: schema.NamingStrategy{
            SingularTable: true,
         },
         // 全局模式,所有DB操作都会预编译并缓存(缓存不含参数部分)
         PrepareStmt: true,
      },
   )
   if err != nil {
      panic(err)
   }

   // SELECT * FROM users_2019;
   db.Scopes(TableOfYear(&User{}, 2019)).Find(&[]User{})
   // SELECT * FROM org1.users;
   db.Scopes(TableOfOrg(&User{}, "org1")).Find(&[]User{})
   // SELECT * FROM users12;
   db.Scopes(TableOfUser(&User{ID: 12})).Find(&[]User{})

}

// 使用对象信息获取表名/ interface
func TableOfUser(user *User) func(db *gorm.DB) *gorm.DB {
   return func(db *gorm.DB) *gorm.DB {
      year := user.ID
      return db.Table(user.TableName() + strconv.Itoa(year))
   }
}

//使用传入数据分库(同一个连接)
func TableOfOrg(user *User, dbName string) func(db *gorm.DB) *gorm.DB {
   return func(db *gorm.DB) *gorm.DB {
      tableName := dbName + "." + user.TableName()
      return db.Table(tableName)
   }
}

//使用传入数据分表
func TableOfYear(user *User, year int) func(db *gorm.DB) *gorm.DB {
   return func(db *gorm.DB) *gorm.DB {
      tableName := user.TableName() + strconv.Itoa(year)
      return db.Table(tableName)
   }
}

Sharding

官方文档

package main

import (
   "gorm.io/driver/mysql"
   "gorm.io/gorm"
   "gorm.io/gorm/schema"
   "gorm.io/sharding"
)

type User struct {
   ID int `json:"id"`
}

type Order struct {
   ID     int `json:"id"`
   UserId int `json:"user_id"`
}

func main() {
   db, err := gorm.Open(
      mysql.Open("root:2001@tcp(127.0.0.1:3306)/blog?parseTime=true"),
      &gorm.Config{
         NamingStrategy: schema.NamingStrategy{
            SingularTable: true,
         },
         // 全局模式,所有DB操作都会预编译并缓存(缓存不含参数部分)
         PrepareStmt: true,
      },
   )
   if err != nil {
      panic(err)
   }

   db.Use(sharding.Register(sharding.Config{
      ShardingKey:         "user_id",
      NumberOfShards:      64,
      PrimaryKeyGenerator: sharding.PKSnowflake,
   }, "orders").Register(sharding.Config{
      ShardingKey:         "user_id",
      NumberOfShards:      256,
      PrimaryKeyGenerator: sharding.PKSnowflake,
      // This case for show up give notifications, audit_logs table use same sharding rule.
   }, Notification{}, AuditLog{}))

   // GORM 创建示例,这会插入到 orders_02 表
   db.Create(&Order{UserId: 2})
   // sql: INSERT INTO orders_2 ...

   // 原生 SQL 插入示例,这会插入到 orders_03 表
   db.Exec("INSERT INTO orders(user_id) VALUES(?)", int64(3))
}

混沌工程

这个没有找到官方文档,应该是闭源的

package main

import (
   "gorm.io/driver/mysql"
   "gorm.io/gorm"
   "gorm.io/gorm/sqlchaos"
)

type User struct {
   ID     int    `json:"id"`
   Name   string `json:"name"`
   Result int    `json:"result"`
}

func main() {
   db, err := gorm.Open(
      mysql.Open("root:2001@tcp(127.0.0.1:3306)/blog?parseTime=true"),
      &gorm.Config{},
      sqlchaos.WithChaos(
         sqlchaos.Config{
            PSM:     "service name",
            DBName:  "dbname",
            EnvList: []string{"ppe", "boe"}, //演练环境
         }),
   )
   if err != nil {
      panic(err)
   }
   db.Create(&User{ID: 1024, Name: "rick", Result: 10})
   // INSERT INTO table (`id`, 'user',‘result ') VALUES (1024,rick,10)
   // sqlchaos篡改为
   // INSERT INTO table (`id`, 'user',‘result ') VALUES (1024,morty,100)
}

Logger/Trace

Gorm 有一个 默认 logger 实现,默认情况下,它会打印慢 SQL 和错误

Logger 接受的选项不多,您可以在初始化时自定义它,例如:

newLogger := logger.New(
  log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer
  logger.Config{
    SlowThreshold:              time.Second,   // Slow SQL threshold
    LogLevel:                   logger.Silent, // Log level
    IgnoreRecordNotFoundError: true,           // Ignore ErrRecordNotFound error for logger
    ParameterizedQueries:      true,           // Don't include params in the SQL log
    Colorful:                  false,          // Disable color
  },
)

// Globally mode
db, err := gorm.Open(sqlite.Open("test.db"), &gorm.Config{
  Logger: newLogger,
})

// Continuous session mode
tx := db.Session(&Session{Logger: newLogger})
tx.First(&user)
tx.Model(&user).Update("Age", 18)

参考 GORM 的 默认 logger 来定义您自己的 logger

Logger 需要实现以下接口,它接受 context,所以你可以用它来追踪日志

type Interface interface {
    LogMode(LogLevel) Interface
    Info(context.Context, string, ...interface{})
    Warn(context.Context, string, ...interface{})
    Error(context.Context, string, ...interface{})
    Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error)
}

数据库迁移

AutoMigrate 用于自动迁移您的 schema,保持您的 schema 是最新的。

GORM 提供了 Migrator 接口,该接口为每个数据库提供了统一的 API 接口,可用来为您的数据库构建独立迁移,例如:

SQLite 不支持 ALTER COLUMNDROP COLUMN,当你试图修改表结构,GORM 将创建一个新表、复制所有数据、删除旧表、重命名新表。

一些版本的 MySQL 不支持 rename 列,索引。GORM 将基于您使用 MySQL 的版本执行不同 SQL

type Migrator interface {
  // AutoMigrate
  AutoMigrate(dst ...interface{}) error

  // Database
  CurrentDatabase() string
  FullDataTypeOf(*schema.Field) clause.Expr

  // Tables
  CreateTable(dst ...interface{}) error
  DropTable(dst ...interface{}) error
  HasTable(dst interface{}) bool
  RenameTable(oldName, newName interface{}) error
  GetTables() (tableList []string, err error)

  // Columns
  AddColumn(dst interface{}, field string) error
  DropColumn(dst interface{}, field string) error
  AlterColumn(dst interface{}, field string) error
  MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error
  HasColumn(dst interface{}, field string) bool
  RenameColumn(dst interface{}, oldName, field string) error
  ColumnTypes(dst interface{}) ([]ColumnType, error)

  // Views
  CreateView(name string, option ViewOption) error
  DropView(name string) error

  // Constraints
  CreateConstraint(dst interface{}, name string) error
  DropConstraint(dst interface{}, name string) error
  HasConstraint(dst interface{}, name string) bool

  // Indexes
  CreateIndex(dst interface{}, name string) error
  DropIndex(dst interface{}, name string) error
  HasIndex(dst interface{}, name string) bool
  RenameIndex(dst interface{}, oldName, newName string) error
}

迁移表:

// 为 `User` 创建表
db.Migrator().CreateTable(&User{})

// 将 "ENGINE=InnoDB" 添加到创建 `User` 的 SQL 里去
db.Set("gorm:table_options", "ENGINE=InnoDB").Migrator().CreateTable(&User{})

// 检查 `User` 对应的表是否存在
db.Migrator().HasTable(&User{})
db.Migrator().HasTable("users")

// 如果存在表则删除(删除时会忽略、删除外键约束)
db.Migrator().DropTable(&User{})
db.Migrator().DropTable("users")

// 重命名表
db.Migrator().RenameTable(&User{}, &UserInfo{})
db.Migrator().RenameTable("users", "user_infos")

迁移列:

type User struct {
  Name string
}

// 添加 name 字段
db.Migrator().AddColumn(&User{}, "Name")
// 删除 name 字段
db.Migrator().DropColumn(&User{}, "Name")
// 修改 name 字段
db.Migrator().AlterColumn(&User{}, "Name")
// 检查 name 字段是否存在
db.Migrator().HasColumn(&User{}, "Name")

type User struct {
  Name    string
  NewName string
}

// 字段重命名
db.Migrator().RenameColumn(&User{}, "Name", "NewName")
db.Migrator().RenameColumn(&User{}, "name", "new_name")

// 字段类型
db.Migrator().ColumnTypes(&User{}) ([]gorm.ColumnType, error)

type ColumnType interface {
    Name() string
    DatabaseTypeName() string                 // varchar
    ColumnType() (columnType string, ok bool) // varchar(64)
    PrimaryKey() (isPrimaryKey bool, ok bool)
    AutoIncrement() (isAutoIncrement bool, ok bool)
    Length() (length int64, ok bool)
    DecimalSize() (precision int64, scale int64, ok bool)
    Nullable() (nullable bool, ok bool)
    Unique() (unique bool, ok bool)
    ScanType() reflect.Type
    Comment() (value string, ok bool)
    DefaultValue() (value string, ok bool)
}

Gen代码生成

官方文档

It is quite straightforward to use gen for your application. Here is how it works:

  1. Write the configuration in golang
package main

import "gorm.io/gen"

// Dynamic SQL
type Querier interface {
  // SELECT * FROM @@table WHERE name = @name{{if role !=""}} AND role = @role{{end}}
  FilterWithNameAndRole(name, role string) ([]gen.T, error)
}

func main() {
  g := gen.NewGenerator(gen.Config{
    OutPath: "../query",
    Mode: gen.WithoutContext|gen.WithDefaultQuery|gen.WithQueryInterface, // generate mode
  })

  // gormdb, _ := gorm.Open(mysql.Open("root:@(127.0.0.1:3306)/demo?charset=utf8mb4&parseTime=True&loc=Local"))
  g.UseDB(gormdb) // reuse your gorm db

  // Generate basic type-safe DAO API for struct `model.User` following conventions
  g.ApplyBasic(model.User{})

  // Generate Type Safe API with Dynamic SQL defined on Querier interface for `model.User` and `model.Company`
  g.ApplyInterface(func(Querier){}, model.User{}, model.Company{})

  // Generate the code
  g.Execute()
}
  1. Generate Code

go run main.go

  1. Use the generated code in your project
import "your_project/query"

func main() {
  // Basic DAO API
  user, err := query.User.Where(u.Name.Eq("modi")).First()

  // Dynamic SQL API
  users, err := query.User.FilterWithNameAndRole("modi", "admin")
}

Raw SQL运行

官方文档

原生查询 SQL 和 Scan:

type Result struct {
  ID   int
  Name string
  Age  int
}

var result Result
db.Raw("SELECT id, name, age FROM users WHERE id = ?", 3).Scan(&result)

db.Raw("SELECT id, name, age FROM users WHERE name = ?", "jinzhu").Scan(&result)

var age int
db.Raw("SELECT SUM(age) FROM users WHERE role = ?", "admin").Scan(&age)

var users []User
db.Raw("UPDATE users SET name = ? WHERE age = ? RETURNING id, name", "jinzhu", 20).Scan(&users)

Exec 原生 SQL:

db.Exec("DROP TABLE users")
db.Exec("UPDATE orders SET shipped_at = ? WHERE id IN ?", time.Now(), []int64{1, 2, 3})

// Exec with SQL Expression
db.Exec("UPDATE users SET money = ? WHERE name = ?", gorm.Expr("money * ? + ?", 10000, 1), "jinzhu")

注意 GORM 允许缓存预编译 SQL 语句来提高性能,查看 性能 获取详情

GORM 支持 sql.NamedArg、map[string]interface{}{} 或 struct 形式的命名参数,例如:

db.Where("name1 = @name OR name2 = @name", sql.Named("name", "jinzhu")).Find(&user)
// SELECT * FROM `users` WHERE name1 = "jinzhu" OR name2 = "jinzhu"

db.Where("name1 = @name OR name2 = @name", map[string]interface{}{"name": "jinzhu2"}).First(&result3)
// SELECT * FROM `users` WHERE name1 = "jinzhu2" OR name2 = "jinzhu2" ORDER BY `users`.`id` LIMIT 1

// 原生 SQL 及命名参数
db.Raw("SELECT * FROM users WHERE name1 = @name OR name2 = @name2 OR name3 = @name",
   sql.Named("name", "jinzhu1"), sql.Named("name2", "jinzhu2")).Find(&user)
// SELECT * FROM users WHERE name1 = "jinzhu1" OR name2 = "jinzhu2" OR name3 = "jinzhu1"

db.Exec("UPDATE users SET name1 = @name, name2 = @name2, name3 = @name",
   sql.Named("name", "jinzhunew"), sql.Named("name2", "jinzhunew2"))
// UPDATE users SET name1 = "jinzhunew", name2 = "jinzhunew2", name3 = "jinzhunew"

db.Raw("SELECT * FROM users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2",
   map[string]interface{}{"name": "jinzhu", "name2": "jinzhu2"}).Find(&user)
// SELECT * FROM users WHERE (name1 = "jinzhu" AND name3 = "jinzhu") AND name2 = "jinzhu2"

type NamedArgument struct {
    Name string
    Name2 string
}

db.Raw("SELECT * FROM users WHERE (name1 = @Name AND name3 = @Name) AND name2 = @Name2",
     NamedArgument{Name: "jinzhu", Name2: "jinzhu2"}).Find(&user)
// SELECT * FROM users WHERE (name1 = "jinzhu" AND name3 = "jinzhu") AND name2 = "jinzhu2"

安全性相关操作

官方文档

GORM 使用 database/sql 的参数占位符来构造 SQL 语句,这可以自动转义参数,避免 SQL 注入数据

注意 Logger 打印的 SQL 并不像最终执行的 SQL 那样已经转义,复制和运行这些 SQL 时应当注意。

用户的输入只能作为参数,例如:

userInput := "jinzhu;drop table users;"

// 安全的,会被转义
db.Where("name = ?", userInput).First(&user)

// SQL 注入
db.Where(fmt.Sprintf("name = %v", userInput)).First(&user)

内联条件:

// 会被转义
db.First(&user, "name = ?", userInput)

// SQL 注入
db.First(&user, fmt.Sprintf("name = %v", userInput))

当通过用户输入的整形主键检索记录时,你应该对变量进行类型检查。

userInputID := "1=1;drop table users;"
// 安全的,返回 err
id,err := strconv.Atoi(userInputID)
if err != nil {
    return error
}
db.First(&user, id)

// SQL 注入
db.First(&user, userInputID)
// SELECT * FROM users WHERE 1=1;drop table users;

为了支持某些功能,一些输入不会被转义,调用方法时要小心用户输入的参数。

db.Select("name; drop table users;").First(&user)
db.Distinct("name; drop table users;").First(&user)

db.Model(&user).Pluck("name; drop table users;", &names)

db.Group("name; drop table users;").First(&user)

db.Group("name").Having("1 = 1;drop table users;").First(&user)

db.Raw("select name from users; drop table users;").First(&user)

db.Exec("select name from users; drop table users;")

db.Order("name; drop table users;").First(&user)

避免 SQL 注入的一般原则是,不信任用户提交的数据。您可以进行白名单验证来测试用户的输入是否为已知安全的、已批准、已定义的输入,并且在使用用户的输入时,仅将它们作为参数。