CoroutineContext源码解析

222 阅读10分钟

先把涉及到的源码先放上来,可以先忽略源码直接看下面解析,等看完解析再回头看源码。

public interface CoroutineContext {
    
    public operator fun <E : Element> get(key: Key<E>): E?

    
    public fun <R> fold(initial: R, operation: (R, Element) -> R): R

    
    public operator fun plus(context: CoroutineContext): CoroutineContext =
        if (context === EmptyCoroutineContext) this else // fast path -- avoid lambda creation
            context.fold(this) { acc, element ->
                val removed = acc.minusKey(element.key)
                if (removed === EmptyCoroutineContext) element else {
                    // make sure interceptor is always last in the context (and thus is fast to get when present)
                    val interceptor = removed[ContinuationInterceptor]
                    if (interceptor == null) CombinedContext(removed, element) else {
                        val left = removed.minusKey(ContinuationInterceptor)
                        if (left === EmptyCoroutineContext) CombinedContext(element, interceptor) else
                            CombinedContext(CombinedContext(left, element), interceptor)
                    }
                }
            }

    
    public fun minusKey(key: Key<*>): CoroutineContext


    public interface Key<E : Element>


    public interface Element : CoroutineContext {

        public val key: Key<*>

        public override operator fun <E : Element> get(key: Key<E>): E? =
            @Suppress("UNCHECKED_CAST")
            if (this.key == key) this as E else null

        public override fun <R> fold(initial: R, operation: (R, Element) -> R): R =
            operation(initial, this)

        public override fun minusKey(key: Key<*>): CoroutineContext =
            if (this.key == key) EmptyCoroutineContext else this
    }
}
public object EmptyCoroutineContext : CoroutineContext, Serializable {
    private const val serialVersionUID: Long = 0
    private fun readResolve(): Any = EmptyCoroutineContext

    public override fun <E : Element> get(key: Key<E>): E? = null
    public override fun <R> fold(initial: R, operation: (R, Element) -> R): R = initial
    public override fun plus(context: CoroutineContext): CoroutineContext = context
    public override fun minusKey(key: Key<*>): CoroutineContext = this
    public override fun hashCode(): Int = 0
    public override fun toString(): String = "EmptyCoroutineContext"
}

//--------------------- internal impl ---------------------

internal class CombinedContext(
    private val left: CoroutineContext,
    private val element: Element
) : CoroutineContext, Serializable {

    override fun <E : Element> get(key: Key<E>): E? {
        var cur = this
        while (true) {
            cur.element[key]?.let { return it }
            val next = cur.left
            if (next is CombinedContext) {
                cur = next
            } else {
                return next[key]
            }
        }
    }

    public override fun <R> fold(initial: R, operation: (R, Element) -> R): R =
        operation(left.fold(initial, operation), element)

    public override fun minusKey(key: Key<*>): CoroutineContext {
        element[key]?.let { return left }
        val newLeft = left.minusKey(key)
        return when {
            newLeft === left -> this
            newLeft === EmptyCoroutineContext -> element
            else -> CombinedContext(newLeft, element)
        }
    }

    //该类是组合类,起步价就是2个,然后多一个结点就+1
    private fun size(): Int {
        var cur = this
        var size = 2
        while (true) {
            cur = cur.left as? CombinedContext ?: return size
            size++
        }
    }

    //只要能获取到则为true
    private fun contains(element: Element): Boolean =
        get(element.key) == element

    //去进行遍历,但凡获取不到就是false,否则就是都能获取到
    private fun containsAll(context: CombinedContext): Boolean {
        var cur = context
        while (true) {
            if (!contains(cur.element)) return false
            val next = cur.left
            if (next is CombinedContext) {
                cur = next
            } else {
                return contains(next as Element)
            }
        }
    }

    override fun equals(other: Any?): Boolean =
        this === other || other is CombinedContext && other.size() == size() && other.containsAll(this)

    override fun hashCode(): Int = left.hashCode() + element.hashCode()

    override fun toString(): String =
        "[" + fold("") { acc, element ->
            if (acc.isEmpty()) element.toString() else "$acc, $element"
        } + "]"

    //将所有的结点上的element都统计到一个数组里面
    private fun writeReplace(): Any {
        val n = size()
        val elements = arrayOfNulls<CoroutineContext>(n)
        var index = 0
        fold(Unit) { _, element -> elements[index++] = element }
        check(index == n)
        @Suppress("UNCHECKED_CAST")
        return Serialized(elements as Array<CoroutineContext>)
    }

    private class Serialized(val elements: Array<CoroutineContext>) : Serializable {
        companion object {
            private const val serialVersionUID: Long = 0L
        }

        //将所有的结点都通过遍历来plus给组合起来;结合上面的write就是一个序列化的读取和写入
        private fun readResolve(): Any = elements.fold(EmptyCoroutineContext, CoroutineContext::plus)
    }
}

这里需要将CoroutineContext、Element、CombinedContext一块看;

CombineContext是继承自CoroutineContext,而且有两个元素

private val left: CoroutineContext,
private val element: Element

这里主要是实现了[Job()+CoroutineName("xx")+Dispatchers.IO+CoroutineExceptionHandler(...)]

一个+号贯穿了所有的集合,其实内部还有一个-号,不过是隐形的,没有暴露出来。

CoroutineContext中主要有四个接口方法:

  • get 获取
  • fold 遍历操作
  • plus 加入新的CoroutineContext后返回对应的集合
  • minusKey 返回 移除指定了key的集合

这里再说一个概念:就是目前我们在使用的时候用到的例如[ CoroutineName、Dispatchers.IO、CoroutineExceptionHandler、Job、ContinuationInterceptor ...]都是继承自Element。

一共四个方法,这里先看fold

flod

fold只有Element和CombinedContext做了实现;

Element
public override fun <R> fold(initial: R, operation: (R, Element) -> R): R = operation(initial, this)

就是普通的一个带初始值的传入操作;

举个例子:

//伪代码
CoroutineName("xx").fold(Dispatchers.IO){ "xx",Dispatchers.IO ->
    //具体操作
}

就是将两个参数传入进去,具体的操作可以自己再定义;

CombinedContext
public override fun <R> fold(initial: R, operation: (R, Element) -> R): R =
        operation(left.fold(initial, operation), element)

这里其实也是一样的操作,不过这里多嵌套了一层,主要是因为CombinedContext的结构不一样;这里先提前理解下CombinedContext的结构;

CombinedContext(CombinedContext(CombniedContext(...,CoroutineName("xx")),Dispatchers.IO),Job())

这样层层嵌套的一个集合;

上面举例的left就是

CombinedContext(CombniedContext(...,CoroutineName("xx")),Dispatchers.IO)

而对应的element就是Job()。

再回去看看fold方法大概就能理解,是保证这个集合里面的所有的Element都能得到遍历操作。

CombinedContext里面层层嵌套,一直到最后一个,那对应的left就是一个Element,是加入集合的第一个元素。而Element中的fold就是上面的源码所示,所以是一个完整的遍历操作。

至于为什么确定第一个是一个Element而不是CombinedContext,等下看plus的时候会解释。

这里放一个图方便理解CombinedContext的结构:

类似链表结构

WX20221213-152855@2x.png

get

get方法也是只有Element和CombinedContext中有实现:

Element
public override operator fun <E : Element> get(key: Key<E>): E? =
    @Suppress("UNCHECKED_CAST")
    if (this.key == key) this as E else null

这里对该Element执行get操作时,如果Key一样,就直接将该Element返回;

CombinedContext
override fun <E : Element> get(key: Key<E>): E? {
    var cur = this
    while (true) {
        cur.element[key]?.let { return it }
        val next = cur.left
        if (next is CombinedContext) {
            cur = next
        } else {
            return next[key]
        }
    }
}

从集合中获取,如果该CombinedContext结点中的element是对应的key则直接返回。

否则获取left对应的下一个结点,如果是CombinedContext类型则说明又是一个集合,然后将该结点赋值给cur,while循环继续;

如果不是CombinedContext类型,则说明到底了(链表到头了),到了最后一个Element,此时next[key]调用是Element的get方法。是就返回,不是就返回null。

plus

这个则只有CoroutineContext中有实现,子类都不曾覆写。

CoroutineContext
    public operator fun plus(context: CoroutineContext): CoroutineContext =
        //注释点1
        if (context === EmptyCoroutineContext) this else // fast path -- avoid lambda creation
            //注意点2
            context.fold(this) { acc, element ->
                //注意点3
                val removed = acc.minusKey(element.key)
                //注意点4
                if (removed === EmptyCoroutineContext) element else {
                    // make sure interceptor is always last in the context (and thus is fast to get when present)
                    //注意点5
                    val interceptor = removed[ContinuationInterceptor]
                    //注意点6
                    if (interceptor == null) CombinedContext(removed, element) else {
                        //注意点7
                        val left = removed.minusKey(ContinuationInterceptor)
                        //注意点8
                        if (left === EmptyCoroutineContext) CombinedContext(element, interceptor) else
                            //注意点9
                            CombinedContext(CombinedContext(left, element), interceptor)
                    }
                }
            }

先粗略讲解一下:

plus对应着操作符+,如果有新的Element要添加进来,就走的这个方法;

注意点1:

如果传进来的是个EmptyCoroutineContext,则没必要添加进来,则直接将已有的CoroutineContext返回就好。

注意点2:

接下来就是一个fold操作,这个前面讲过;如果传进来的是个Element则直接就是一个元素的操作;如果是CombinedContext则是对整个集合进行遍历操作。

这里看看里面具体的操作,此时需要注意的是,根据之前的fold的解说,acc是原本已有的,element是新传进来的。

注意点3:

这里先提前理解下minusKey函数,就理解为在已有元素的基础上移除指定key的元素,然后将剩下的返回。

注意点4:

如果移除后剩下的就是个EmptyCoroutineContext,则直接将要加入的element返回;因为可能原本什么就没有,这个新加入的element就是第一个,另一个就是原本可能就是一个element,而key正好就是新加入进来的key,则直接将那个element移除了,正好把新的element返回(就是我们之前说的,开启协程时,如果子协程指定了新的线程调度器,则会将之前的覆盖)。

注意点5:

从集合中获取拦截器(如果有添加自定义拦截器的话)

这里的拦截器是个什么东西呢?

    private fun fun25() {
        //拦截器
        class LogInterceptor() : ContinuationInterceptor {
            override val key = ContinuationInterceptor

            override fun <T> interceptContinuation(continuation: Continuation<T>) =
                object : Continuation<T> {
                    override val context: CoroutineContext
                        get() = EmptyCoroutineContext

                    override fun resumeWith(result: Result<T>) {
                        log("ContinuationInterceptor-start")
                        continuation.resumeWith(result)
                        log("ContinuationInterceptor-end")
                    }
                }
        }

        GlobalScope.launch(Dispatchers.Default + LogInterceptor()) {
            log("fun25")
        }
    }

打印的日志

ContinuationInterceptor-start
fun25
ContinuationInterceptor-end

在协程体执行的前后增加了插入处理,可以加入自己想处理的业务逻辑。

注意点5上面有一行注释

// make sure interceptor is always last in the context (and thus is fast to get when present)

就是使用时确保拦截器最好在最后一个添加进去,为的是在get的时候可以快速获取。(继续看后面的代码,也确实是这么处理的)

注意点6:

如果没有添加拦截器,则将元素加入集合中

CombinedContext(removed, element)

这里将移除后剩余的和新加入进来的element组合成了一个新的CombinedContext结点;

这里举个例子:

如果第一个元素是CoroutineName,第一个元素进来时是不走+号的,所以这里不走该方法;

然后第二个元素是Job(),之前没有job则返回的还是只有CoroutineName的element,然后这俩组合成了CombinedContext(CoroutineName,Job);

然后第三个元素Dispatchers.IO,也是和之前的组合CombinedContext(CombinedContext(CoroutineName,Job),Dispatcher.IO);

然后第四个元素是Job,是重复的,也就是会把之前的Job给剔除掉,removed就是CombinedContext(CoroutineName,Dispatchers.IO),再加入新加入的Job;ConmbinedContext(CombinedContext(CoroutineName,Dispatchers.IO),Job);

注意点7-9:

如果存在自定义拦截器,则先将拦截器移除:

  • 剩下的元素如果是EmptyCoroutineContext,则将新加进来的元素和拦截器组合,将拦截器放到尾部。
  • 否则就是先将移除了拦截器剩下的集合与新加进来的element组合,然后再将他们的组合与拦截器组合。

上面就是实现了一个操作:始终将拦截器放在尾部

minusKey

在已有元素的基础上移除指定key的元素,然后将剩下的返回。

同样是只有在Element和CombinedContext中有实现。

Element
public override fun minusKey(key: Key<*>): CoroutineContext =
            if (this.key == key) EmptyCoroutineContext else this

如果对应的key正好一样,则返回EmptyCoroutineContext,直接移除掉了。否则就是没移除掉还是将该元素返回。

举个例子,原本第一个Element是CoroutineName,这时又plus一个CoroutineName,在下面代码中:

[ CoroutineName("xx") ]

val removed = acc.minusKey(element.key)

移除后返回就是 EmptyCoroutineContext ,把对应Key的干掉了,啥都没了;如果plus进来的不是CoroutineName而是Job,则返回this就是CoroutineName本身了;然后后面会和Job组合成为CombinedContext(CoroutineName("xx") , Job())。

CombinedContext
public override fun minusKey(key: Key<*>): CoroutineContext {
    //注意点1
    element[key]?.let { return left }
    //注意点2
    val newLeft = left.minusKey(key)
    return when {
        newLeft === left -> this //注意点3
        newLeft === EmptyCoroutineContext -> element //注意点4
        else -> CombinedContext(newLeft, element) //注意点5
    }
}

注意点1:

如果element正好是则直接将left返回,正好把element移除了。

注意点2:

如果当前结点的element不是,则再往下走,找自己下一个结点,最终根据when有三种结果返回。

这样直接看代码会看不明白,还是结合例子来看:

一共四个return。

假设当前走了好几个plus,目前集合是这样的:

第一个return

CombinedContext(CoroutineName,Dispatchers.IO)

plus一个Dispatchers.Main,这样正好是element的key,则直接将left返回,而left是CoroutineName;然后和Dispatchers.Main结合就是CombinedContext(CoroutineName,Dispatchers.Main)

第二个return

CombinedContext(CoroutineName,Dispatchers.IO)

plus一个Job,element不是对应的key,继续往下走,left也就是CoroutineName也不是对应的key;而CoroutineName是Element类型,他对应的同名方法的处理是:如果不是对应的key则返回的是CoroutineName本身;

所以newLeft === left成立,返回的是this,也就是CombinedContext(CoroutineName,Dispatchers.IO)原封不动返回。

第三个return

CombinedContext(Dispatchers.IO,CoroutineName)

plus一个Dispatchers.Main,element是CoroutineName则继续往下走,走left.minusKey(key);这里left是Dispatchers.IO,是一个Element类型;这里Element类型对应的操作是直接返回一个EmptyCoroutineContext;

newLeft===EmptyCoroutineContext成立,于是返回element,也就是CoroutineName;

第四个return

CombinedContext(CombinedContext(Dispatchers.IO,CoroutineName),Job)

现在是两层CombinedContext起步,此时传进来一个CoroutineName;于是第一个element也就是Job,key不同继续往下走,此时left也就是CombinedContext(Dispatchers.IO,CoroutineName)也走同名方法;CombinedContext(Dispatchers.IO,CoroutineName)的element是CoroutineName,key对上了,于是直接将CombinedContext(Dispatchers.IO,CoroutineName)的left返回也就是Dispatchers.IO;

此时newLeft===Dispatchers.IO,走最后一个return:

CombinedContext(newLeft, element)

newLeft = Dispatchers.IO

element = Job

组合后CombinedContext(Dispatchers.IO, Job)

如此以上几种return都走到了;

回溯起来Element就像是View,CombinedContext就像是ViewGroup。