我们今天来学习 golang-set
包,这是一个用于处理集合的包,它提供了一些常用的集合操作,比如并集、交集、差集等,并且支持线程安全
学习一个库从它的测试用例开始,我们来看的第一个测试用例时用力测试 NewSet
这个 api
NewSet
会基于给定的元素一个新的集合 Set
,这个集合有很多方法:
Add
:向集合中添加一个元素,返回该元素是否被添加成功Append
:将多个元素添加到集合中,返回添加的元素数量Cardinality
:返回集合中元素的个数Clear
:清空集合Clone
:克隆一个集合Contains
:返回给定元素是否在集合中ContainsOne
:返回给定元素是否在集合中ContainsAny
:给定的元素中至少有一个在集合中ContainsAnyElement
:检查是否包含另一个集合中的任意一个元素Equal
:如果两个集合中的的元素个数相等,并且元素也是一样的,那么他们就是相等的,和元素的顺序没有关系。如果两个集合中的元素类型不一致,会发生panic
IsEmpty
:判断集合是否为空IsProperSubset
:子集,用于判断集合a
是否是集合b
的子集,且a != b
,注意传入的参数必须与方法的接收者类型相同,否则会引发panic
IsProperSuperset
:超集,用于判断集合a
是否是集合b
的超集,且a != b
,注意传入的参数必须与方法的接收者类型相同,否则会引发panic
IsSubset
:真子集,用于判断集合a
是否是集合b
的子集,注意传入的参数必须与方法的接收者类型相同,否则会引发panic
IsSuperset
:真超集,用于判断集合a
是否是集合b
的超集,注意传入的参数必须与方法的接收者类型相同,否则会引发panic
Each
:遍历每个元素并对每个元素执行传递的函数,如果传递的函数返回true
,则停止遍历Iter
:返回一个迭代器,迭代器是一个只读的通道,通过这个通道可以遍历集合中的每个元素,可以使用for range
来遍历Iterator
:返回一个迭代器,迭代器是一个只读的通道,通过这个通道可以遍历集合中的每个元素,可以使用for range
来遍历,和Iter
方法的区别在于Iterator
返回的是一个Iterator
类型的结构体Remove
:从集合中移除单个元素RemoveAll
:从集合中移除多个元素String
:返回集合的字符串表示Union
:并集,返回一个包含两个集合中所有元素的新集合,注意传入的参数必须与方法的接收者类型相同,否则会引发panic
Intersect
:交集,返回一个包含两个集合中共有元素的新集合,注意传入的参数必须与方法的接收者类型相同,否则会引发panic
Difference
:差集,返回一个包含a
集合中的元素不在b
集合中的元素的新集合,注意传入的参数必须与方法的接收者类型相同,否则会引发panic
SymmetricDifference
:对称差,返回一个包含两个集合中不相同元素的新集合,注意传入的参数必须与方法的接收者类型相同,否则会引发panic
Pop
:从集合中移除一个元素,返回移除的元素和是否移除成功ToSlice
:将集合的成员转成切片MarshalJSON
:将集合转成json
格式UnmarshalJSON
:将json
格式的数据转成集合
我们先来看第一个测试用例:Test_NewSet
前置知识
在学习这个库之前,需要了解一些前置知识,比如 comparable
接口
comparable
comparable
是一个内置接口约束,它表示这个类型可以使用 ==
和 !=
运算符进行比较
- 基本类型:
- 布尔值
- 数字
- 字符串
channel
- 接口
- 指针
- 复合类型:
- 数组
- 结构体
要注意的是 map
的 key
不能使用 comparable
接口,因为 map
需要具体的可比较类型作为 key
// 错误
var m map[comparable]string
// 正确
var m1 map[string]string
var m2 map[int]string
var m3 map[struct{X int}]string
但是在泛型函数或类型中,可以使用 comparable
作为类型约束
// 泛型函数
func NewMap[K comparable, V any]() map[K]V {
return make(map[K]V)
}
// 泛型类型
type Map[K comparable, V any] struct {
M map[K]V
}
Test_NewSet
Test_NewSet
测试用例测试了 NewSet
的基本功能,我传入的是一个可变参数,要保证它的类型是正确的
在这个测试用例中包含了 Set
的两个方法:Cardinality
和 Equal
func Test_NewSet(t *testing.T) {
a := NewSet[int]()
if a.Cardinality() != 0 {
t.Error("NewSet should start out as an empty set")
}
assertEqual(NewSet([]int{}...), NewSet[int](), t)
assertEqual(NewSet([]int{1}...), NewSet(1), t)
assertEqual(NewSet([]int{1, 2}...), NewSet(1, 2), t)
assertEqual(NewSet([]string{"a"}...), NewSet("a"), t)
assertEqual(NewSet([]string{"a", "b"}...), NewSet("a", "b"), t)
}
NewSet
在学习 Cardinality
方法之前,我们先来看下 NewSet
NewSet
是一个泛型函数,它接受一个可变参数,这个参数的类型是 comparable
的,然后通过 newThreadSafeSetWithSize
函数创建一个线程安全的 Set
,并将参数中的元素添加到这个 Set
中
func NewSet[T comparable](vals ...T) Set[T] {
s := newThreadSafeSetWithSize[T](len(vals))
for _, item := range vals {
s.Add(item)
}
return s
}
newThreadSafeSetWithSize
函数是基于传入的基数创建一个线程安全的 Set
,这个线程安全的 Set
是一个 threadSafeSet
类型的结构体
func newThreadSafeSetWithSize[T comparable](cardinality int) *threadSafeSet[T] {
return &threadSafeSet[T]{
uss: newThreadUnsafeSetWithSize[T](cardinality),
}
}
threadSafeSet
结构体如下:
type threadSafeSet[T comparable] struct {
sync.RWMutex
uss *threadUnsafeSet[T]
}
线程安全的底层也是一个线程不安全的 map
type threadUnsafeSet[T comparable] map[T]struct{}
为什么线程安全的 Set
使用了线程不安全的 map
呢?线程安全相比于线程不安全多了一个锁,所以底层还是一个 map
操作,只是在操作 map
时加了锁
func newThreadUnsafeSetWithSize[T comparable](cardinality int) *threadUnsafeSet[T] {
t := make(threadUnsafeSet[T], cardinality)
return &t
}
Set
实现了之后,已经创建了一个空的 map
集合,接下来就是将参数中的元素添加到这个集合中,添加元素的方法是 Add
Add
这个 Add
方法是 threadSafeSet
上的方法,因为这个 Set
对外提供的方法都是线程安全的,所以 Add
方法也是线程安全的
Add
在执行前会先加锁,然后调用 threadUnsafeSet
上的 Add
方法,最后释放锁
func (t *threadSafeSet[T]) Add(v T) bool {
t.Lock()
ret := t.uss.Add(v)
t.Unlock()
return ret
}
threadUnsafeSet
上的 Add
方法如下:
在添加之前先获取下原先 map
的长度,然后将元素添加到 map
中,添加后在和原先的长度进行比较,如果不相等,说明添加成功,返回 true
,否则返回 false
func (s threadUnsafeSet[T]) Add(v T) bool {
prevLen := len(s)
s[v] = struct{}{}
return prevLen != len(s)
}
Set
是基于 map
实现的,但是我们并不关心 map
中的 value
,所以 value
使用了一个空的结构体 struct{}
,这样可以节省空间
Equal
在测试中要比较两个 Set
是否相等,需要有 Equal
方法,如下:
set1 := NewSet([]int{1}...)
set2 := NewSet(1)
set1.Equal(set2)
Equal
方法会先获取两个 Set
的锁,然后调用 threadUnsafeSet
上的 Equal
方法,最后释放锁
为什么要获取两个 Set
锁呢?这是因为要保证在比较过程中两个 Set
都不会被修改,如果只锁其中一个 Set
就会出现在比较的过程中另一个 Set
被修改的情况
func (t *threadSafeSet[T]) Equal(other Set[T]) bool {
o := other.(*threadSafeSet[T])
t.RLock()
o.RLock()
ret := t.uss.Equal(o.uss)
t.RUnlock()
o.RUnlock()
return ret
}
threadUnsafeSet
上的 Equal
先是比较两个 Set
的基数是否相等,如果不相等,直接返回 false
然后遍历一个 Set
中的元素,判断另一个 Set
中是否包含这个元素,如果不包含,返回 false
,否则返回 true
func (s *threadUnsafeSet[T]) Equal(other Set[T]) bool {
o := other.(*threadUnsafeSet[T])
if s.Cardinality() != other.Cardinality() {
return false
}
for elem := range *s {
if !o.contains(elem) {
return false
}
}
return true
}
Append
Append
的作用是将多个元素添加到集合中,返回添加的元素数量
这个实现也分为线程安全和线程不安全两个部分
我们先来看线程安全的 Append
方法,Append
接收可变参数,进入方法时先加锁,然后调用 threadUnsafeSet
上的 Append
方法,最后释放锁
func (t *threadSafeSet[T]) Append(v ...T) int {
t.Lock()
ret := t.uss.Append(v...)
t.Unlock()
return ret
}
threadUnsafeSet
上的 Append
方法和 Add
类似,只是它接收的是可变参数,遍历参数,将每个元素添加到 map
中
func (s *threadUnsafeSet[T]) Append(v ...T) int {
prevLen := len(*s)
for _, val := range v {
(*s)[val] = struct{}{}
}
return len(*s) - prevLen
}
Contains
Contains
方法用于判断给定的元素是否在集合中,如果在集合中返回 true
,否则返回 false
也可以传入多个元素,只有当所有元素都在集合中时才返回 true
func (t *threadSafeSet[T]) Contains(v ...T) bool {
t.RLock()
ret := t.uss.Contains(v...)
t.RUnlock()
return ret
}
func (s *threadUnsafeSet[T]) Contains(v ...T) bool {
for _, val := range v {
if _, ok := (*s)[val]; !ok {
return false
}
}
return true
}
ContainsOne
ContainsOne
方法用于判断给定的元素是否在集合中,如果在集合中返回 true
,否则返回 false
和 Contains
方法不同的是,ContainsOne
只接收一个参数
func (t *threadSafeSet[T]) ContainsOne(v T) bool {
t.RLock()
ret := t.uss.ContainsOne(v)
t.RUnlock()
return ret
}
func (s *threadUnsafeSet[T]) ContainsOne(v T) bool {
_, ok := (*s)[v]
return ok
}
ContainsAny
ContainsAny
方法用于判断给定的元素中是否至少有一个存在于集合中,如果存在就返回 true
,否则返回 false
和 Contains
判断的逻辑正好相反
func (t *threadSafeSet[T]) ContainsAny(v ...T) bool {
t.RLock()
ret := t.uss.ContainsAny(v...)
t.RUnlock()
return ret
}
func (s *threadUnsafeSet[T]) ContainsAny(v ...T) bool {
for _, val := range v {
if _, ok := (*s)[val]; ok {
return true
}
}
return false
}
ContainsAnyElement
ContainsAnyElement
方法用于检测两个集合是否有至少一个元素相等
在遍历两个集合时,需要先判断一下集合的大小,遍历相的小的集合
func (t *threadSafeSet[T]) ContainsAnyElement(other Set[T]) bool {
o := other.(*threadSafeSet[T])
t.RLock()
o.RLock()
ret := t.uss.ContainsAnyElement(o.uss)
t.RUnlock()
o.RUnlock()
return ret
}
func (s *threadUnsafeSet[T]) ContainsAnyElement(other Set[T]) bool {
o := other.(*threadUnsafeSet[T])
// loop over smaller set
if s.Cardinality() < other.Cardinality() {
for elem := range *s {
if o.contains(elem) {
return true
}
}
} else {
for elem := range *o {
if s.contains(elem) {
return true
}
}
}
return false
}
Remove
Remove
方法用于从集合中移除一个元素
func (t *threadSafeSet[T]) Remove(v T) {
t.Lock()
delete(*t.uss, v)
t.Unlock()
}
func (s threadUnsafeSet[T]) Remove(v T) {
delete(s, v)
}
RemoveAll
RemoveAll
方法用于从集合中移除多个元素
func (t *threadSafeSet[T]) RemoveAll(i ...T) {
t.Lock()
t.uss.RemoveAll(i...)
t.Unlock()
}
func (s threadUnsafeSet[T]) RemoveAll(i ...T) {
for _, elem := range i {
delete(s, elem)
}
}
Clear
Clear
方法用于清空集合,在 go
中,如果循环遍历一个 map
并删除其中的元素,编译器会优化成 mapclear
func (t *threadSafeSet[T]) Clear() {
t.Lock()
t.uss.Clear()
t.Unlock()
}
func (s *threadUnsafeSet[T]) Clear() {
for key := range *s {
delete(*s, key)
}
}
IsSubset
IsSubset
方法用于判断集合 a
是否是集合 b
的子集
func (t *threadSafeSet[T]) IsSubset(other Set[T]) bool {
o := other.(*threadSafeSet[T])
t.RLock()
o.RLock()
ret := t.uss.IsSubset(o.uss)
t.RUnlock()
o.RUnlock()
return ret
}
func (s *threadUnsafeSet[T]) IsSubset(other Set[T]) bool {
o := other.(*threadUnsafeSet[T])
if s.Cardinality() > other.Cardinality() {
return false
}
for elem := range *s {
if !o.contains(elem) {
return false
}
}
return true
}
IsProperSubset
IsProperSubset
用来判断 a
是否是 b
的真子集,即 a
是 b
的子集,但是 a != b
func (t *threadSafeSet[T]) IsProperSubset(other Set[T]) bool {
o := other.(*threadSafeSet[T])
t.RLock()
defer t.RUnlock()
o.RLock()
defer o.RUnlock()
return t.uss.IsProperSubset(o.uss)
}
func (s *threadUnsafeSet[T]) IsProperSubset(other Set[T]) bool {
return s.Cardinality() < other.Cardinality() && s.IsSubset(other)
}
IsSuperset
IsSuperset
用来判断 a
是否是 b
的超集
func (t *threadSafeSet[T]) IsSuperset(other Set[T]) bool {
return other.IsSubset(t)
}
func (s *threadUnsafeSet[T]) IsSuperset(other Set[T]) bool {
return other.IsSubset(s)
}
IsProperSuperset
IsProperSuperset
用来判断 a
是否是 b
的真超集
func (t *threadSafeSet[T]) IsProperSuperset(other Set[T]) bool {
return other.IsProperSubset(t)
}
func (s *threadUnsafeSet[T]) IsProperSuperset(other Set[T]) bool {
return s.Cardinality() > other.Cardinality() && s.IsSuperset(other)
}
Union
Union
方法的作用是合并两个集合,返回一个新的集合,这个新的集合包含两个集合中的所有元素
func (t *threadSafeSet[T]) Union(other Set[T]) Set[T] {
o := other.(*threadSafeSet[T])
t.RLock()
o.RLock()
unsafeUnion := t.uss.Union(o.uss).(*threadUnsafeSet[T])
ret := &threadSafeSet[T]{uss: unsafeUnion}
t.RUnlock()
o.RUnlock()
return ret
}
func (s threadUnsafeSet[T]) Union(other Set[T]) Set[T] {
o := other.(*threadUnsafeSet[T])
n := s.Cardinality()
if o.Cardinality() > n {
n = o.Cardinality()
}
unionedSet := make(threadUnsafeSet[T], n)
for elem := range s {
unionedSet.add(elem)
}
for elem := range *o {
unionedSet.add(elem)
}
return &unionedSet
}
Intersect
Intersect
方法的作用是返回两个集合中共有的元素,Intersect
遍历时会先遍历小的集合
func (t *threadSafeSet[T]) Intersect(other Set[T]) Set[T] {
o := other.(*threadSafeSet[T])
t.RLock()
o.RLock()
unsafeIntersection := t.uss.Intersect(o.uss).(*threadUnsafeSet[T])
ret := &threadSafeSet[T]{uss: unsafeIntersection}
t.RUnlock()
o.RUnlock()
return ret
}
func (s *threadUnsafeSet[T]) Intersect(other Set[T]) Set[T] {
o := other.(*threadUnsafeSet[T])
intersection := newThreadUnsafeSet[T]()
// loop over smaller set
if s.Cardinality() < other.Cardinality() {
for elem := range *s {
if o.contains(elem) {
intersection.add(elem)
}
}
} else {
for elem := range *o {
if s.contains(elem) {
intersection.add(elem)
}
}
}
return intersection
}
Difference
Difference
方法的作用是返回 a
集合中的元素不在 b
集合中的元素
func (t *threadSafeSet[T]) Difference(other Set[T]) Set[T] {
o := other.(*threadSafeSet[T])
t.RLock()
o.RLock()
unsafeDifference := t.uss.Difference(o.uss).(*threadUnsafeSet[T])
ret := &threadSafeSet[T]{uss: unsafeDifference}
t.RUnlock()
o.RUnlock()
return ret
}
func (s *threadUnsafeSet[T]) Difference(other Set[T]) Set[T] {
o := other.(*threadUnsafeSet[T])
diff := newThreadUnsafeSet[T]()
for elem := range *s {
if !o.contains(elem) {
diff.add(elem)
}
}
return diff
}
SymmetricDifference
SymmetricDifference
方法的作用是返回两个集合中不相同的元素
SymmetricDifference
和 Difference
的区别在于 Difference
只返回 a
集合中的元素不在 b
集合中的元素,而 SymmetricDifference
返回两个集合中不相同的元素
func (t *threadSafeSet[T]) SymmetricDifference(other Set[T]) Set[T] {
o := other.(*threadSafeSet[T])
t.RLock()
o.RLock()
unsafeDifference := t.uss.SymmetricDifference(o.uss).(*threadUnsafeSet[T])
ret := &threadSafeSet[T]{uss: unsafeDifference}
t.RUnlock()
o.RUnlock()
return ret
}
func (s *threadUnsafeSet[T]) SymmetricDifference(other Set[T]) Set[T] {
o := other.(*threadUnsafeSet[T])
sd := newThreadUnsafeSet[T]()
for elem := range *s {
if !o.contains(elem) {
sd.add(elem)
}
}
for elem := range *o {
if !s.contains(elem) {
sd.add(elem)
}
}
return sd
}
Clone
Clone
方法用于克隆一个集合,返回一个新的集合
func (t *threadSafeSet[T]) Clone() Set[T] {
t.RLock()
unsafeClone := t.uss.Clone().(*threadUnsafeSet[T])
ret := &threadSafeSet[T]{uss: unsafeClone}
t.RUnlock()
return ret
}
func (s *threadUnsafeSet[T]) Clone() Set[T] {
clonedSet := newThreadUnsafeSetWithSize[T](s.Cardinality())
for elem := range *s {
clonedSet.add(elem)
}
return clonedSet
}
Each
Each
方法用于遍历集合中的每个元素,并对每个元素执行传递的函数,如果传递的函数返回 true
,则停止遍历
func (t *threadSafeSet[T]) Each(cb func(T) bool) {
t.RLock()
for elem := range *t.uss {
if cb(elem) {
break
}
}
t.RUnlock()
}
func (s *threadUnsafeSet[T]) Each(cb func(T) bool) {
for elem := range *s {
if cb(elem) {
break
}
}
}
Iter
Iter
方法用于返回一个迭代器,迭代器是一个只读的通道,通过这个通道可以遍历集合中的每个元素,可以使用 for range
来遍历
func (t *threadSafeSet[T]) Iter() <-chan T {
ch := make(chan T)
go func() {
t.RLock()
for elem := range *t.uss {
ch <- elem
}
close(ch)
t.RUnlock()
}()
return ch
}
func (s *threadUnsafeSet[T]) Iter() <-chan T {
ch := make(chan T)
go func() {
for elem := range *s {
ch <- elem
}
close(ch)
}()
return ch
}
Iterator
Iterator
方法用于返回一个迭代器,迭代器是一个只读的通道,通过这个通道可以遍历集合中的每个元素,可以使用 for range
来遍历,和 Iter
方法的区别在于 Iterator
返回的是一个 Iterator
类型的结构体
func (t *threadSafeSet[T]) Iterator() *Iterator[T] {
iterator, ch, stopCh := newIterator[T]()
go func() {
t.RLock()
L:
for elem := range *t.uss {
select {
case <-stopCh:
break L
case ch <- elem:
}
}
close(ch)
t.RUnlock()
}()
return iterator
}
func (s *threadUnsafeSet[T]) Iterator() *Iterator[T] {
iterator, ch, stopCh := newIterator[T]()
go func() {
L:
for elem := range *s {
select {
case <-stopCh:
break L
case ch <- elem:
}
}
close(ch)
}()
return iterator
}
Pop
Pop
方法用于从集合中移除一个元素,返回移除的元素和是否移除成功
func (t *threadSafeSet[T]) Pop() (T, bool) {
t.Lock()
defer t.Unlock()
return t.uss.Pop()
}
func (s *threadUnsafeSet[T]) Pop() (v T, ok bool) {
for item := range *s {
delete(*s, item)
return item, true
}
return v, false
}
ToSlice
ToSlice
方法用于将集合的成员转成切片
func (t *threadSafeSet[T]) ToSlice() []T {
keys := make([]T, 0, t.Cardinality())
t.RLock()
for elem := range *t.uss {
keys = append(keys, elem)
}
t.RUnlock()
return keys
}
func (s threadUnsafeSet[T]) ToSlice() []T {
keys := make([]T, 0, s.Cardinality())
for elem := range s {
keys = append(keys, elem)
}
return keys
}
String
String
方法用于返回集合的字符串表示
func (t *threadSafeSet[T]) String() string {
t.RLock()
ret := t.uss.String()
t.RUnlock()
return ret
}
func (s threadUnsafeSet[T]) String() string {
items := make([]string, 0, len(s))
for elem := range s {
items = append(items, fmt.Sprintf("%v", elem))
}
return fmt.Sprintf("Set{%s}", strings.Join(items, ", "))
}
isEmpty
isEmpty
方法用于判断集合是否为空
func (t *threadSafeSet[T]) IsEmpty() bool {
return t.Cardinality() == 0
}
func (s *threadUnsafeSet[T]) IsEmpty() bool {
return s.Cardinality() == 0
}
MarshalJSON
MarshalJSON
方法用于将集合转成 json
格式
func (t *threadSafeSet[T]) MarshalJSON() ([]byte, error) {
t.RLock()
b, err := t.uss.MarshalJSON()
t.RUnlock()
return b, err
}
func (s threadUnsafeSet[T]) MarshalJSON() ([]byte, error) {
items := make([]string, 0, s.Cardinality())
for elem := range s {
b, err := json.Marshal(elem)
if err != nil {
return nil, err
}
items = append(items, string(b))
}
return []byte(fmt.Sprintf("[%s]", strings.Join(items, ","))), nil
}
UnmarshalJSON
UnmarshalJSON
方法用于将 json
格式的数据转成集合
func (t *threadSafeSet[T]) UnmarshalJSON(p []byte) error {
t.RLock()
err := t.uss.UnmarshalJSON(p)
t.RUnlock()
return err
}
func (s *threadUnsafeSet[T]) UnmarshalJSON(b []byte) error {
var i []T
err := json.Unmarshal(b, &i)
if err != nil {
return err
}
s.Append(i...)
return nil
}