GORM 基础CRUD | 豆包MarsCode AI刷题

119 阅读16分钟
11-19基础CRUD
11-20更新事务部分

一、基础

注意:由于GORM的DB是链式操作,所有的动词如CreateFind等都是Terminal操作,到这里就会将前面的构造转换成SQL去执行,就算这些动词后面还有条件,也不会生效。

1.1、创建连接

以Mysql为例:

// 参考 https://github.com/go-sql-driver/mysql#dsn-data-source-name 获取详情
dsn := "root:123456@tcp(127.0.0.1:3306)/gorm?charset=utf8mb4&parseTime=True&loc=Local"
db, err := gorm.Open(mysql.Open(dsn),
    &gorm.Config{})

1.2、创建数据

假设我们数据库中有一张表格:Product,表格格式如下:

新建结构体对应这个表格:

type Product struct {
	ID    uint   `gorm:"primary_key;auto_increment"`
	Code  string `gorm:"column:code"`
	Price uint   `gorm:"column:price;default:50"`
}

在Gorm中,需要指定某个结构体对应哪张表格:

// 定义这个结构体对应的表名
func (p Product) TableName() string {
	return "product"
}

如果不指定,则默认使用结构体名称的蛇形负数作为表明去查找。蛇形负数就是把形如CreditCard映射成credit_cards

p := &Product{Code: "D42", Price: 100}
ret := db.Create(p)
fmt.Println(ret.Error, ret.RowsAffected, p.ID)

由于是链式调用,因此Create会返回一个tx *DB结构的数据,这里面我们可以输出是否报错、影响行数等信息。

products := []*Product{{Code: "D41"}, {Code: "D42"}, {Code: "D43"}}
ret = db.Create(&products)
fmt.Println(ret.RowsAffected)

批量插入还有一个办法:db.CreateInBatches, 可以指定插入多少数据。

Note:无法像Create函数中传入结构体,所以,我们应该传入指针。

ret = db.Select("Code").Create(&Product{Code: "D45", Price: 150})

上面执行完之后,Price字段并不会出现在对应的记录中,因为我们创建的时候指定了只使用Code字段。

上面所有代码执行之后,查看表格中的数据:

会不会有有人好奇D45这里为什么不是50?

答案很简单,我们只是在定义结构体的时候,不指定Price会给一个默认值。但是这里给了150,而我们新建记录的时候是不看这个字段的,数据表中price字段也没有默认值,因此这里是空。

创建钩子 Hook

Hook 是在创建、查询、更新、删除等操作之前、之后调用的函数。详见Hook文档,写的很清晰。

GORM允许用户通过实现BeforeSaveBeforeCreateAfterSaveAfterCreate来自定义钩子。这些钩子方法会在创建一条记录的时候被调用。

假如说我们想在创建实体之后输出一下创建的实体信息:

/*
创建之后输出实体信息。
*/
func (p *Product) AfterCreate(tx *gorm.DB) (err error) {
	fmt.Printf("%#v\n", p)
	return
}

这时候再重新执行一下上面的代码,控制台输出如下:

&create.Product{ID:0x1, Code:"D46", Price:0x64}

根据Map创建数据:

db.Model(&User{}).Create(map[string]interface{}{
  "Name": "jinzhu", "Age": 18,
})

// batch insert from `[]map[string]interface{}{}`
db.Model(&User{}).Create([]map[string]interface{}{
  {"Name": "jinzhu_1", "Age": 18},
  {"Name": "jinzhu_2", "Age": 20},
})

db.Model(entity)会根据entityTableName方法找到对应的表名。

Note: 当使用map来创建时,钩子方法不会执行,关联不会被保存且不会回写主键。

关联创建

建表语句:

CREATE TABLE `users` (
  `id` int(11) NOT NULL AUTO_INCREMENT,
  `name` varchar(255) NOT NULL,
  `created_at` datetime DEFAULT NULL,
  `updated_at` datetime DEFAULT NULL,
  `deleted_at` datetime DEFAULT NULL,
  PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

CREATE TABLE `credit_cards` (
  `id` int(11) NOT NULL AUTO_INCREMENT,
  `number` varchar(255) NOT NULL,
  `user_id` int(11) NOT NULL,
  `created_at` datetime DEFAULT NULL,
  `updated_at` datetime DEFAULT NULL,
  `deleted_at` datetime DEFAULT NULL,
  PRIMARY KEY (`id`),
  FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

执行下面代码:

type CreditCard struct {
	gorm.Model
	Number string
	UserID uint
}

type User struct {
	gorm.Model
	Name       string
	CreditCard CreditCard
}

func RelateCreate() {
	db, err := gorm.Open(mysql.Open("root:123456@tcp(127.0.0.1:3306)/gorm?charset=utf8mb4&parseTime=True&loc=Local"),
		&gorm.Config{})

	if err != nil {
		panic(err)
	}

	db.Create(&User{
		Name:       "jinzhu",
		CreditCard: CreditCard{Number: "411111111111"},
	})
}

执行之后,会发现userscredit_cards都有记录更新。

那么,如何跳过关联更新,也就是不更新credit_cards的数据呢?

通过Select, Omit方法跳过关联更新。

db.Omit("CreditCard").Create(&User{
    Name: "scy",
})
db.Omit(clause.Associations).Create(&User{
    Name: "byte",
})

默认值

可以通过结构体TagDefault来定义字段的默认值:

type User struct {
  ID   int64
  Name string `gorm:"default:galeone"`
  Age  int64  `gorm:"default:18"`
}

这些默认值会被当作结构体字段的零值插入到数据库中。

注意:当结构体中字段的默认值是0值的时候,比如说0, '', false,这些字段值不会被保存到数据库中,可以使用指针类型 或 Scanner/Valuer来避免这种情况

type User struct {
  gorm.Model
  Name string
  Age  *int           `gorm:"default:18"`
  Active sql.NullBool `gorm:"default:true"`
}

如果想要在数据库迁移的时候跳过默认值,可以使用 default:(-)

type User struct {
  ID        string `gorm:"default:uuid_generate_v3()"` // db func
  FirstName string
  LastName  string
  Age       uint8
  FullName  string `gorm:"->;type:GENERATED ALWAYS AS (concat(firstname,' ',lastname));default:(-);"`
}

Upsert及冲突

Upsert

当冲突的时候,假如说ID=1已经存在了,我们在创建的时候依旧指定ID=1那么就会出现冲突,可以使用

clause.OnConflict处理数据冲突。

db, err := gorm.Open(mysql.Open("root:123456@tcp(127.0.0.1:3306)/gorm?charset=utf8mb4&parseTime=True&loc=Local"),
    &gorm.Config{})

if err != nil {
    panic(err)
}
p := &Product{Code: "D46", Price: 100, ID: 1}
ret := db.Create(p)
fmt.Println(ret.Error, ret.RowsAffected, p.ID)

正常情况执行之后:Error 1062 (23000): Duplicate entry '1' for key 'product.PRIMARY' 0 1

跳过冲突:

ret := db.Clauses(clause.OnConflict{DoNothing: true}).Create(p)

1.3、查询数据

查询单个数据

Take, First, Last用于查询单个数据

  • First:获取第一条记录(按照主键升序来)
  • Last:获取最后一条记录(主键降序),其实就是主键升序的最后一个记录
  • Take:获取一条记录(没有指定排序字段)

但是用这些查询的时候可能会返回错误:ErrRecordNotFound。 用上面查询单个数据的时候,需要额外检查是否有这个错误。

type Product struct {
	ID    uint   `gorm:"primary_key;auto_increment"`
	Code  string `gorm:"column:code"`
	Price uint   `gorm:"column:price;default:50"`
}

// 定义这个结构体对应的表名
func (p Product) TableName() string {
	return "product"
}

// 创建钩子,查询后输出结果
func (p *Product) AfterFind(tx *gorm.DB) (err error) {
	fmt.Printf("product: %#v\n", p)
	return
}
func SimpleQuery() {
	dsn := "root:123456@tcp(127.0.0.1:3306)/gorm?charset=utf8mb4&parseTime=True&loc=Local"
	db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{})
	if err != nil {
		panic(err)
	}
	p := &Product{}
	// 查询单个数据
	ret := db.First(p)
	if errors.Is(ret.Error, gorm.ErrRecordNotFound) {
		// service
		panic(ret.Error)
	}
	// 查询最后一个
	p2 := &Product{}
	ret = db.Last(p2)
	if errors.Is(ret.Error, gorm.ErrRecordNotFound) {
		// service
		panic(ret.Error)
	}
	// Take
	p3 := &Product{}
	ret = db.Take(p3)
	if errors.Is(ret.Error, gorm.ErrRecordNotFound) {
		// service
		panic(ret.Error)
	}
}

如果你想避免ErrRecordNotFound错误,你可以使用Find,比如db.Limit(1).Find(&user)Find方法可以接受struct和slice的数据。

注意:对单个对象使用Find而不带limitdb.Find(&user)将会查询整个表并且只返回第一个对象,只是性能不高并且不确定的。

First and Last 方法会按主键排序找到第一条记录和最后一条记录 (分别)。 只有在目标 struct 是指针或者通过 db.Model() 指定 model 时,该方法才有效,此外,如果相关 model 没有定义主键,那么将按 model 的第一个字段进行排序。

主键检索

如果主键是数字类型,那么可以使用内联条件来检索对象。

p3 := &Product{}
db.First(&p3, 2)

检索全部对象

products := []*Product{}
db.Find(&products)

条件

String条件
p := &Product{}
// 查询id大于5的第一个记录
// SELECT * FROM users WHERE id > 5 ORDER BY id LIMIT 1;
db.Where("id > ?", 5).First(&p)

// 查询id 等于 5, 7的记录
// SELECT * FROM users WHERE id IN (5,7);
ps := []*Product{}
db.Where("id IN ?", []int{5, 7}).Find(&ps)
// 更多在文档中:https://gorm.io/zh_CN/docs/query.html#%E6%9D%A1%E4%BB%B6

为了方便查看sql语句,可以开启Logger

newLogger := logger.New(
    log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer
    logger.Config{
        SlowThreshold:             time.Second, // Slow SQL threshold
        LogLevel:                  logger.Info, // Log level
        IgnoreRecordNotFoundError: true,        // Ignore ErrRecordNotFound error for logger
        ParameterizedQueries:      true,        // Don't include params in the SQL log
        Colorful:                  true,		// Enable Color
    },
)
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{Logger: newLogger})

注意,如果对象设置了主键,条件查询将不会覆盖主键的值,而是用And连接条件:

p1 := &Product{ID: 1}
ret := db.Debug().Where("id = ?", 2).Find(p1) // Debug会打印SQL语句
// SELECT * FROM `product` WHERE id = 2 AND `product`.`id` = 1
print(ret.RowsAffected) // 0
// 奇怪的是这里没有触发
if errors.Is(ret.Error, gorm.ErrRecordNotFound) {
    fmt.Println(ret.Error)
    panic(ret.Error) 
}
Struct & Map
p2 := &Product{}
// Struct Map 条件
db.Where(&Product{Price: 50}).First(p2)
// SELECT * FROM `product` WHERE `product`.`price` = ? ORDER BY `product`.`id` LIMIT ?
db.Where(map[string]interface{}{"price": 50}).Find(p2)
// SELECT * FROM `product` WHERE `price` = ? AND `product`.`id` = ?

ps := []*Product{}
// 根据主键
db.Where([]int{4, 5, 6}).Find(&ps)
// SELECT * FROM `product` WHERE `product`.`id` IN (?,?,?)

Note:如果使用结构体查询,当结构体中有零值如0, '', false等,这些字段不会被用于构建查询条件,可以使用Map解决这个问题。

例如:

db.Where(&Product{Price: 0}).Limit(1).Find(p2)
// SELECT * FROM `product` LIMIT ?

// 使用map
db.Where(map[string]interface{}{"price": 0}).Limit(1).Find(p2)
// SELECT * FROM `product` WHERE `price` = ? AND `product`.`id` = ? LIMIT

有时候,不需要结构体中的所有字段查询,那么可以指定用结构体的哪个字段:

// 指定结构体查询字段
db.Where(&Product{Price: 100, ID: 1}, "ID").Find(p2)
// SELECT * FROM `product` WHERE `product`.`id` = ?
内联条件

查询条件可以内联到First和Find等方法中,其方式类似于where。

db.Limit(1).Find(p2, "price = ?", 50)
// SELECT * FROM `product` WHERE price = ? LIMIT ?
Not条件
db.Not("price", 50).Limit(1).Find(p2)
// SELECT * FROM `product` WHERE `price` <> ? LIMIT ?

等价于Where条件

db.Where("price <> ?", 50).Limit(1).Find(p2)
// SELECT * FROM `product` WHERE price <> ? LIMIT ?
Or条件
// Or条件
db.Where("price > ?", 50).Or("id < ?", 5).Limit(1).Find(p2)
// SELECT * FROM `product` WHERE price > ? OR id < ? LIMIT ?

db.Where("price > ?", 50).Or(&Product{ID: 5, Price: 100}).Limit(1).Find(p2)
// SELECT * FROM `product` WHERE price > ? OR (`product`.`id` = ? AND `product`.`price` = ?) LIMIT ?

选择特定字段

Select用与选择想从数据库中获取的特定字段

db.Select("code").Where("price < 200").Limit(1).Find(p2)
// SELECT `code` FROM `product` WHERE price < 200 LIMIT ?

// or
// db.Select([]string{})... 切片传递多个字段

排序

// 排序
db.Order("price desc").Limit(1).First(p2)
// SELECT * FROM `product` ORDER BY price desc,`product`.`id` LIMIT ?
db.Clauses(clause.OrderBy{
    Expression: clause.Expr{SQL: "FIELD(id, ?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true},
}).Limit(1).Find(p2)
// SELECT * FROM `product` WHERE `product`.`id` = ? ORDER BY FIELD(id, ?,?,?) LIMIT ?

Limit & Offset

Limit前面用到多次,这里不做解释了。

Offset明确指出从那一个位置开始搜索。

db.Offset(3).Limit(1).Find(p2)
// SELECT * FROM `product` LIMIT ? OFFSET ?

输出的会是第四条记录。

Group By & Having

为了方便演示,将上面Product记录再增加一遍

type Result struct {
	Code  string 
	Total uint
}

// 创建钩子,查询后输出结果
func (r *Result) AfterFind(tx *gorm.DB) (err error) {
	fmt.Printf("result: %v\n", r)
	return
}
results := []*Result{}
db.Model(&Product{}).Select("code, sum(price) as total").Group("code").Find(&results)
// SELECT code, sum(price) as total FROM `product` GROUP BY `code`

db.Model(&Product{}).Select("code, sum(price) as total").Group("code").Having("sum(price) > ?", 150).Find(&results)
// SELECT code, sum(price) as total FROM `product` GROUP BY `code` HAVING sum(price) > ?

Distinct

查询对应列数据不同的元素:

db.Model(&Product{}).Distinct("code").Find(&ps)
// SELECT DISTINCT `code` FROM `product`

可以发现,Distinct只会返回code的内容:

product: &query.Product{ID:0x0, Code:"D46", Price:0x0}
product: &query.Product{ID:0x0, Code:"D41", Price:0x0}
product: &query.Product{ID:0x0, Code:"D42", Price:0x0}
product: &query.Product{ID:0x0, Code:"D43", Price:0x0}
product: &query.Product{ID:0x0, Code:"D45", Price:0x0}

Joins 连表查询

这里没有做样例,官方的样例很明了:

type result struct {
  Name  string
  Email string
}

db.Model(&User{}).Select("users.name, emails.email").Joins("left join emails on emails.user_id = users.id").Scan(&result{})
// SELECT users.name, emails.email FROM `users` left join emails on emails.user_id = users.id

rows, err := db.Table("users").Select("users.name, emails.email").Joins("left join emails on emails.user_id = users.id").Rows()
for rows.Next() {
  ...
}

db.Table("users").Select("users.name, emails.email").Joins("left join emails on emails.user_id = users.id").Scan(&results)

// multiple joins with parameter
db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Joins("JOIN credit_cards ON credit_cards.user_id = users.id").Where("credit_cards.number = ?", "411111111111").Find(&user)
Joins预加载

预加载相关

db.Joins("Company").Find(&users)
// SELECT `users`.`id`,`users`.`name`,`users`.`age`,`Company`.`id` AS `Company__id`,`Company`.`name` AS `Company__name` FROM `users` LEFT JOIN `companies` AS `Company` ON `users`.`company_id` = `Company`.`id`;

// inner join
db.InnerJoins("Company").Find(&users)
// SELECT `users`.`id`,`users`.`name`,`users`.`age`,`Company`.`id` AS `Company__id`,`Company`.`name` AS `Company__name` FROM `users` INNER JOIN `companies` AS `Company` ON `users`.`company_id` = `Company`.`id`;
条件Joins
db.Joins("Company", db.Where(&Company{Alive: true})).Find(&users)
// SELECT `users`.`id`,`users`.`name`,`users`.`age`,`Company`.`id` AS `Company__id`,`Company`.`name` AS `Company__name` FROM `users` LEFT JOIN `companies` AS `Company` ON `users`.`company_id` = `Company`.`id` AND `Company`.`alive` = true;
Joins一个衍生表
type User struct {
    Id  int
    Age int
}

type Order struct {
    UserId     int
    FinishedAt *time.Time
}

// query 相当于一个虚拟表格
query := db.Table("order").Select("MAX(order.finished_at) as latest").Joins("left join user user on order.user_id = user.id").Where("user.age > ?", 18).Group("order.user_id")
db.Model(&Order{}).Joins("join (?) q on order.finished_at = q.latest", query).Scan(&results)
// SELECT `order`.`user_id`,`order`.`finished_at` FROM `order` join (SELECT MAX(order.finished_at) as latest FROM `order` left join user user on order.user_id = user.id WHERE user.age > 18 GROUP BY `order`.`user_id`) q on order.finished_at = q.latest

Scan

将结果扫描到结构体中,和Find用法相似。

1.4、高级查询

下期青训营再学。

高级查询

1.5、更新

设置一个钩子,自动输出更新后的结果,这样就不需要去看数据库了:

func (p *Product) AfterUpdate(tx *gorm.DB) (err error) {
	fmt.Printf("after Update: %#v\n", p)
	return
}

保存所有字段

Save方法会保存所有的字段,即使字段是零值。

// 先查询出来
db.First(p)
p.Price = 1000
db.Save(p)
// UPDATE `product` SET `code`=?,`price`=? WHERE `id` = ?

注意:Save是一个组合函数,如果保存的值中没有包含主键,那么它将执行Create,否则执行Update

不要将SaveModel一同使用,这是不被允许的。

更新单个列

使用Update更新单个列,要有一些条件,不然会导致ErrMissingWhereClause 错误。

p := &Product{}
// 先查询出来
db.First(p)

// 传入p,会调用p的主键,要确保有,不然没有更新条件,会报错
db.Model(p).Update("price", 800)
// UPDATE `product` SET `price`=? WHERE `id` = ?

更新多个列

Updates 方法支持 structmap[string]interface{} 参数, 对于struct结构,默认只会更新非零字段的值。

p := &Product{}
// 先查询出来
db.First(p)
// 传入p,会调用p的主键,要确保有,不然没有更新条件,会报错
db.Model(p).Updates(Product{Code: "S01", Price: 900})

注意 使用 struct 更新时, GORM 将只更新非零值字段,可以用map 来更新属性,或者使用 Select 声明字段来更新

更新选定字段

如果想要在更新的时候选择性的更新一个字段,可以使用SelectOmit

db.Model(p).Omit("price").Updates(map[string]interface{}{"price": 500, "code": "s02"})
// UPDATE `product` SET `code`=? WHERE `id` = ?

可以看到,跳过了price的更新。

等同于下面:

db.Model(p).Select("code").Updates(map[string]interface{}{"price": 500, "code": "s03"})
// UPDATE `product` SET `code`=? WHERE `id` = ?

批量更新

如果没有通过 Model 指定一个含有主键的记录, GORM会执行批量更新。

db.Model(&Product{}).Where("price = ?", 50).Update("price", 200)
//UPDATE `product` SET `price`=? WHERE price = ?

阻止全局更新

如果你执行一个没有任何条件的批量更新,GORM 默认不会运行,并且会返回 ErrMissingWhereClause 错误

ret := db.Model(&Product{}).Update("price", 500)
fmt.Println(ret.Error)

程序会终止。

你可以用一些条件,使用原生 SQL 或者启用 AllowGlobalUpdate 模式,例如:

db.Exec("UPDATE product SET price = ?", 500) // 原生SQL
// 启用AllowGlobalUpdate 
db.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&Product{}).Update("price", 500)

高级选项

高级选项 后面有时间再更新

使用SQL表达式更新

1.6、删除

创建一个钩子,提醒我们是否删除成功

func (p Product) AfterDelete(tx *gorm.DB) (err error) {
	if tx.Error != nil {
		fmt.Println(tx.Error)
		err = tx.Error
		return
	}
	fmt.Printf("Delete Successfully!\n")
	return nil
}

删除一条记录

删除一条记录时,删除对象需要指定主键,否则会触发批量删除

p := &Product{}
db.Last(p)
db.Delete(p)
// DELETE FROM `product` WHERE `product`.`id` = ?

// 带额外条件的删除
db.Where(xx).Delete(p)

根据主键删除

类似于根据主键查询

db.Delete(&Product{}, 9)
db.Delete(&Product{}, '9') // 会将'9' 自动转换成 9
db.Delete(&Product{}, []int{1,2,3})

批量删除

如果指定的值不包括主属性,那么会批量删除。

db.Where("price < ?", 10).Delete(&Product{})

可以将一个主键切片传递给Delete方法,以便更高效的删除数据量大的记录

ps := []Product{{ID: 1}, {ID: 2}}
db.Delete(&ps)

阻止全局删除

使用方式与 阻止全局更新一样。

软删除 !

如果模型包含了 gorm.DeletedAt字段(该字段也被包含在gorm.Model中),那么该模型将会自动获得软删除的能力!

当调用Delete时,GORM并不会从数据库中删除该记录,而是将该记录的DeleteAt设置为当前时间,而后的一般查询方法将无法查找到此条记录。

如果不想使用gorm.Model添加所有的字段,可以使用下面例子开启软删除特性:

type User struct {
  ID      int
  Deleted gorm.DeletedAt
  Name    string
}
查找被软删的记录
db.Unscoped().Where("age = 20").Find(&users)
// SELECT * FROM users WHERE age = 20;
永久删除
db.Unscoped().Delete(&order)
// DELETE FROM orders WHERE id=10;
删除标志

默认情况下,gorm.Model使用*time.Time作为DeletedAt 的字段类型,不过软删除插件gorm.io/plugin/soft_delete同时也提供其他的数据格式支持。

Note: 当使用DeletedAt创建唯一复合索引时,你必须使用其他的数据类型,例如通过gorm.io/plugin/soft_delete插件将字段类型定义为unix时间戳等等

具体可见

二、事务

禁用默认事务

默认情况下,GORM是开启事务的,测试一下:

func SimpleDemo() (err error) {
	newLogger := logger.New(
		log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer
		logger.Config{
			SlowThreshold:             time.Second, // Slow SQL threshold
			LogLevel:                  logger.Info, // Log level
			IgnoreRecordNotFoundError: true,        // Ignore ErrRecordNotFound error for logger
			ParameterizedQueries:      true,        // Don't include params in the SQL log
			Colorful:                  true,
		},
	)
	db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
		Logger: newLogger,
	})
	if err != nil {
		panic(err)
	}
	// 演示默认事务
	err = db.Transaction(func(tx *gorm.DB) error {
		// 修改product的一个产品
		ret := tx.Model(&Product{}).Where("id = ?", 6).Update("price", 300)
		if ret.Error != nil {
			return errors.New(ret.Error.Error())
		}
		// 查看users表中不存在的记录
		user := &User{}
		// 应该报错 ErrRecord..
		ret = tx.Model(&User{}).Where("id = ?", 3).First(user)
		if ret.Error != nil {
			return errors.New(ret.Error.Error())
		}
		return nil
	})
	if err != nil {
		panic(err)
	}
	return err
}

执行完之后,会发现控制台报错,product表格也没有更新,这说明事务是默认开启的。

如果没有这方面的要求,可以在初始化的时候禁用,大约提升的性能:

db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
    Logger: newLogger,
    SkipDefaultTransaction: true, // 全部禁用事务
})
tx := db.Session(&gorm.Session{SkipDefaultTransaction: true})
// tx.xx 后续tx开头的都没有事务了

Note: 开启事务之后,一定要用tx,而不是db。在Transaction中,返回任何错误都是回滚事务,返回nil则会提交事务。

嵌套事务

GORM支持嵌套事务,可以回滚大事务内执行的一部分操作:

func EmbeddingTx() {
	db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
		Logger: getLogger(),
	})
	if err != nil {
		panic(err)
	}

	err = db.Transaction(func(tx *gorm.DB) error {
		// 修改product的一个产品
		ret := tx.Model(&Product{}).Where("id = ?", 6).Update("price", 300)
		if ret.Error != nil {
			return errors.New(ret.Error.Error())
		}

		err = tx.Transaction(func(tx1 *gorm.DB) error {
			// 查询user的一个记录
			user := &User{}
			// 应该报错 ErrRecord..
			ret = tx1.Model(&User{}).Where("id = ?", 3).First(user)
			if ret.Error != nil {
				return errors.New(ret.Error.Error())
			}
			return nil
		})

		return nil
	})
}

这里可以发现,product表中的数据已经修改了

手动事务

Gorm支持直接调用事务控制方法:

// 开始事务
tx := db.Begin()

// 在事务中执行一些 db 操作(从这里开始,您应该使用 'tx' 而不是 'db')
tx.Create(...)

// ...

// 遇到错误时回滚事务
tx.Rollback()

// 否则,提交事务
tx.Commit()

在手动提交事务中,每次遇到错误必须调用tx.Rollback(),不如自动事务方便,返回err就是回滚了。

func HandTx() error {
	db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
		Logger: getLogger(),
	})
	if err != nil {
		panic(err)
	}
	// 开启事务
	tx := db.Begin()
	// 最终一定会执行这个,保证事务回滚
	defer func() {
		if err := recover(); err != nil {
			tx.Rollback()
		}
	}()

	ret := tx.Model(&Product{}).Where("id = ?", 6).Update("price", 400)
	if ret.Error != nil {
		tx.Rollback() // 出错就要回滚
		return ret.Error
	}
	// 查询user的一个记录
	user := &User{}
	// 应该报错 ErrRecord..
	ret = tx.Model(&User{}).Where("id = ?", 3).First(user)
	if ret.Error != nil {
		tx.Rollback()
		return ret.Error
	}
	return tx.Commit().Error
}

执行完之后,会发现product表中的数据没有变化。

SavePoint、RollbackTo

GORM提供了SavePoint、RollbackTo方法,来提供保存点以及回滚至保存点功能,例如:

tx := db.Begin()
tx.Create(&user1)

tx.SavePoint("sp1") // 保存断点位置
tx.Create(&user2)
tx.RollbackTo("sp1") // 出错了就回滚到sp1这个位置

tx.Commit() // Commit user1