5-1-3 协程实战-挂起流程

35 阅读6分钟

Kotlin协程挂起流程深度解析

我将深入剖析Kotlin协程的挂起机制,从编译器转换到运行时状态管理,全面解析挂起函数的执行流程。

1. 挂起函数的编译器转换

1.1 基础挂起函数的转换

原始代码:

suspend fun fetchUser(userId: Int): User {
    val profile = fetchProfile(userId)  // 挂起点1
    val avatar = fetchAvatar(userId)    // 挂起点2
    return User(profile, avatar)
}

suspend fun fetchProfile(userId: Int): Profile {
    delay(1000)
    return Profile(userId, "User $userId")
}

suspend fun fetchAvatar(userId: Int): Avatar {
    delay(500)
    return Avatar("avatar_$userId.png")
}

编译器转换后(伪代码):

// 编译器为每个挂起函数生成一个状态机类
class FetchUserContinuation(
    completion: Continuation<User>
) : ContinuationImpl(completion) {
    // 状态机状态
    var label: Int = 0
    var result: Any? = null
    var userId: Int = 0
    var profile: Profile? = null
    
    // 状态机入口方法
    fun invokeSuspend(result: Any?): Any? {
        when (label) {
            0 -> {
                // 初始状态
                this.userId = userId
                label = 1
                // 调用第一个挂起函数
                return fetchProfile(userId, this)
            }
            1 -> {
                // 第一个挂起函数完成
                profile = result as Profile
                label = 2
                // 调用第二个挂起函数
                return fetchAvatar(userId, this)
            }
            2 -> {
                // 第二个挂起函数完成
                val avatar = result as Avatar
                val user = User(profile!!, avatar)
                // 完成整个函数
                completion.resumeWith(Result.success(user))
                return Unit
            }
            else -> throw IllegalStateException("Invalid state")
        }
    }
}

// 转换后的fetchUser函数
fun fetchUser(userId: Int, completion: Continuation<User>): Any? {
    val continuation = if (completion is FetchUserContinuation) {
        completion
    } else {
        FetchUserContinuation(completion).apply {
            this.userId = userId
        }
    }
    
    return continuation.invokeSuspend(null)
}

2. Continuation接口详解

2.1 Continuation接口定义

interface Continuation<in T> {
    val context: CoroutineContext
    
    fun resumeWith(result: Result<T>)
}

// 编译器生成的子类
internal abstract class BaseContinuationImpl : Continuation<Any?> {
    final override fun resumeWith(result: Result<Any?>) {
        var current = this
        var param = result
        
        while (true) {
            with(current) {
                val outcome: Result<Any?> = try {
                    // 调用invokeSuspend,执行状态机逻辑
                    val outcome = invokeSuspend(param)
                    if (outcome === COROUTINE_SUSPENDED) {
                        return  // 继续挂起
                    }
                    Result.success(outcome)
                } catch (exception: Throwable) {
                    Result.failure(exception)
                }
                
                // 获取完成续体
                val completion = completion
                if (completion == null) {
                    // 这是顶级续体,恢复调用者
                    return
                }
                
                // 设置当前续体为父续体
                param = outcome
                current = completion
            }
        }
    }
    
    abstract fun invokeSuspend(result: Result<Any?>): Any?
}

3. 挂起点状态管理

3.1 挂起点的标识

// 编译器插入的挂起点标识符
internal val COROUTINE_SUSPENDED = Any()

// 挂起函数示例
suspend fun complexOperation(): String {
    // 挂起点 1
    val data1 = suspendCoroutine { cont ->
        thread {
            Thread.sleep(100)
            cont.resume("data1")
        }
    }
    
    // 挂起点 2
    val data2 = suspendCoroutine { cont ->
        thread {
            Thread.sleep(50)
            cont.resume("data2")
        }
    }
    
    // 挂起点 3
    val data3 = withContext(Dispatchers.IO) {
        delay(200)
        "data3"
    }
    
    return "$data1-$data2-$data3"
}

// 编译器生成的续体状态机
class ComplexOperationContinuation(
    completion: Continuation<String>
) : ContinuationImpl(completion) {
    var label = 0
    var data1: String? = null
    var data2: String? = null
    var data3: String? = null
    
    override fun invokeSuspend(result: Result<Any?>): Any? {
        when (label) {
            0 -> {
                label = 1
                // 调用suspendCoroutine
                return suspendCoroutine<Unit> { cont ->
                    thread {
                        Thread.sleep(100)
                        cont.resume(Unit)
                    }
                }
            }
            1 -> {
                data1 = result.getOrThrow() as String
                label = 2
                return suspendCoroutine<Unit> { cont ->
                    thread {
                        Thread.sleep(50)
                        cont.resume(Unit)
                    }
                }
            }
            2 -> {
                data2 = result.getOrThrow() as String
                label = 3
                // 调用withContext,切换调度器
                return withContext(Dispatchers.IO, this) {
                    delay(200)
                    "data3"
                }
            }
            3 -> {
                data3 = result.getOrThrow() as String
                return "$data1-$data2-$data3"
            }
            else -> throw IllegalStateException()
        }
    }
}

4. 协程调度器与线程切换

4.1 调度器实现原理

// Dispatcher的挂起恢复流程
suspend fun <T> withContext(
    context: CoroutineContext,
    block: suspend CoroutineScope.() -> T
): T = suspendCoroutineUninterceptedOrReturn { uCont ->
    // 获取当前协程的续体
    val oldContext = uCont.context
    val newContext = oldContext + context
    
    // 如果上下文相同,直接执行
    if (newContext === oldContext) {
        val coroutine = ScopeCoroutine(newContext, uCont)
        return@suspendCoroutineUninterceptedOrReturn coroutine.startUndispatchedOrReturn(coroutine, block)
    }
    
    // 检查是否需要线程调度
    if (newContext[ContinuationInterceptor] == oldContext[ContinuationInterceptor]) {
        // 相同的拦截器,不需要线程切换
        val coroutine = UndispatchedCoroutine(newContext, uCont)
        return@suspendCoroutineUninterceptedOrReturn coroutine.startUndispatchedOrReturn(coroutine, block)
    }
    
    // 需要线程切换,创建新的协程
    val coroutine = DispatchedCoroutine(newContext, uCont)
    block.startCoroutineCancellable(coroutine, coroutine)
    coroutine.getResult()
}

// 被调度的协程实现
private class DispatchedCoroutine<in T>(
    context: CoroutineContext,
    uCont: Continuation<T>
) : ScopeCoroutine<T>(context, uCont) {
    
    override fun afterResume(state: Any?) {
        // 获取父协程的续体
        val delegate = uCont as? ContinuationImpl ?: return
        val context = delegate.context
        
        // 获取拦截器(调度器)
        val interceptor = context[ContinuationInterceptor]
        
        if (interceptor != null) {
            // 通过拦截器调度恢复
            interceptor.interceptContinuation(delegate).resumeWith(state)
        } else {
            // 直接恢复
            delegate.resumeWith(state)
        }
    }
}

// Dispatchers.IO调度器示例
object Dispatchers.IO : CoroutineDispatcher() {
    override fun dispatch(context: CoroutineContext, block: Runnable) {
        // 从线程池获取线程执行
        DefaultExecutor.execute(block)
    }
    
    override fun isDispatchNeeded(context: CoroutineContext): Boolean {
        // 总是需要调度(除非在测试中)
        return true
    }
}

5. 挂起函数的执行流程跟踪

5.1 自定义挂起跟踪工具

import kotlin.coroutines.*
import kotlin.coroutines.intrinsics.*

// 挂起点跟踪器
class SuspensionTracer {
    private val stack = mutableListOf<SuspensionPoint>()
    
    data class SuspensionPoint(
        val functionName: String,
        val lineNumber: Int,
        val timestamp: Long = System.nanoTime(),
        val threadName: String = Thread.currentThread().name
    )
    
    fun traceEnter(functionName: String, lineNumber: Int) {
        stack.add(SuspensionPoint(functionName, lineNumber))
        println("↘ 进入挂起函数: $functionName (L$lineNumber)")
        println("   线程: ${Thread.currentThread().name}")
        println("   栈深度: ${stack.size}")
    }
    
    fun traceExit(functionName: String, result: Any?) {
        val point = stack.removeLastOrNull()
        if (point != null) {
            val duration = (System.nanoTime() - point.timestamp) / 1_000_000.0
            println("↗ 离开挂起函数: $functionName")
            println("   结果: $result")
            println("   耗时: ${duration}ms")
            println("   线程: ${Thread.currentThread().name}")
        }
    }
    
    fun traceResume(continuation: Continuation<*>) {
        println("↻ 恢复协程")
        println("   续体: ${continuation::class.simpleName}")
        println("   上下文: ${continuation.context}")
    }
}

// 挂起函数包装器
suspend inline fun <T> traceSuspend(
    crossinline block: suspend () -> T
): T = suspendCoroutineUninterceptedOrReturn { cont ->
    val tracer = SuspensionTracer()
    val stackTrace = Thread.currentThread().stackTrace
    
    // 查找调用者信息
    val caller = stackTrace.find { 
        !it.className.contains("SuspensionTracer") && 
        !it.className.contains("Coroutine")
    }
    
    caller?.let {
        tracer.traceEnter(it.methodName, it.lineNumber)
    }
    
    // 包装续体
    val tracedCont = object : Continuation<T> {
        override val context: CoroutineContext = cont.context
        
        override fun resumeWith(result: Result<T>) {
            caller?.let {
                tracer.traceExit(it.methodName, result.getOrNull())
            }
            tracer.traceResume(cont)
            cont.resumeWith(result)
        }
    }
    
    // 启动块
    block.startCoroutine(tracedCont)
    
    // 如果挂起,返回COROUTINE_SUSPENDED
    COROUTINE_SUSPENDED
}

// 使用示例
suspend fun tracedOperation(): String = traceSuspend {
    println("  执行操作...")
    delay(100)
    "操作完成"
}

suspend fun nestedTracedOperation(): String = traceSuspend {
    val result1 = tracedOperation()
    val result2 = tracedOperation()
    "$result1 + $result2"
}

6. 挂起函数的异常处理流程

6.1 异常传播机制

// 异常在挂起函数中的传播
suspend fun riskyOperation(): String {
    try {
        val result1 = mayFail(1)  // 可能抛出异常
        val result2 = mayFail(2)  // 可能抛出异常
        return "$result1 + $result2"
    } catch (e: Exception) {
        println("捕获异常: ${e.message}")
        throw e  // 重新抛出,让调用者处理
    }
}

suspend fun mayFail(attempt: Int): String = suspendCoroutine { cont ->
    if (Math.random() > 0.5) {
        cont.resume("成功-$attempt")
    } else {
        cont.resumeWithException(
            IllegalStateException("操作 $attempt 失败")
        )
    }
}

// 编译器生成的异常处理状态机
class RiskyOperationContinuation(
    completion: Continuation<String>
) : ContinuationImpl(completion) {
    var label = 0
    var result1: String? = null
    
    override fun invokeSuspend(result: Result<Any?>): Any? {
        return try {
            when (label) {
                0 -> {
                    label = 1
                    mayFail(1, this)
                }
                1 -> {
                    result1 = result.getOrThrow() as String
                    label = 2
                    mayFail(2, this)
                }
                2 -> {
                    val result2 = result.getOrThrow() as String
                    "$result1 + $result2"
                }
                else -> throw IllegalStateException()
            }
        } catch (e: Exception) {
            // 异常通过Result传播
            Result.failure(e)
        }
    }
}

7. 协程取消与挂起

7.1 取消检查点

// 挂起函数中的取消检查
suspend fun cancellableOperation(): String {
    // 编译器会在每个挂起点插入取消检查
    // 伪代码:ensureActive()
    
    val result1 = delayWithCancellation(100)
    
    // 显式取消检查
    kotlin.coroutines.coroutineContext.ensureActive()
    
    val result2 = delayWithCancellation(200)
    
    return "$result1 - $result2"
}

suspend fun delayWithCancellation(timeMillis: Long): String {
    // 挂起函数内部会检查取消状态
    return suspendCancellableCoroutine { cont ->
        val job = cont.context[Job]
        
        // 注册取消回调
        cont.invokeOnCancellation { 
            println("操作被取消")
        }
        
        // 模拟异步操作
        thread {
            Thread.sleep(timeMillis)
            if (cont.isActive) {
                cont.resume("延迟 $timeMillis ms")
            }
        }
    }
}

// 编译器生成的取消检查
class CancellableOperationContinuation(
    completion: Continuation<String>
) : ContinuationImpl(completion) {
    override fun invokeSuspend(result: Result<Any?>): Any? {
        // 在每个挂起点前插入取消检查
        when (label) {
            0 -> {
                // 检查取消状态
                val context = completion.context
                context.ensureActive()
                
                label = 1
                return delayWithCancellation(100, this)
            }
            1 -> {
                // 检查取消状态
                val context = completion.context
                context.ensureActive()
                
                label = 2
                return delayWithCancellation(200, this)
            }
            2 -> {
                return result.getOrThrow() as String
            }
        }
        return null
    }
}

8. 挂起函数的状态可视化

8.1 协程状态跟踪器

import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import java.util.concurrent.atomic.AtomicInteger

class CoroutineStateTracker {
    private val coroutineStates = mutableMapOf<String, CoroutineState>()
    private val stateChannel = Channel<StateUpdate>(Channel.UNLIMITED)
    
    data class CoroutineState(
        val name: String,
        var status: Status,
        var thread: String?,
        var suspensionPoint: String?,
        var startTime: Long,
        var lastUpdate: Long
    )
    
    enum class Status {
        CREATED, RUNNING, SUSPENDED, RESUMING, COMPLETED, CANCELLED
    }
    
    data class StateUpdate(
        val coroutineName: String,
        val newStatus: Status,
        val thread: String?,
        val suspensionPoint: String?
    )
    
    suspend fun trackCoroutine(name: String, block: suspend () -> Unit) {
        updateState(name, Status.CREATED, null, null)
        
        try {
            updateState(name, Status.RUNNING, Thread.currentThread().name, null)
            block()
            updateState(name, Status.COMPLETED, null, null)
        } catch (e: CancellationException) {
            updateState(name, Status.CANCELLED, null, null)
        } catch (e: Exception) {
            updateState(name, Status.COMPLETED, null, null)
            throw e
        }
    }
    
    private fun updateState(
        name: String,
        status: Status,
        thread: String?,
        point: String?
    ) {
        val state = coroutineStates.getOrPut(name) {
            CoroutineState(
                name = name,
                status = Status.CREATED,
                thread = null,
                suspensionPoint = null,
                startTime = System.currentTimeMillis(),
                lastUpdate = System.currentTimeMillis()
            )
        }
        
        state.status = status
        state.thread = thread
        state.suspensionPoint = point
        state.lastUpdate = System.currentTimeMillis()
        
        // 发送状态更新
        stateChannel.trySend(StateUpdate(name, status, thread, point))
    }
    
    // 挂起函数包装器,自动跟踪
    suspend fun <T> traceSuspend(
        coroutineName: String,
        operationName: String,
        block: suspend () -> T
    ): T {
        // 进入挂起点
        updateState(
            coroutineName,
            Status.SUSPENDED,
            Thread.currentThread().name,
            operationName
        )
        
        return try {
            val result = block()
            
            // 恢复执行
            updateState(
                coroutineName,
                Status.RESUMING,
                Thread.currentThread().name,
                null
            )
            
            result
        } catch (e: Exception) {
            updateState(
                coroutineName,
                Status.RESUMING,
                Thread.currentThread().name,
                null
            )
            throw e
        }
    }
    
    fun printStates() {
        println("\n=== 协程状态快照 ===")
        coroutineStates.values.forEach { state ->
            val duration = System.currentTimeMillis() - state.startTime
            println("${state.name}: ${state.status} (${duration}ms)")
            println("  线程: ${state.thread}")
            println("  挂起点: ${state.suspensionPoint ?: "无"}")
        }
    }
    
    fun startMonitor() = GlobalScope.launch {
        for (update in stateChannel) {
            println("[监控] ${update.coroutineName}: ${update.newStatus}")
            if (update.thread != null) {
                println("     线程: ${update.thread}")
            }
            if (update.suspensionPoint != null) {
                println("     挂起点: ${update.suspensionPoint}")
            }
        }
    }
}

// 使用示例
fun main() = runBlocking {
    val tracker = CoroutineStateTracker()
    tracker.startMonitor()
    
    val job = launch {
        tracker.trackCoroutine("worker-1") {
            val result1 = tracker.traceSuspend("worker-1", "fetchData") {
                delay(100)
                "data1"
            }
            
            val result2 = tracker.traceSuspend("worker-1", "processData") {
                delay(150)
                "processed: $result1"
            }
            
            println("结果: $result2")
        }
    }
    
    job.join()
    tracker.printStates()
}

9. 挂起函数的性能优化

9.1 避免挂起开销

// 优化1: 避免不必要的挂起
suspend fun getCachedOrFetch(key: String): String {
    // 首先尝试从缓存获取(非挂起操作)
    val cached = cache[key]
    if (cached != null) {
        return cached  // 直接返回,不挂起
    }
    
    // 缓存未命中,执行挂起操作
    return fetchFromNetwork(key)
}

// 优化2: 批量挂起
suspend fun fetchMultiple(items: List<String>): Map<String, String> {
    // 一次性获取所有,而不是逐个获取
    return withContext(Dispatchers.IO) {
        items.associateWith { fetchItem(it) }
    }
}

// 优化3: 使用通道减少挂起次数
fun CoroutineScope.produceItems(): ReceiveChannel<String> = produce {
    repeat(100) { index ->
        send("item-$index")
        // 每10个item检查一次取消
        if (index % 10 == 0) {
            ensureActive()
        }
    }
}

// 优化4: 挂起函数的inline优化
inline suspend fun <T> fastOperation(
    crossinline block: () -> T
): T = suspendCoroutineUninterceptedOrReturn { cont ->
    try {
        val result = block()
        cont.resume(result)
    } catch (e: Exception) {
        cont.resumeWithException(e)
    }
}

// 使用
suspend fun optimizedWork() {
    // 这个操作不会真的挂起,因为block是即时执行的
    val result = fastOperation {
        // 快速计算
        (1..100).sum()
    }
}

10. 挂起流程总结

10.1 完整挂起流程时序图

协程调用者          挂起函数          调度器            线程池
    |                |                |                |
    | 1.调用挂起函数  |                |                |
    |--------------->|                |                |
    |                |                |                |
    |                | 2.创建状态机    |                |
    |                |--------------->|                |
    |                |                |                |
    |                | 3.检查是否需要调度 |                |
    |                |--------------->|                |
    |                |                |                |
    |                | 4.挂起并返回SUSPENDED|            |
    |<---------------|                |                |
    |                |                |                |
    |                | 5.调度到线程池   |                |
    |                |                |--------------->|
    |                |                |                |
    |                |                | 6.在线程中执行  |
    |                |                |<-------------->|
    |                |                |                |
    |                | 7.完成,恢复协程 |                |
    |                |<---------------|                |
    |                |                |                |
    | 8.继续执行      |                |                |
    |<---------------|                |                |

10.2 关键要点

  1. 状态机转换:编译器将挂起函数转换为状态机,每个挂起点对应一个状态
  2. 续体传递:通过Continuation接口在挂起点间传递执行状态
  3. 挂起标识:COROUTINE_SUSPENDED是特殊的返回值,表示函数已挂起
  4. 调度器作用:决定协程在哪个线程恢复执行
  5. 取消传播:取消状态通过协程上下文传播,挂起点检查取消状态
  6. 异常处理:异常通过Result包装在续体中传播

10.3 性能建议

  1. 避免频繁的挂起/恢复操作
  2. 批量处理异步任务
  3. 使用适当的调度器
  4. 合理设置协程上下文
  5. 及时释放资源,避免内存泄漏

理解Kotlin协程的挂起流程对于编写高效、正确的异步代码至关重要。通过深入了解内部机制,可以更好地调试性能问题,编写更优雅的并发代码。