Okio - 源码分析(三)-超时机制

130 阅读2分钟

携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第8天,点击查看活动详情

超时机制

Okio的亮点之一就是增加了超时机制,防止因为意外导致I/O一直阻塞的问题,默认的超时机制是同步的。

AsyncTimeout是Okio中异步超时机制的实现,它是一个单链表,结点按等待时间从小到大排序,head是一个头结点,起占位作用。使用了一个WatchDog的后台线程来不断的遍历所有节点,如果某个节点超时就会将该节点从链表中移除,并关闭Socket。

open class AsyncTimeout : Timeout() {
  /** 如果此节点当前在队列中,则为真。  */
  private var inQueue = false

  /** 链表中的下一个节点。*/
  private var next: AsyncTimeout? = null

  /** 如果安排好了,这是看门狗应该超时的时间。  */
  private var timeoutAt = 0L

 
  //把当前AsyncTimeout对象加入节点
  fun enter() {
    val timeoutNanos = timeoutNanos()
    val hasDeadline = hasDeadline()
    if (timeoutNanos == 0L && !hasDeadline) {
      return // No timeout and no deadline? Don't bother with the queue.
    }
    scheduleTimeout(this, timeoutNanos, hasDeadline)
  }

  /** 从链表中移除节点 */
  fun exit(): Boolean {
    return cancelScheduledTimeout(this)
  }
  private fun remainingNanos(now: Long) = timeoutAt - now
  //在子类中重写了该方法,主要是进行socket的关闭
  protected open fun timedOut() {}
  fun sink(sink: Sink): Sink {
    return object : Sink {
      override fun write(source: Buffer, byteCount: Long) {
        checkOffsetAndCount(source.size, 0, byteCount)

        var remaining = byteCount
        while (remaining > 0L) {
          var toWrite = 0L
          var s = source.head!!
          while (toWrite < TIMEOUT_WRITE_SIZE) {
            val segmentSize = s.limit - s.pos
            toWrite += segmentSize.toLong()
            if (toWrite >= remaining) {
              toWrite = remaining
              break
            }
            s = s.next!!
          }

          withTimeout { sink.write(source, toWrite) }
          remaining -= toWrite
        }
      }

      override fun flush() {
        withTimeout { sink.flush() }
      }

      override fun close() {
        withTimeout { sink.close() }
      }

      override fun timeout() = this@AsyncTimeout

      override fun toString() = "AsyncTimeout.sink($sink)"
    }
  }


  fun source(source: Source): Source {
    return object : Source {
      override fun read(sink: Buffer, byteCount: Long): Long {
        return withTimeout { source.read(sink, byteCount) }
      }

      override fun close() {
        withTimeout { source.close() }
      }

      override fun timeout() = this@AsyncTimeout

      override fun toString() = "AsyncTimeout.source($source)"
    }
  }

  inline fun <T> withTimeout(block: () -> T): T {
    var throwOnTimeout = false
    enter()
    try {
      val result = block()
      throwOnTimeout = true
      return result
    } catch (e: IOException) {
      throw if (!exit()) e else `access$newTimeoutException`(e)
    } finally {
      val timedOut = exit()
      if (timedOut && throwOnTimeout) throw `access$newTimeoutException`(null)
    }
  }

  @PublishedApi // Binary compatible trampoline function
  internal fun `access$newTimeoutException`(cause: IOException?) = newTimeoutException(cause)


  protected open fun newTimeoutException(cause: IOException?): IOException {
    val e = InterruptedIOException("timeout")
    if (cause != null) {
      e.initCause(cause)
    }
    return e
  }

  private class Watchdog internal constructor() : Thread("Okio Watchdog") {
    init {
      isDaemon = true
    }

    override fun run() {
      while (true) {
        try {
          var timedOut: AsyncTimeout? = null
          synchronized(AsyncTimeout::class.java) {
            timedOut = awaitTimeout()
            //除头结点外没有任何其他节点
            if (timedOut === head) {
              head = null
              return
            }
          }

          timedOut?.timedOut()
        } catch (ignored: InterruptedException) {
        }
      }
    }
  }

  companion object {

    private const val TIMEOUT_WRITE_SIZE = 64 * 1024

    private val IDLE_TIMEOUT_MILLIS = TimeUnit.SECONDS.toMillis(60)
    private val IDLE_TIMEOUT_NANOS = TimeUnit.MILLISECONDS.toNanos(IDLE_TIMEOUT_MILLIS)

 
    private var head: AsyncTimeout? = null

    private fun scheduleTimeout(node: AsyncTimeout, timeoutNanos: Long, hasDeadline: Boolean) {
      synchronized(AsyncTimeout::class.java) {
        check(!node.inQueue) { "Unbalanced enter/exit" }
        node.inQueue = true

        if (head == null) {
          head = AsyncTimeout()
          Watchdog().start()
        }

        val now = System.nanoTime()
        if (timeoutNanos != 0L && hasDeadline) {
          node.timeoutAt = now + minOf(timeoutNanos, node.deadlineNanoTime() - now)
        } else if (timeoutNanos != 0L) {
          node.timeoutAt = now + timeoutNanos
        } else if (hasDeadline) {
          node.timeoutAt = node.deadlineNanoTime()
        } else {
          throw AssertionError()
        }

        val remainingNanos = node.remainingNanos(now)
        var prev = head!!
        while (true) {
          if (prev.next == null || remainingNanos < prev.next!!.remainingNanos(now)) {
            node.next = prev.next
            prev.next = node
            if (prev === head) {
              // Wake up the watchdog when inserting at the front.
              (AsyncTimeout::class.java as Object).notify()
            }
            break
          }
          prev = prev.next!!
        }
      }
    }

    //执行真正的移除操作
    private fun cancelScheduledTimeout(node: AsyncTimeout): Boolean {
      synchronized(AsyncTimeout::class.java) {
        if (!node.inQueue) return false
        node.inQueue = false

        var prev = head
        while (prev != null) {
          if (prev.next === node) {
            prev.next = node.next
            node.next = null
            return false
          }
          prev = prev.next
        }
        return true
      }
    }

    @Throws(InterruptedException::class)
    internal fun awaitTimeout(): AsyncTimeout? {
      val node = head!!.next
      //除了头结点外没有任何其他节点
      if (node == null) {
        val startNanos = System.nanoTime()
        (AsyncTimeout::class.java as Object).wait(IDLE_TIMEOUT_MILLIS)
        return if (head!!.next == null && System.nanoTime() - startNanos >= IDLE_TIMEOUT_NANOS) {
          head // The idle timeout elapsed.
        } else {
          null // The situation has changed.
        }
      }

      var waitNanos = node.remainingNanos(System.nanoTime())

      if (waitNanos > 0) {
        val waitMillis = waitNanos / 1000000L
        waitNanos -= waitMillis * 1000000L
        (AsyncTimeout::class.java as Object).wait(waitMillis, waitNanos.toInt())
        return null
      }
      //代表node节点已超时
      head!!.next = node.next
      node.next = null
      return node
    }
  }
}

默认都是未设置超时时间的,需要我们来设置