通用分类树工具库(Category)

1 阅读11分钟

通用分类树工具库:从设计到使用

一篇关于如何设计和实现一个泛型、高性能、易用的分类树处理工具库的完整技术分享

本文目录

  • 写在前面

  • 一、设计目标:我们要解决什么问题?

  • 二、核心数据结构设计

  • 三、核心功能实现

  • 四、使用示例

  • 五、性能优化与注意事项

  • 六、总结


在业务系统开发中,分类树(Category Tree)是一个非常常见的需求:
商品分类、组织架构、行政区划、菜单权限...几乎每个项目都会遇到。:
但每次都要重复写递归找子节点、找父节点、列表转树形结构这些代码,不仅繁琐,还容易出错。
本文将分享一个泛型分类树工具库的设计与实现过程,它具备以下特点:

  • 泛型支持:ID 可以是 int 或 string,数据可以是任意类型
  • 零依赖:纯标准库实现,开箱即用
  • 功能完整:列表转树、树转列表、查找子树、查找祖级、查找特定节点
  • 类型安全:充分利用 Go 泛型,编译时保证类型正确

一、设计目标:我们要解决什么问题?

在开始写代码之前,先明确我们要解决的核心问题:

1.1 常见的分类树操作需求

// 假设我们有这样的分类数据
type Category struct {
    ID   int
    Pid  int
    Name string
}

var categories = []Category{
    {ID: 1, Pid: 0, Name: "数码"},
    {ID: 2, Pid: 1, Name: "手机"},
    {ID: 3, Pid: 2, Name: "iPhone"},
    // ...
}

我们需要的能力:

  • 列表 → 树:将平铺的列表转换为树形结构(用于前端展示)
  • 树 → 列表:将树形结构重新平铺(用于数据导出)
  • 查找子树:给定一个节点,找到它下面所有的子节点(含自身)
  • 查找祖级:给定一个节点,找到它上面所有的父节点
  • 查找节点:在树中查找特定 ID 的节点

1.2 设计

  • ID 类型不确定:可能是 int(可以拓展为 comparable ),可能是 string
  • 数据类型不确定:可能是商品分类,可能是部门,可能是区域
  • 性能要求:频繁的递归查找不能太慢
  • 易用性:API 要简单直观

二、核心数据结构设计

2.1 Item:树的节点定义

type (
    // keyType 约束 ID 的类型只能是 int 或 string (可以改为comparable)
    keyType interface {
        int | string
    }

    // Item 树的节点
    Item[T keyType, D any] struct {
        ID       T             `json:"id"`                 // 节点ID
        Pid      T             `json:"pid"`                 // 父节点ID
        Name     string        `json:"name"`                // 节点名称
        Raw      D             `json:"raw,omitempty,optional"` // 原始数据
        Children []*Item[T, D] `json:"children,omitempty,optional"` // 子节点
    }
)

设计思路:

  • 使用泛型
    • T支持灵活的 ID类型
    • D承载任意业务数据
  • Children是切片,保持顺序性
  • Raw字段可选,用于携带原始业务对象

2.2 Category:工具类的主体

type Category[T keyType, D any] struct {
    List  []*Item[T, D] // 平铺列表(原始数据)
    Trees []*Item[T, D] // 树形结构(根节点列表)
}

设计思路:

  • 同时维护列表和树,避免重复计算
  • 列表用于快速查找和祖级追溯
  • 树用于树形操作和子节点查找

三、核心功能实现

3.1 列表转树(Conv)

这是最核心的功能:将平铺的列表转换为树形结构。

// Conv 转换列表为分类树
func (c *Category[T, D]) Conv(list []D, call func(D) *Item[T, D]) *Category[T, D] {
    // 1. 转换列表
    var length = len(list)
    c.List = make([]*Item[T, D], length)
    for key, item := range list {
        var v = call(item)
        
        // 2. 过滤无效ID
        if any(v.ID) == nil {
            continue
        }
        
        // 处理 string 类型的空值
        if tmp, ok := any(v.ID).(string); ok {
            if tmp == "" {
              continue
            }
        }
        
        // 处理 int 类型的零值
        if tmp, ok := any(v.ID).(int); ok {
            if tmp tmp == 0 {
                continue
            }
        }

        c.List[key] = v
    }

    if len(c.List) <= 0 {
        return c
    }

    // TODO 如果使用comparable这里要注意区分 comparable/keyType
    // 3. 构建树形结构(根节点ID为0)
    c.Trees = c.makeTrees(T(0))
    return c
}

关键细节:

  • 通过回调函数让调用方决定如何构建 Item
  • 自动过滤无效 ID(nil、空字符串、0值)
  • 根节点的 PID 约定为 0(或空字符串)

3.2 递归构建树(makeTrees)

// makeTrees 递归构建树形结构
func (c *Category[T, D]) makeTrees(pid T) []*Item[T, D] {
    var children []*Item[T, D]
    for _, item := range c.List {
        var value = new(Item[T, D])
        *value = *item  // 值拷贝,避免相互影响
        
        if value.Pid == pid {
            children = append(children, value)
            // 递归找子节点
            value.Children = c.makeTrees(value.ID)
        }
    }

    if len(children) <= 0 {
        children = []*Item[T, D]{}
    }

    return children
}

设计要点:

  • 值拷贝*value = *item 确保修改子节点不影响原列表
  • 递归终止:当没有子节点时返回空切片
  • 深度优先:先找子节点,再找孙节点

3.3 树转列表(SubFlatList)

// SubFlatList 树状结构转平铺列表
func (c *Category[T, D]) SubFlatList(trees []*Item[T, D]) []*Item[T, D] {
    var list []*Item[T, D]
    for _, item := range trees {
        var val = new(Item[T, D])
        *val = *item
        val.Children = nil  // 平铺时清除子节点
        
        list = append(list, val)
        if len(item.Children) > 0 {
            list = append(list, c.SubFlatList(item.Children)...)
        }
    }
    return list
}

应用场景:

  • 数据库存储
  • 数据导出
  • 序列化传输

3.4 查找子树(FindTrees)

// FindTrees 查找指定ID下所有子集(包含自身)
func (c *Category[T, D]) FindTrees(parentIds []T, list []D, call func(D) *Item[T, D]) []T {
    var trees = c.Conv(list, call).Trees
    if len(trees) <= 0 {
        return nil
    }

    var (
        ids     = append([]T{}, parentIds...)
        records []*Item[T, D]
    )
    
    // 1. 找到所有父节点
    for _, id := range parentIds {
        var item = c.FindId(id, trees)
        if item == nil {
            continue
        }
        records = append(records, item)
    }

    // 2. 收集所有子节点ID
    for _, item := range records {
        for _, subItem := range c.SubFlatList(item.Children) {
            ids = append(ids, subItem.ID)
        }
    }

    return ids
}

典型场景:

  • 删除分类时找出所有子分类
  • 权限继承时找出所有子菜单
  • 统计某个节点下所有数据

3.5 查找祖级(FindParents)

// FindParents 查找祖级节点
func (c *Category[T, D]) FindParents(ID T) []*Item[T, D] {
    // 1. 找到当前节点
    var current *Item[T, D]
    for _, item := range c.List {
        if item.ID == ID {
            current = new(Item[T, D])
            *current = *item
        }
    }

    if current == nil {
        return nil
    }

    var data = []*Item[T, D]{current}
    
    // 2. 检查是否为根节点
    if pid, ok := any(current.Pid).(string); ok && pid == "" {
        return data
    }
    if pid, ok := any(current.Pid).(int); ok && pid == 0 {
        return data
    }

    // 3. 递归找父节点并反转顺序(从上到下)
    var (
        list  = append([]*Item[T, D]{current}, c.findParents(current.Pid)...)
        res   = make([]*Item[T, D], len(list))
        index = 0
    )
    for i := len(list) - 1; i >= 0; i-- {
        res[index] = list[i]
        index += 1
    }

    return res
}

func (c *Category[T, D]) findParents(pid T) []*Item[T, D] {
    var parents []*Item[T, D]
    for _, item := range c.List {
        if item.ID == pid {
            var val = new(Item[T, D])
            *val = *item
            parents = append(parents, val)

            // 递归终止条件
            if pid, ok := any(val.Pid).(string); ok && pid == "" {
                break
            }
            if pid, ok := any(val.Pid).(int); ok && pid == 0 {
                break
            }

            parents = append(parents, c.findParents(val.Pid)...)
        }
    }
    return parents
}

返回值设计:返回从根到当前节点的路径,便于面包屑导航。

3.6 查找节点(Find/FindId)

// Find 在树中查找指定ID的节点
func (c *Category[T, D]) Find(ID T) *Item[T, D] {
    return c.find(ID, c.Trees)
}

func (c *Category[T, D]) find(ID T, subs []*Item[T, D]) *Item[T, D] {
    for _, item := range subs {
        if item.ID == ID {
            return item
        }

        if len(item.Children) > 0 {
            if val := c.find(ID, item.Children); val != nil {
                return val
            }
        }
    }
    return nil
}

// FindId 在指定数据中查找节点(支持外部传入数据)
func (c *Category[T, D]) FindId(id T, data []*Item[T, D]) *Item[T, D] {
    for _, item := range data {
        if item.ID == id {
            return item
        }

        if len(item.Children) > 0 {
            var data = c.FindId(id, item.Children)
            if data != nil {
                return data
            }
        }
    }
    return nil
}

四、使用示例

4.1 基础用法

package main

import (
    "encoding/json"
    "fmt"
    "strings"
)

// 定义业务数据结构
type ProductCategory struct {
    ID    int
    Pid   int
    Name  string
    Level int // 业务特有字段
}

func main() {
    // 准备数据
    var data = []ProductCategory{
        {ID: 1, Pid: 0, Name: "数码", Level: 1},
        {ID: 2, Pid: 0, Name: "家电", Level: 1},
        {ID: 3, Pid: 1, Name: "手机", Level: 2},
        {ID: 4, Pid: 3, Name: "iPhone", Level: 3},
        {ID: 5, Pid: 3, Name: "华为", Level: 3},
        {ID: 6, Pid: 2, Name: "电视", Level: 2},
    }

    // 创建分类树
    var category = New[int, ProductCategory]().Conv(
        data,
        func(item ProductCategory) *Item[int, ProductCategory] {
            return &Item[int, ProductCategory]{
                ID:   item.ID,
                Pid:  item.Pid,
                Name: item.Name,
                Raw:  item, // 携带原始数据
            }
        },
    )

    // 查看树形结构
    treesJSON, _ := json.MarshalIndent(category.Trees, "", "  ")
    fmt.Printf("树形结构:\n%s\n", treesJSON)

    // 查找 iPhone 的祖级路径
    var (
        parents := category.FindParents(4)
        names []string
    )
    for _, p := range parents {
        names = append(names, p.Name)
    }
    fmt.Printf("iPhone 的路径:%s\n", strings.Join(names, " > "))
    
    // 在树中查找节点
    var node = category.Find(3)
    if node != nil {
        fmt.Printf("找到节点:ID=%d, Name=%s\n", node.ID, node.Name)
    }
}

4.2 实际业务场景

// 场景1:删除分类时找出所有子分类ID
func DeleteCategory(cateID int) {
    // 获取所有分类(从数据库)
    var allCategories []ProductCategory
    db.Find(&allCategories)
    
    var cate = New[int, ProductCategory]().Conv(
        allCategories,
        func(item ProductCategory) *Item[int, ProductCategory] {
            return &Item[int, ProductCategory]{ID: item.ID, Pid: item.Pid}
        },
    )
    
    // 找出所有要删除的ID(包含子分类)
    var idsToDelete = cate.FindTrees([]int{cateID}, allCategories, 
        func(item ProductCategory) *Item[int, ProductCategory] {
            return &Item[int, ProductCategory]{ID: item.ID, Pid: item.Pid}
        },
    )
    
    // 批量删除
    db.Where("id IN ?", idsToDelete).Delete(&ProductCategory{})
}

// 场景2:构建前端级联选择器数据
type CascaderOption struct {
    Value    int              `json:"value"`
    Label    string           `json:"label"`
    Children []CascaderOption `json:"children,omitempty"`
}

func BuildCascaderOptions() []CascaderOption {
    var all []ProductCategory
    db.Find(&all)
    
    var cate = New[int, ProductCategory]().Conv(all, 
        func(item ProductCategory) *Item[int, ProductCategory] {
            return &Item[int, ProductCategory]{
                ID: item.ID, Pid: item.Pid, Name: item.Name,
            }
        },
    )
    
    return convertToCascader(cate.Trees)
}

func convertToCascader(trees []*Item[int, ProductCategory]) []CascaderOption {
    var res []CascaderOption
    for _, node := range trees {
        res = append(res, CascaderOption{
            Value:    node.ID,
            Label:    node.Name,
            Children: convertToCascader(node.Children),
        })
    }
    return res
}

// 场景3:面包屑导航
func GetBreadcrumb(cateID int) string {
    var all []ProductCategory
    db.Find(&all)
    
    var cate = New[int, ProductCategory]().Conv(
        all, 
        func(item ProductCategory) *Item[int, ProductCategory] {
            return &Item[int, ProductCategory]{
                ID: item.ID, Pid: item.Pid, Name: item.Name,
            }
        },
    )
    
    var (
        parents = cate.FindParents(cateID)
        names []string
    )
    for _, p := range parents {
        names = append(names, p.Name)
    }
    return strings.Join(names, " > ")
}

4.3 字符串类型ID的使用

// 适用于 MongoDB、UUID 等场景
type Org struct {
    ID      string
    Pid     string
    Name    string
    Manager string
}

func HandleOrg() {
    var (
        orgs = []Org{
            {ID: "root", Pid: "", Name: "总公司", Manager: "张三"},
            {ID: "dept1", Pid: "root", Name: "技术部", Manager: "李四"},
            {ID: "dept2", Pid: "dept1", Name: "前端组", Manager: "王五"},
        }
        tree = New[string, Org]().Conv(
            orgs,
            func(item Org) *Item[string, Org] {
                return &Item[string, Org]{
                    ID:   item.ID,
                    Pid:  item.Pid,
                    Name: item.Name,
                    Raw:  item,
                }
            },
        )
    )
    
    // 查找前端组的所有上级
    var parents = tree.FindParents("dept2")
    for _, p := range parents {
        fmt.Printf("上级:%s(负责人:%s)\n", p.Name, p.Raw.Manager)
    }
    
    // 树转列表
    var flatList = tree.SubFlatList(tree.Trees)
    fmt.Printf("平铺后共 %d 个节点\n", len(flatList))
}

五、性能优化与注意事项

5.1 值拷贝的设计考量

// 为什么需要值拷贝?
var value = new(Item[T, D])
*value = *item  // 如果不拷贝,修改 Children 会影响原列表

如果不进行值拷贝,对 Children 的修改会污染原始数据,导致多次调用结果不一致。

5.2 递归的深度控制

对于分类树这种层级有限的场景(通常不超过10层),递归完全够用。如果担心栈溢出,可以改为迭代实现:

// 迭代版本的查找(防止栈溢出)
func (c *Category[T, D]) FindIterative(ID T) *Item[T, D] {
    var stack = make([]*Item[T, D], 0)
    stack = append(stack, c.Trees...)
    
    for len(stack) > 0 {
        node := stack[len(stack)-1]
        stack = stack[:len(stack)-1]
        
        if node.ID == ID {
            return node
        }
        
        if len(node.Children) > 0 {
            stack = append(stack, node.Children...)
        }
    }
    return nil
}

5.3 内存使用优化

  • 复用实例:复用 Category 实例,避免重复构建
  • 按需携带 Raw:如果不需要原始数据,可以不传 Raw
  • 大数据量处理:考虑分批处理或使用数据库递归查询
// 优化示例:不携带 Raw 数据
var res = New[int, ProductCategory]().Conv(
    data,
    func(item ProductCategory) *Item[int, ProductCategory] {
        return &Item[int, ProductCategory]{
            ID:   item.ID,
            Pid:  item.Pid,
            Name: item.Name,
            // Raw: item, // 如果不需原始数据,省略该字段
        }
    },
)

5.4 并发安全

当前实现非并发安全,如果需要在并发环境使用:

type SafeCategory[T keyType, D any] struct {
    mu sync.RWMutex
    *Category[T, D]
}

func (s *SafeCategory[T, D]) Conv(list []D, call func(D) *Item[T, D]) *SafeCategory[T, D] {
    s.mu.Lock()
    defer s.mu.Unlock()
    s.Category.Conv(list, call)
    return s
}

// 其他方法类似...

5.5 常见陷阱

  • 循环引用:确保数据中没有循环引用(A的父是B,B的父是A)
  • ID唯一性:同一棵树中ID必须唯一
  • 根节点约定:根节点的Pid必须是0或空字符串
  • 空值处理:Conv时会自动过滤无效ID

六、总结

已实现的功能

  • 列表 <-----> 树的双向转换
  • 查找子树(包含自身)
  • 查找祖级路径
  • 查找任意节点
  • 泛型支持 int/string ID
  • 任意业务数据类型
  • 零依赖,纯标准库实现

设计心得

  • 泛型:以前需要用 interface{} + 类型断言的地方,使用泛型既安全又优雅
  • 值拷贝要谨慎:该拷贝时一定要拷贝,不该拷贝时不要浪费内存
  • API 要直观:方法命名要符合直觉(Find、Conv、SubFlatList)

源码

package categories

type (
	keyType interface {
		int | string
	}

	Item[T keyType, D any] struct {
		ID       T             `json:"id"`
		Pid      T             `json:"pid"`
		Name     string        `json:"name"`
		Raw      D             `json:"raw,omitempty,optional"`
		Children []*Item[T, D] `json:"children,omitempty,optional"`
	}

	Category[T keyType, D any] struct {
		List, Trees []*Item[T, D]
	}
)

func New[T keyType, D any]() *Category[T, D] {
	return &Category[T, D]{}
}

// Conv 转换列表为分类列表
func (c *Category[T, D]) Conv(list []D, call func(D) *Item[T, D]) *Category[T, D] {
	var length = len(list)
	c.List = make([]*Item[T, D], length)
	for key, item := range list {
		var v = call(item)
		if any(v.ID) == nil {
			continue
		}

		if tmp, ok := any(v.ID).(string); ok {
			if tmp == "" {
				continue
			}
		}

		if tmp, ok := any(v.ID).(int); ok {
			if tmp == 0 {
				continue
			}
		}

		c.List[key] = v
	}

	if len(c.List) <= 0 {
		return c
	}

	// 将分类结构化
	c.Trees = c.makeTrees(T(0))
	return c
}

// SubFlatList 结构化分类
func (c *Category[T, D]) makeTrees(pid T) []*Item[T, D] {
	var children []*Item[T, D]
	for _, item := range c.List {
		var value = new(Item[T, D])
		*value = *item
		if value.Pid == pid {
			children = append(children, value)
			value.Children = c.makeTrees(value.ID)
		}
	}

	if len(children) <= 0 {
		children = []*Item[T, D]{}
	}

	return children
}

// SubFlatList 子集树状结构转平铺列表
func (c *Category[T, D]) SubFlatList(trees []*Item[T, D]) []*Item[T, D] {
	var list []*Item[T, D]
	for _, item := range trees {
		var val = new(Item[T, D])
		*val = *item
		val.Children = nil

		list = append(list, val)
		if len(item.Children) > 0 {
			list = append(list, c.SubFlatList(item.Children)...)
		}
	}

	return list
}

// FindTrees 查找指定id下所有子集包含自身
func (c *Category[T, D]) FindTrees(parentIds []T, list []D, call func(D) *Item[T, D]) []T {
	var trees = c.Conv(list, call).Trees
	if len(trees) <= 0 {
		return nil
	}

	var (
		ids     = append([]T{}, parentIds...)
		records []*Item[T, D]
	)
	// 查询
	for _, id := range parentIds {
		var item = c.FindId(id, trees)
		if item == nil {
			continue
		}

		records = append(records, item)
	}

	for _, item := range records {
		for _, item := range c.SubFlatList(item.Children) {
			ids = append(ids, item.ID)
		}
	}

	return ids
}

func (c *Category[T, D]) FindId(id T, data []*Item[T, D]) *Item[T, D] {
	for _, item := range data {
		if item.ID == id {
			return item
		}

		if len(item.Children) > 0 {
			var data = c.FindId(id, item.Children)
			if data != nil {
				return data
			}
		}
	}

	return nil
}

// Find 查找id
func (c *Category[T, D]) Find(ID T) *Item[T, D] {
	return c.find(ID, c.Trees)
}

func (c *Category[T, D]) find(ID T, subs []*Item[T, D]) *Item[T, D] {
	for _, item := range subs {
		if item.ID == ID {
			return item
		}

		if len(item.Children) > 0 {
			if val := c.find(ID, item.Children); val != nil {
				return val
			}
		}
	}

	return nil
}

// FindParents 查找祖级
func (c *Category[T, D]) FindParents(ID T) []*Item[T, D] {
	var current *Item[T, D]
	for _, item := range c.List {
		if item.ID == ID {
			current = new(Item[T, D])
			*current = *item
		}
	}

	if current == nil {
		return nil
	}

	var data = []*Item[T, D]{current}
	if pid, ok := any(current.Pid).(string); ok {
		if pid == "" {
			return data
		}
	}

	if pid, ok := any(current.Pid).(int); ok {
		if pid == 0 {
			return data
		}
	}

	var (
		list  = append([]*Item[T, D]{current}, c.findParents(current.Pid)...)
		res   = make([]*Item[T, D], len(list))
		index = 0
	)
	for i := len(list) - 1; i >= 0; i-- {
		res[index] = list[i]
		index += 1
	}

	return res
}

func (c *Category[T, D]) findParents(pid T) []*Item[T, D] {
	var parents []*Item[T, D]
	for _, item := range c.List {
		if item.ID == pid {
			var val = new(Item[T, D])
			*val = *item
			parents = append(parents, val)

			if pid, ok := any(val.Pid).(string); ok {
				if pid == "" {
					break
				}
			}

			if pid, ok := any(val.Pid).(int); ok {
				if pid == 0 {
					break
				}
			}

			parents = append(parents, c.findParents(val.Pid)...)
		}
	}

	return parents
}