Kotlin协程的CoroutineContext.plus()一步步地分析

1,091 阅读4分钟

Kotlin的与Java差异巨大的语法时常给我带来困扰,这在学习协程的过程中非常明显。比如CoroutineContext.plus()方法,一直以来都没搞懂。

先说结论:

  1. 每个CoroutineContext类都有static final Element = 自己类.class的唯一对象
  2. 就像Map一样,每个CoroutineContext中,相同类型的Element只能有一份,如果 contextA + contextB 时B中有相同的Element, 就把A中的相同Element移除,放入B的Element。
  3. ContinuationInterceptor一定是放在CoroutineContext中Element列表最后的位置,主要就是Dipatchers.XX切换线程操作。在分析plus()方法时这一点并不重要。

所以,现在我要逐行逐行地分析一遍。

对以下的Scope打印结果和class

val scope0 = EmptyCoroutineContext
scope0 = EmptyCoroutineContext : kotlin.coroutines.EmptyCoroutineContext

val scope1 = CoroutineName("Hello")
CoroutineName(Hello) : kotlinx.coroutines.CoroutineName

val scope2 = scope1 + Dispatchers.Main
scope2 = [CoroutineName(Hello), Main] : kotlin.coroutines.CombinedContext

val scope3 = scope2 + Dispatchers.IO
scope3 = [CoroutineName(Hello), Dispatchers.IO] : kotlin.coroutines.CombinedContext

val scope4 = scope3 + EmptyCoroutineContext
scope4 = [CoroutineName(Hello), Dispatchers.IO] : kotlin.coroutines.CombinedContext

val scope5 = scope4 + CoroutineName("World")
scope5 = [CoroutineName(World), Dispatchers.IO] : kotlin.coroutines.CombinedContext

看不懂plus()方法的问题点是什么?

public operator fun plus(context: CoroutineContext): CoroutineContext =
    if (context === EmptyCoroutineContext) this else // fast path -- avoid lambda creation
        context.fold(this) { acc, element ->
            val removed = acc.minusKey(element.key)
            if (removed === EmptyCoroutineContext) element else {
                // make sure interceptor is always last in the context (and thus is fast to get when present)
                val interceptor = removed[ContinuationInterceptor]
                if (interceptor == null) CombinedContext(removed, element) else {
                    val left = removed.minusKey(ContinuationInterceptor)
                    if (left === EmptyCoroutineContext) CombinedContext(element, interceptor) else
                        CombinedContext(CombinedContext(left, element), interceptor)
                }
            }
        }

其实主要问题在于: fold(this) { acc, element } 是啥?

fold() 是啥?

实现fold()方法的一共有3个类,而所有的CoroutineContext都是这3个类的子类

object EmptyCoroutineContext
public override fun <R> fold(initial: R, operation: (R, Element) -> R): R = initial

interface Element
CoroutineName implements Element
Dispatchers.XX implements Element 几个分发器最终继承的也都是Element
public override fun <R> fold(initial: R, operation: (R, Element) -> R): R =
    operation(initial, this)
    
class CombinedContext
public override fun <R> fold(initial: R, operation: (R, Element) -> R): R =
    operation(left.fold(initial, operation), element)

有了这3个fold()的实现,就可以逐行去分析plus()都做了什么。

以下分析过程中,+ 左边用A表示,右边用B表示 -> A + B

scope0

它就是一个普通的EmptyCoroutineContext,作为初始化值。

scope1 = EmptyCoroutineContext + CoroutineName("Hello")

  1. 先判断B是否是Empty, 不是

  2. B.fold(A) {} 进入 CoroutineName->Element.fold(A) {}

  3. fold()的第二个参数是一个两参函数,此时执行的是: Element.fold(A) { acc=A, element=B }

  4. 参数已经确定,继续plus(): 以各自的Element为Key, 从A中减去B, A是Empty, Empty.minusKey(Key)不管减什么都返回this, 本身就是空的,啥都减不掉,最终还是Empty。

  5. 即然要执行加法,一边已经是Empty,再加任何东西,最终值都是另外一边,那就直接返回B。

  6. 最终结果=B=CoroutineName("Hello")

scope2 = CoroutineName("Hello") + Dispatchers.Main

  1. 先判断B是否是Empty, 不是

  2. B.fold(A) {} 进入 Dispatchers.Main->Element.fold(A) {}

  3. 从上一小节可以知道,fold第二个参数的赋值情况: Element.fold(A) { acc=A, element=B }

  4. val removed = 从A中减去B: Element.minusKey(key)判断参数key和自身key是否相等, 如相等就是自己减自己,返回Empty, 如不等表示减不掉,返回this,此处Key不等,返回this=CoroutineName("Hello")

  5. 判断removed中是否有拦截器,也就是Dispatchers.XX,此处没有,返回CombinedContext(removed=CoroutineName("Hello"), element=Dispatchers.Main)

scope3 = [CoroutineName("Hello"), Dispatchers.Main] + Dispatchers.IO

  1. 先判断B是否是Empty, 不是

  2. B.fold(A) {} 进入 Dispatchers.IO->Element.fold(A) {}

  3. Element.fold(A) { acc=A, element=B }

  4. val removed = A - B, A是CombinedContext类型
    (1) 判断 combined.element.key 是否等于 B.key, 此处等于, 直接返回 left=CoroutineName("Hello")

  5. 判断removed中是否有拦截器,拦截器刚刚被减掉了,那就添加新拦截器,然后返回CombinedContext(removed=CoroutineName("Hello"), element=Dispatchers.IO)

scope4 = [CoroutineName("Hello"), Dispatchers.IO] + EmptyCoroutineContext

  1. 先判断B是否是Empty, 是, 加个空没意义,直接返回自身 [CoroutineName("Hello"), Dispatchers.IO]

scope5 = [CoroutineName("Hello"), Dispatchers.IO] + CoroutineName("World")

  1. 先判断B是否是Empty, 不是

    1. B.fold(A) {} 进入 CoroutineName->Element.fold(A) {}
  2. Element.fold(A) { acc=A, element=B }

  3. val removed = A - B, A是CombinedContext类型

    (1) 判断 combined.element.key 是否等于 B.key, 此处不等于

    (2) 从 combined.left中减去B.key , 两者都是CoroutineName, 相等

    (3) CoroutineName->Element相减结果是Empty, 返回element=Dispatchers.IO

  4. removed不是Empty, 继续

  5. 判断removed中是否有拦截器,此次是有的,用interceptor=Dispatchers.IO将它记录下来

  6. 从removed中减去拦截器, 得到left=Empty, 返回结果CombinedContext(element=CoroutineName("World"), interceptor=Dispatchers.IO)实现拦截器一定要放在最后

  7. 写这么多太累了, 不想再写条新的, 直接假设 removed=[SupervisorJob, Dispatchers.IO], 来看plus()的最后一个分支, 从第6点开始

    (1) 判断removed中是否有拦截器,有,用interceptor=Dispatchers.IO将它记录下来

    (2) 从removed中减去拦截器,得到left=SupervisorJob

    (3) left != Empty, 返回结果 CombinedContext(CombinedContext(left=SupervisorJob, element=B=CoroutineName("World")), interceptor=Dispatchers.IO)

    (4) 结论同样是: 如有相同的Key则替换, Dispatchers放到最后