go 第三方库源码解读---golang-set

290 阅读13分钟

我们今天来学习 golang-set 包,这是一个用于处理集合的包,它提供了一些常用的集合操作,比如并集、交集、差集等,并且支持线程安全

学习一个库从它的测试用例开始,我们来看的第一个测试用例时用力测试 NewSet 这个 api

NewSet 会基于给定的元素一个新的集合 Set,这个集合有很多方法:

  • Add:向集合中添加一个元素,返回该元素是否被添加成功
  • Append:将多个元素添加到集合中,返回添加的元素数量
  • Cardinality:返回集合中元素的个数
  • Clear:清空集合
  • Clone:克隆一个集合
  • Contains:返回给定元素是否在集合中
  • ContainsOne:返回给定元素是否在集合中
  • ContainsAny:给定的元素中至少有一个在集合中
  • ContainsAnyElement:检查是否包含另一个集合中的任意一个元素
  • Equal:如果两个集合中的的元素个数相等,并且元素也是一样的,那么他们就是相等的,和元素的顺序没有关系。如果两个集合中的元素类型不一致,会发生 panic
  • IsEmpty:判断集合是否为空
  • IsProperSubset:子集,用于判断集合 a 是否是集合 b 的子集,且 a != b,注意传入的参数必须与方法的接收者类型相同,否则会引发 panic
  • IsProperSuperset:超集,用于判断集合 a 是否是集合 b 的超集,且 a != b,注意传入的参数必须与方法的接收者类型相同,否则会引发 panic
  • IsSubset:真子集,用于判断集合 a 是否是集合 b 的子集,注意传入的参数必须与方法的接收者类型相同,否则会引发 panic
  • IsSuperset:真超集,用于判断集合 a 是否是集合 b 的超集,注意传入的参数必须与方法的接收者类型相同,否则会引发 panic
  • Each:遍历每个元素并对每个元素执行传递的函数,如果传递的函数返回 true,则停止遍历
  • Iter:返回一个迭代器,迭代器是一个只读的通道,通过这个通道可以遍历集合中的每个元素,可以使用 for range 来遍历
  • Iterator:返回一个迭代器,迭代器是一个只读的通道,通过这个通道可以遍历集合中的每个元素,可以使用 for range 来遍历,和 Iter 方法的区别在于 Iterator 返回的是一个 Iterator 类型的结构体
  • Remove:从集合中移除单个元素
  • RemoveAll:从集合中移除多个元素
  • String:返回集合的字符串表示
  • Union:并集,返回一个包含两个集合中所有元素的新集合,注意传入的参数必须与方法的接收者类型相同,否则会引发 panic
  • Intersect:交集,返回一个包含两个集合中共有元素的新集合,注意传入的参数必须与方法的接收者类型相同,否则会引发 panic
  • Difference:差集,返回一个包含 a 集合中的元素不在 b 集合中的元素的新集合,注意传入的参数必须与方法的接收者类型相同,否则会引发 panic
  • SymmetricDifference:对称差,返回一个包含两个集合中不相同元素的新集合,注意传入的参数必须与方法的接收者类型相同,否则会引发 panic
  • Pop:从集合中移除一个元素,返回移除的元素和是否移除成功
  • ToSlice:将集合的成员转成切片
  • MarshalJSON:将集合转成 json 格式
  • UnmarshalJSON:将 json 格式的数据转成集合

我们先来看第一个测试用例:Test_NewSet

前置知识

在学习这个库之前,需要了解一些前置知识,比如 comparable 接口

comparable

comparable 是一个内置接口约束,它表示这个类型可以使用 ==!= 运算符进行比较

  • 基本类型:
    • 布尔值
    • 数字
    • 字符串
    • channel
    • 接口
    • 指针
  • 复合类型:
    • 数组
    • 结构体

要注意的是 mapkey 不能使用 comparable 接口,因为 map 需要具体的可比较类型作为 key

// 错误
var m map[comparable]string

// 正确
var m1 map[string]string
var m2 map[int]string
var m3 map[struct{X int}]string

但是在泛型函数或类型中,可以使用 comparable 作为类型约束

// 泛型函数
func NewMap[K comparable, V any]() map[K]V {
	return make(map[K]V)
}

// 泛型类型
type Map[K comparable, V any] struct {
	M map[K]V
}

Test_NewSet

Test_NewSet 测试用例测试了 NewSet 的基本功能,我传入的是一个可变参数,要保证它的类型是正确的

在这个测试用例中包含了 Set 的两个方法:CardinalityEqual

func Test_NewSet(t *testing.T) {
	a := NewSet[int]()
	if a.Cardinality() != 0 {
		t.Error("NewSet should start out as an empty set")
	}

	assertEqual(NewSet([]int{}...), NewSet[int](), t)
	assertEqual(NewSet([]int{1}...), NewSet(1), t)
	assertEqual(NewSet([]int{1, 2}...), NewSet(1, 2), t)
	assertEqual(NewSet([]string{"a"}...), NewSet("a"), t)
	assertEqual(NewSet([]string{"a", "b"}...), NewSet("a", "b"), t)
}

NewSet

在学习 Cardinality 方法之前,我们先来看下 NewSet

NewSet 是一个泛型函数,它接受一个可变参数,这个参数的类型是 comparable 的,然后通过 newThreadSafeSetWithSize 函数创建一个线程安全的 Set,并将参数中的元素添加到这个 Set

func NewSet[T comparable](vals ...T) Set[T] {
	s := newThreadSafeSetWithSize[T](len(vals))
	for _, item := range vals {
		s.Add(item)
	}
	return s
}

newThreadSafeSetWithSize 函数是基于传入的基数创建一个线程安全的 Set,这个线程安全的 Set 是一个 threadSafeSet 类型的结构体

func newThreadSafeSetWithSize[T comparable](cardinality int) *threadSafeSet[T] {
	return &threadSafeSet[T]{
		uss: newThreadUnsafeSetWithSize[T](cardinality),
	}
}

threadSafeSet 结构体如下:

type threadSafeSet[T comparable] struct {
	sync.RWMutex
	uss *threadUnsafeSet[T]
}

线程安全的底层也是一个线程不安全的 map

type threadUnsafeSet[T comparable] map[T]struct{}

为什么线程安全的 Set 使用了线程不安全的 map 呢?线程安全相比于线程不安全多了一个锁,所以底层还是一个 map 操作,只是在操作 map 时加了锁

func newThreadUnsafeSetWithSize[T comparable](cardinality int) *threadUnsafeSet[T] {
	t := make(threadUnsafeSet[T], cardinality)
	return &t
}

Set 实现了之后,已经创建了一个空的 map 集合,接下来就是将参数中的元素添加到这个集合中,添加元素的方法是 Add

Add

这个 Add 方法是 threadSafeSet 上的方法,因为这个 Set 对外提供的方法都是线程安全的,所以 Add 方法也是线程安全的

Add 在执行前会先加锁,然后调用 threadUnsafeSet 上的 Add 方法,最后释放锁

func (t *threadSafeSet[T]) Add(v T) bool {
	t.Lock()
	ret := t.uss.Add(v)
	t.Unlock()
	return ret
}

threadUnsafeSet 上的 Add 方法如下:

在添加之前先获取下原先 map 的长度,然后将元素添加到 map 中,添加后在和原先的长度进行比较,如果不相等,说明添加成功,返回 true,否则返回 false

func (s threadUnsafeSet[T]) Add(v T) bool {
	prevLen := len(s)
	s[v] = struct{}{}
	return prevLen != len(s)
}

Set 是基于 map 实现的,但是我们并不关心 map 中的 value,所以 value 使用了一个空的结构体 struct{},这样可以节省空间

Equal

在测试中要比较两个 Set 是否相等,需要有 Equal 方法,如下:

set1 := NewSet([]int{1}...)
set2 := NewSet(1)
set1.Equal(set2)

Equal 方法会先获取两个 Set 的锁,然后调用 threadUnsafeSet 上的 Equal 方法,最后释放锁

为什么要获取两个 Set 锁呢?这是因为要保证在比较过程中两个 Set 都不会被修改,如果只锁其中一个 Set 就会出现在比较的过程中另一个 Set 被修改的情况

func (t *threadSafeSet[T]) Equal(other Set[T]) bool {
	o := other.(*threadSafeSet[T])

	t.RLock()
	o.RLock()

	ret := t.uss.Equal(o.uss)
	t.RUnlock()
	o.RUnlock()
	return ret
}

threadUnsafeSet 上的 Equal 先是比较两个 Set 的基数是否相等,如果不相等,直接返回 false

然后遍历一个 Set 中的元素,判断另一个 Set 中是否包含这个元素,如果不包含,返回 false,否则返回 true

func (s *threadUnsafeSet[T]) Equal(other Set[T]) bool {
	o := other.(*threadUnsafeSet[T])

	if s.Cardinality() != other.Cardinality() {
		return false
	}
	for elem := range *s {
		if !o.contains(elem) {
			return false
		}
	}
	return true
}

Append

Append 的作用是将多个元素添加到集合中,返回添加的元素数量

这个实现也分为线程安全和线程不安全两个部分

我们先来看线程安全的 Append 方法,Append 接收可变参数,进入方法时先加锁,然后调用 threadUnsafeSet 上的 Append 方法,最后释放锁

func (t *threadSafeSet[T]) Append(v ...T) int {
	t.Lock()
	ret := t.uss.Append(v...)
	t.Unlock()
	return ret
}

threadUnsafeSet 上的 Append 方法和 Add 类似,只是它接收的是可变参数,遍历参数,将每个元素添加到 map

func (s *threadUnsafeSet[T]) Append(v ...T) int {
	prevLen := len(*s)
	for _, val := range v {
		(*s)[val] = struct{}{}
	}
	return len(*s) - prevLen
}

Contains

Contains 方法用于判断给定的元素是否在集合中,如果在集合中返回 true,否则返回 false

也可以传入多个元素,只有当所有元素都在集合中时才返回 true

func (t *threadSafeSet[T]) Contains(v ...T) bool {
	t.RLock()
	ret := t.uss.Contains(v...)
	t.RUnlock()

	return ret
}
func (s *threadUnsafeSet[T]) Contains(v ...T) bool {
	for _, val := range v {
		if _, ok := (*s)[val]; !ok {
			return false
		}
	}
	return true
}

ContainsOne

ContainsOne 方法用于判断给定的元素是否在集合中,如果在集合中返回 true,否则返回 false

Contains 方法不同的是,ContainsOne 只接收一个参数

func (t *threadSafeSet[T]) ContainsOne(v T) bool {
	t.RLock()
	ret := t.uss.ContainsOne(v)
	t.RUnlock()

	return ret
}
func (s *threadUnsafeSet[T]) ContainsOne(v T) bool {
	_, ok := (*s)[v]
	return ok
}

ContainsAny

ContainsAny 方法用于判断给定的元素中是否至少有一个存在于集合中,如果存在就返回 true,否则返回 false

Contains 判断的逻辑正好相反

func (t *threadSafeSet[T]) ContainsAny(v ...T) bool {
	t.RLock()
	ret := t.uss.ContainsAny(v...)
	t.RUnlock()

	return ret
}
func (s *threadUnsafeSet[T]) ContainsAny(v ...T) bool {
	for _, val := range v {
		if _, ok := (*s)[val]; ok {
			return true
		}
	}
	return false
}

ContainsAnyElement

ContainsAnyElement 方法用于检测两个集合是否有至少一个元素相等

在遍历两个集合时,需要先判断一下集合的大小,遍历相的小的集合

func (t *threadSafeSet[T]) ContainsAnyElement(other Set[T]) bool {
	o := other.(*threadSafeSet[T])

	t.RLock()
	o.RLock()

	ret := t.uss.ContainsAnyElement(o.uss)

	t.RUnlock()
	o.RUnlock()
	return ret
}
func (s *threadUnsafeSet[T]) ContainsAnyElement(other Set[T]) bool {
	o := other.(*threadUnsafeSet[T])

	// loop over smaller set
	if s.Cardinality() < other.Cardinality() {
		for elem := range *s {
			if o.contains(elem) {
				return true
			}
		}
	} else {
		for elem := range *o {
			if s.contains(elem) {
				return true
			}
		}
	}
	return false
}

Remove

Remove 方法用于从集合中移除一个元素

func (t *threadSafeSet[T]) Remove(v T) {
	t.Lock()
	delete(*t.uss, v)
	t.Unlock()
}
func (s threadUnsafeSet[T]) Remove(v T) {
	delete(s, v)
}

RemoveAll

RemoveAll 方法用于从集合中移除多个元素

func (t *threadSafeSet[T]) RemoveAll(i ...T) {
	t.Lock()
	t.uss.RemoveAll(i...)
	t.Unlock()
}
func (s threadUnsafeSet[T]) RemoveAll(i ...T) {
	for _, elem := range i {
		delete(s, elem)
	}
}

Clear

Clear 方法用于清空集合,在 go 中,如果循环遍历一个 map 并删除其中的元素,编译器会优化成 mapclear

func (t *threadSafeSet[T]) Clear() {
	t.Lock()
	t.uss.Clear()
	t.Unlock()
}
func (s *threadUnsafeSet[T]) Clear() {
	for key := range *s {
		delete(*s, key)
	}
}

IsSubset

IsSubset 方法用于判断集合 a 是否是集合 b 的子集

func (t *threadSafeSet[T]) IsSubset(other Set[T]) bool {
	o := other.(*threadSafeSet[T])

	t.RLock()
	o.RLock()

	ret := t.uss.IsSubset(o.uss)
	t.RUnlock()
	o.RUnlock()
	return ret
}
func (s *threadUnsafeSet[T]) IsSubset(other Set[T]) bool {
	o := other.(*threadUnsafeSet[T])
	if s.Cardinality() > other.Cardinality() {
		return false
	}
	for elem := range *s {
		if !o.contains(elem) {
			return false
		}
	}
	return true
}

IsProperSubset

IsProperSubset 用来判断 a 是否是 b 的真子集,即 ab 的子集,但是 a != b

func (t *threadSafeSet[T]) IsProperSubset(other Set[T]) bool {
	o := other.(*threadSafeSet[T])

	t.RLock()
	defer t.RUnlock()
	o.RLock()
	defer o.RUnlock()

	return t.uss.IsProperSubset(o.uss)
}
func (s *threadUnsafeSet[T]) IsProperSubset(other Set[T]) bool {
	return s.Cardinality() < other.Cardinality() && s.IsSubset(other)
}

IsSuperset

IsSuperset 用来判断 a 是否是 b 的超集

func (t *threadSafeSet[T]) IsSuperset(other Set[T]) bool {
	return other.IsSubset(t)
}
func (s *threadUnsafeSet[T]) IsSuperset(other Set[T]) bool {
	return other.IsSubset(s)
}

IsProperSuperset

IsProperSuperset 用来判断 a 是否是 b 的真超集

func (t *threadSafeSet[T]) IsProperSuperset(other Set[T]) bool {
	return other.IsProperSubset(t)
}
func (s *threadUnsafeSet[T]) IsProperSuperset(other Set[T]) bool {
	return s.Cardinality() > other.Cardinality() && s.IsSuperset(other)
}

Union

Union 方法的作用是合并两个集合,返回一个新的集合,这个新的集合包含两个集合中的所有元素

func (t *threadSafeSet[T]) Union(other Set[T]) Set[T] {
	o := other.(*threadSafeSet[T])

	t.RLock()
	o.RLock()

	unsafeUnion := t.uss.Union(o.uss).(*threadUnsafeSet[T])
	ret := &threadSafeSet[T]{uss: unsafeUnion}
	t.RUnlock()
	o.RUnlock()
	return ret
}
func (s threadUnsafeSet[T]) Union(other Set[T]) Set[T] {
	o := other.(*threadUnsafeSet[T])

	n := s.Cardinality()
	if o.Cardinality() > n {
		n = o.Cardinality()
	}
	unionedSet := make(threadUnsafeSet[T], n)

	for elem := range s {
		unionedSet.add(elem)
	}
	for elem := range *o {
		unionedSet.add(elem)
	}
	return &unionedSet
}

Intersect

Intersect 方法的作用是返回两个集合中共有的元素,Intersect 遍历时会先遍历小的集合

func (t *threadSafeSet[T]) Intersect(other Set[T]) Set[T] {
	o := other.(*threadSafeSet[T])

	t.RLock()
	o.RLock()

	unsafeIntersection := t.uss.Intersect(o.uss).(*threadUnsafeSet[T])
	ret := &threadSafeSet[T]{uss: unsafeIntersection}
	t.RUnlock()
	o.RUnlock()
	return ret
}
func (s *threadUnsafeSet[T]) Intersect(other Set[T]) Set[T] {
	o := other.(*threadUnsafeSet[T])

	intersection := newThreadUnsafeSet[T]()
	// loop over smaller set
	if s.Cardinality() < other.Cardinality() {
		for elem := range *s {
			if o.contains(elem) {
				intersection.add(elem)
			}
		}
	} else {
		for elem := range *o {
			if s.contains(elem) {
				intersection.add(elem)
			}
		}
	}
	return intersection
}

Difference

Difference 方法的作用是返回 a 集合中的元素不在 b 集合中的元素

func (t *threadSafeSet[T]) Difference(other Set[T]) Set[T] {
	o := other.(*threadSafeSet[T])

	t.RLock()
	o.RLock()

	unsafeDifference := t.uss.Difference(o.uss).(*threadUnsafeSet[T])
	ret := &threadSafeSet[T]{uss: unsafeDifference}
	t.RUnlock()
	o.RUnlock()
	return ret
}
func (s *threadUnsafeSet[T]) Difference(other Set[T]) Set[T] {
	o := other.(*threadUnsafeSet[T])

	diff := newThreadUnsafeSet[T]()
	for elem := range *s {
		if !o.contains(elem) {
			diff.add(elem)
		}
	}
	return diff
}

SymmetricDifference

SymmetricDifference 方法的作用是返回两个集合中不相同的元素

SymmetricDifferenceDifference 的区别在于 Difference 只返回 a 集合中的元素不在 b 集合中的元素,而 SymmetricDifference 返回两个集合中不相同的元素

func (t *threadSafeSet[T]) SymmetricDifference(other Set[T]) Set[T] {
	o := other.(*threadSafeSet[T])

	t.RLock()
	o.RLock()

	unsafeDifference := t.uss.SymmetricDifference(o.uss).(*threadUnsafeSet[T])
	ret := &threadSafeSet[T]{uss: unsafeDifference}
	t.RUnlock()
	o.RUnlock()
	return ret
}
func (s *threadUnsafeSet[T]) SymmetricDifference(other Set[T]) Set[T] {
	o := other.(*threadUnsafeSet[T])

	sd := newThreadUnsafeSet[T]()
	for elem := range *s {
		if !o.contains(elem) {
			sd.add(elem)
		}
	}
	for elem := range *o {
		if !s.contains(elem) {
			sd.add(elem)
		}
	}
	return sd
}

Clone

Clone 方法用于克隆一个集合,返回一个新的集合

func (t *threadSafeSet[T]) Clone() Set[T] {
	t.RLock()

	unsafeClone := t.uss.Clone().(*threadUnsafeSet[T])
	ret := &threadSafeSet[T]{uss: unsafeClone}
	t.RUnlock()
	return ret
}
func (s *threadUnsafeSet[T]) Clone() Set[T] {
	clonedSet := newThreadUnsafeSetWithSize[T](s.Cardinality())
	for elem := range *s {
		clonedSet.add(elem)
	}
	return clonedSet
}

Each

Each 方法用于遍历集合中的每个元素,并对每个元素执行传递的函数,如果传递的函数返回 true,则停止遍历

func (t *threadSafeSet[T]) Each(cb func(T) bool) {
	t.RLock()
	for elem := range *t.uss {
		if cb(elem) {
			break
		}
	}
	t.RUnlock()
}
func (s *threadUnsafeSet[T]) Each(cb func(T) bool) {
	for elem := range *s {
		if cb(elem) {
			break
		}
	}
}

Iter

Iter 方法用于返回一个迭代器,迭代器是一个只读的通道,通过这个通道可以遍历集合中的每个元素,可以使用 for range 来遍历

func (t *threadSafeSet[T]) Iter() <-chan T {
	ch := make(chan T)
	go func() {
		t.RLock()

		for elem := range *t.uss {
			ch <- elem
		}
		close(ch)
		t.RUnlock()
	}()

	return ch
}
func (s *threadUnsafeSet[T]) Iter() <-chan T {
	ch := make(chan T)
	go func() {
		for elem := range *s {
			ch <- elem
		}
		close(ch)
	}()

	return ch
}

Iterator

Iterator 方法用于返回一个迭代器,迭代器是一个只读的通道,通过这个通道可以遍历集合中的每个元素,可以使用 for range 来遍历,和 Iter 方法的区别在于 Iterator 返回的是一个 Iterator 类型的结构体

func (t *threadSafeSet[T]) Iterator() *Iterator[T] {
	iterator, ch, stopCh := newIterator[T]()

	go func() {
		t.RLock()
	L:
		for elem := range *t.uss {
			select {
			case <-stopCh:
				break L
			case ch <- elem:
			}
		}
		close(ch)
		t.RUnlock()
	}()

	return iterator
}
func (s *threadUnsafeSet[T]) Iterator() *Iterator[T] {
	iterator, ch, stopCh := newIterator[T]()

	go func() {
	L:
		for elem := range *s {
			select {
			case <-stopCh:
				break L
			case ch <- elem:
			}
		}
		close(ch)
	}()

	return iterator
}

Pop

Pop 方法用于从集合中移除一个元素,返回移除的元素和是否移除成功

func (t *threadSafeSet[T]) Pop() (T, bool) {
	t.Lock()
	defer t.Unlock()
	return t.uss.Pop()
}
func (s *threadUnsafeSet[T]) Pop() (v T, ok bool) {
	for item := range *s {
		delete(*s, item)
		return item, true
	}
	return v, false
}

ToSlice

ToSlice 方法用于将集合的成员转成切片

func (t *threadSafeSet[T]) ToSlice() []T {
	keys := make([]T, 0, t.Cardinality())
	t.RLock()
	for elem := range *t.uss {
		keys = append(keys, elem)
	}
	t.RUnlock()
	return keys
}
func (s threadUnsafeSet[T]) ToSlice() []T {
	keys := make([]T, 0, s.Cardinality())
	for elem := range s {
		keys = append(keys, elem)
	}

	return keys
}

String

String 方法用于返回集合的字符串表示

func (t *threadSafeSet[T]) String() string {
	t.RLock()
	ret := t.uss.String()
	t.RUnlock()
	return ret
}
func (s threadUnsafeSet[T]) String() string {
	items := make([]string, 0, len(s))
	for elem := range s {
		items = append(items, fmt.Sprintf("%v", elem))
	}
	return fmt.Sprintf("Set{%s}", strings.Join(items, ", "))
}

isEmpty

isEmpty 方法用于判断集合是否为空

func (t *threadSafeSet[T]) IsEmpty() bool {
	return t.Cardinality() == 0
}
func (s *threadUnsafeSet[T]) IsEmpty() bool {
	return s.Cardinality() == 0
}

MarshalJSON

MarshalJSON 方法用于将集合转成 json 格式

func (t *threadSafeSet[T]) MarshalJSON() ([]byte, error) {
	t.RLock()
	b, err := t.uss.MarshalJSON()
	t.RUnlock()

	return b, err
}
func (s threadUnsafeSet[T]) MarshalJSON() ([]byte, error) {
	items := make([]string, 0, s.Cardinality())

	for elem := range s {
		b, err := json.Marshal(elem)
		if err != nil {
			return nil, err
		}

		items = append(items, string(b))
	}

	return []byte(fmt.Sprintf("[%s]", strings.Join(items, ","))), nil
}

UnmarshalJSON

UnmarshalJSON 方法用于将 json 格式的数据转成集合

func (t *threadSafeSet[T]) UnmarshalJSON(p []byte) error {
	t.RLock()
	err := t.uss.UnmarshalJSON(p)
	t.RUnlock()

	return err
}
func (s *threadUnsafeSet[T]) UnmarshalJSON(b []byte) error {
	var i []T
	err := json.Unmarshal(b, &i)
	if err != nil {
		return err
	}
	s.Append(i...)

	return nil
}