Go语言基础十——map的实现原理

286 阅读6分钟

我正在参加「掘金·启航计划」

1、map实现原理

1.1、key,value存储

map是通过key获取value的一种数据结构,其底层存储方式为数组,存储时key不能重复,当key重复时,value进行覆盖,我们通过key进行hash运算(把key转化为一个整形数字)然后对数组的长度取余,得到key存储在数组的哪个下标位置,最后将key和value组装为一个结构体,放入数组下标处

length = len(array) = 4
hashkey1 = hash(xiaoming) = 4
index1 = hashkey1 % 4 = 0
hashkey2 = hash(xiaoli) = 6
index2 = hashkey2 % 4 = 2

1.2、hash冲突

如上图所示,数组一个下标处只能存储一个元素,也就是说一个数组下标只能存储一对key,value, 如果另一个key经过hash后得到index与之前重复,就存在hash冲突

1.2.1、开放定址法

也就是说当我们存储一个key,value时,发现hashkey(key)的下标已经被别的key占用,那我们在这个数组中空间中重新找一个没被占用的存储这个冲突的key,那么没被占用的有很多,找哪个好呢?常见的有线性探测法,线性补偿探测法,随机探测法,

  • 线性探测:按照顺序来,从冲突的下标处开始往后探测,到达数组末尾时,从数组开始处探测,直到找到一个空位置存储这个key,当数组都找不到的情况下会扩容(事实是当数组容量快满的时候就会扩容了);查找某一个key的时候,找到key对应的下标,比较key是否相等,如果相等直接取出来,否则按照书序探测直到碰到一个空位置,说明key不存在。如下图,首先存储key=xiaoming在下标0处,当存储key=xiaowang时,hash冲突了,按照线性探测,存储在下标1处,(红色线是冲突或者下标已经被占用)再者key=xiaozhao存储在下标4处,当存储key=xiaoliu时,hash冲突了,按照线性探测,从头开始,存储在下标2处(黄色是冲突或者下标已经被占用)

1.2.2、拉链法

简单理解为链表,当key的hash冲突时,我们在冲突位置的元素上形成一个链表,通过指针互连接,当查找时,发现key冲突,顺着链表一直往下找,直到链表的尾节点,找不到返回空。

1.2.3、开放定址(线性探测)和拉链法优缺点

  • 拉链法比线性探测处理简单
  • 线性探测查找比拉链法会更消耗时间
  • 线性探测会更加容易导致扩容
  • 拉链存储了指针,所以空间上会比线性探测占用多一点
  • 拉链是动态申请存储空间,所以更适合链长不确定的

1.3、map实现原理

map源码位于src/runtime/map.go中,map也是数组存储的,每个数组下标处存储的是一个bucket,bucket的类型代码如下,每个bucket可以存储8个kv键值对,当每个bucket存储的kv对到达8个之后,会通过overflow指针指向一个新的bucket,从而行程一个链表,bmap的结构中,没有显示定义kv的结构和overflow指针,是通过指针运算进行访问的。

//bucket结构体定义,b就是bucket

type bmap struct{
	// tophash generally contains the top byte of the hash value
	// for each key in this bucket. If tophash[0] < minTopHash,
	// tophash[0] is a bucket evacuation state instead.
    //翻译:tophash 通常包含bucket中每个键的hash值得高八位。
    如果tophash[0]小于minTopHash,topHash[0]为桶疏散状态,
    //bucketCnt的初始值是8
	tophash [bucketCnt]uint8
	// Followed by bucketCnt keys and then bucketCnt elems.
	// NOTE: packing all the keys together and then all the elems together makes the
	// code a bit more complicated than alternating key/elem/key/elem/... but it allows
	// us to eliminate padding which would be needed for, e.g., map[int64]int8.
	// Followed by an overflow pointer.

}

看上面代码和注释,我们可以得到bucket中存储的kv是这样的,tophash用来快速查找key值是否在该bucket中,而不用每次都通过真值进行比较;还有kv的存放,为什么不是k1v1,k2v2...而是k1k2...v1v2..存放,上面代码注释说的map[int64]int8;key是int64(8个字节),value是int8(1个字节),kv的长度不同,如果按照kv格式存放,则考虑内存对齐,v也占用int64,而按照后者存储,8个v刚好占用一个int64,节省内存。

分析一下go的整体内存结构,如下图,当往map中存储一个kv对时,通过k获取hash值,hash值得低8位和bucket数组长度取余,定位到在数组中的哪个下标,hash值得高8位存储在bucket中的tophash中,用来快速判断key是否存在,key和value的具体值则通过指针运算存储,当一个bucket满时,通过overflow指针连接到下一个bucket。

go的map存储源码如下:

func mapassign(t *maptype, h *hmap, key unsafe.Pointer) unsafe.Pointer {
	if h == nil {
		panic(plainError("assignment to entry in nil map"))
	}
	if raceenabled {
		callerpc := getcallerpc()
		pc := funcPC(mapassign)
		racewritepc(unsafe.Pointer(h), callerpc, pc)
		raceReadObjectPC(t.key, key, callerpc, pc)
	}
	if msanenabled {
		msanread(key, t.key.size)
	}
	if h.flags&hashWriting != 0 {
		throw("concurrent map writes")
	}
    //计算hash值
	hash := t.hasher(key, uintptr(h.hash0))

	// Set hashWriting after calling t.hasher, since t.hasher may panic,
	// in which case we have not actually done a write.
	h.flags ^= hashWriting
    
    //如果bucket数组一开始为空,则初始化
	if h.buckets == nil {
		h.buckets = newobject(t.bucket) // newarray(t.bucket, 1)
	}

again:
    //定位存储在哪一个bucket中
	bucket := hash & bucketMask(h.B)
	if h.growing() {
		growWork(t, h, bucket)
	}
    //得到bucket的结构体
	b := (*bmap)(unsafe.Pointer(uintptr(h.buckets) + bucket*uintptr(t.bucketsize)))
    //获取高八位的hash值
	top := tophash(hash)

	var inserti *uint8
	var insertk unsafe.Pointer
	var elem unsafe.Pointer
bucketloop:
    //死循环
	for {
        //循环bucket中的tophash数组
		for i := uintptr(0); i < bucketCnt; i++ {
            //如果hash不相等
			if b.tophash[i] != top {
                //判断是否为空,为空则插入
				if isEmpty(b.tophash[i]) && inserti == nil {
					inserti = &b.tophash[i]
					insertk = add(unsafe.Pointer(b), dataOffset+i*uintptr(t.keysize))
					elem = add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.keysize)+i*uintptr(t.elemsize))
				}
                //插入成功,终止最外层循环
				if b.tophash[i] == emptyRest {
					break bucketloop
				}
				continue
			}
            //到这里说明高8位的hash一样,获取已存在的key
			k := add(unsafe.Pointer(b), dataOffset+i*uintptr(t.keysize))
			if t.indirectkey() {
				k = *((*unsafe.Pointer)(k))
			}
            //判断两个key是否相等,不相等就循环下一个
			if !t.key.equal(key, k) {
				continue
			}
			// already have a mapping for key. Update it.
            //如果相等,则更新
			if t.needkeyupdate() {
				typedmemmove(t.key, k, key)
			}
            //获取已存在的value
			elem = add(unsafe.Pointer(b), dataOffset+bucketCnt*uintptr(t.keysize)+i*uintptr(t.elemsize))
			goto done
		}
        //如果上一个bucket没能插入,则通过overflow获取链表上的下一个bucket
		ovf := b.overflow(t)
		if ovf == nil {
			break
		}
		b = ovf
	}

	// Did not find mapping for key. Allocate new cell & add entry.

	// If we hit the max load factor or we have too many overflow buckets,
	// and we're not already in the middle of growing, start growing.
	if !h.growing() && (overLoadFactor(h.count+1, h.B) || tooManyOverflowBuckets(h.noverflow, h.B)) {
		hashGrow(t, h)
		goto again // Growing the table invalidates everything, so try again
	}

	if inserti == nil {
		// all current buckets are full, allocate a new one.
		newb := h.newoverflow(t, b)
		inserti = &newb.tophash[0]
		insertk = add(unsafe.Pointer(newb), dataOffset)
		elem = add(insertk, bucketCnt*uintptr(t.keysize))
	}

	// store new key/elem at insert position
	if t.indirectkey() {
		kmem := newobject(t.key)
		*(*unsafe.Pointer)(insertk) = kmem
		insertk = kmem
	}
	if t.indirectelem() {
		vmem := newobject(t.elem)
		*(*unsafe.Pointer)(elem) = vmem
	}
	typedmemmove(t.key, insertk, key)
    //将高八位hash值进行存储
	*inserti = top
	h.count++

done:
	if h.flags&hashWriting == 0 {
		throw("concurrent map writes")
	}
	h.flags &^= hashWriting
	if t.indirectelem() {
		elem = *((*unsafe.Pointer)(elem))
	}
	return elem
}