一个计算每秒钟rtp丢包、收包的算法实现

124 阅读1分钟

kotlin实现


import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.time.Clock

data class ConferenceStatistic(
    val conferenceId: String,
    val rtpLoss: Float,
    val rtpTotal: Int,
    val bucketIndex: Int,
)

// nextSeq 是期待收到的下一个包的seq
class Bucket(var lossRtpSet: MutableSet<Int>, var nextSeq: Int = -1, var packetCount: Int = 0) {
    fun addRtp(seq: Int): Boolean {
        if (nextSeq == -1){
            nextSeq = seq
            packetCount++
            return true
        }

        if (seq == nextSeq) {
            // 顺序收到包
            nextSeq = if (seq == 65535) 0 else seq + 1
            packetCount++
            return true
        }
        if (seq < nextSeq) {
            return if (lossRtpSet.contains(seq)) {
                // it has been lost in the current second
                lossRtpSet.remove(seq)
                packetCount++
                true
            } else if (seq < nextSeq - 10000) {
                // seq is far less than nextSeq, so it must be recounted from 0
                // that means this packet is new packet, but is out of order
                for (i in nextSeq..65535) {
                    lossRtpSet.add(i)
                }
                for (i in 0..<seq) {
                    lossRtpSet.add(i)
                }
                packetCount++
                nextSeq = seq + 1
                true
            } else {
                // this packet is old one, but it has not been lost in the current second
                // will try other bucket to try to find the loss packet
                false
            }
        }
        // receive disorder new rtp
        for (i in nextSeq..<seq) {
                lossRtpSet.add(i)
            }
        nextSeq = seq + 1
        packetCount++
        return true
    }

    fun receivedRtpCount(): Int {
        return packetCount
    }

    fun lossRtpCount(): Int {
        return lossRtpSet.size
    }
}

class ConferenceMonitor {
    private var scope: CoroutineScope
    private var lastTimeStamp: Long
    private val conferenceId: String
    private var bucketIndex: Int = 0
    private val clock: Clock = Clock.systemUTC()
    lateinit var logger: Logger

    @Volatile
    private var running = true;

    private val bucketLength = 5;

    // bucketLength + 1 buckets, each bucket is 1 second, one bucket is writing buffer
    // It's impossible to receive packets from 5 seconds ago
    // One bucket expires every second and it is reset
    private val buckets: MutableList<Bucket> = mutableListOf<Bucket>().also {
        for (i in 0 .. bucketLength){
            it.add(Bucket(mutableSetOf()))
        }
    }

    constructor(conferenceId: String, scope: CoroutineScope) {
        this.conferenceId = conferenceId
        lastTimeStamp = System.currentTimeMillis() / 1000
        logger = LoggerFactory.getLogger("$conferenceId-conferenceMonitor")
        this.scope = scope.also {
            it.launch {
                while (running) {
                    switchBucket()
                    scrape().also { cm ->
                        PromHelper.getConferenceForwardMergeVideoRtpLossGauge()?.labelValues(conferenceId)?.set(cm.rtpLoss.toDouble())
                        PromHelper.getConferenceForwardMergeVideoRtpCountGauge()?.labelValues(conferenceId)?.set(cm.rtpTotal.toDouble())
                    }
                    delay(1000)
                }
            }
        }
    }

    @Synchronized
    fun addRtp(seq: Int) {
        switchBucket()
        if (!buckets[bucketIndex].addRtp(seq)) {
                // seq recount from 0
                for (b in buckets) {
                    if (b.lossRtpSet.contains(seq)) {
                        b.lossRtpSet.remove(seq)
                        b.packetCount++
                        break
                    }
                }
        }
    }

    @Synchronized
    private fun switchBucket(cSecond: Long = clock.millis() / 1000): Int {
        if (cSecond == lastTimeStamp) {
            return bucketIndex
        }
        lastTimeStamp = cSecond
        val lastBucketIndex = bucketIndex
        bucketIndex = if (bucketIndex == bucketLength) 0 else bucketIndex + 1
        buckets[bucketIndex].packetCount = 0
        buckets[bucketIndex].nextSeq = buckets[lastBucketIndex].nextSeq
        buckets[bucketIndex].lossRtpSet.clear()
        return lastBucketIndex
    }

    @Synchronized
    fun scrape(): ConferenceStatistic {
        var lossRtpCount: Float = 0F
        var receivedRtpCount: Float = 0F
        buckets.forEach() {
            lossRtpCount += it.lossRtpCount()
            receivedRtpCount += it.receivedRtpCount()
        }
        if (receivedRtpCount <= 10) {
            // too few packets, so loss ratio is not accurate
            return ConferenceStatistic(conferenceId, 0f, receivedRtpCount.toInt(), bucketIndex)
        }
        val lossRatio = lossRtpCount /(receivedRtpCount + lossRtpCount)

        return ConferenceStatistic(
            conferenceId,
            lossRatio,
            (receivedRtpCount / bucketLength).toInt(),
            bucketIndex
        )
    }

    fun stop() {
        running = false
        PromHelper.getConferenceForwardMergeVideoRtpLossGauge()?.remove(conferenceId)
        PromHelper.getConferenceForwardMergeVideoRtpCountGauge()?.remove(conferenceId)
    }
}

测试代码


class ConferenceMonitorTest {
    @Test
    fun test1() {
        val currentTimestamp = System.currentTimeMillis() / 1000
        while (true) {
            // wait until next second
            if (System.currentTimeMillis() / 1000 != currentTimestamp) {
                break
            }
        }
        val cs = CoroutineScope(Dispatchers.IO)
        val cm = ConferenceMonitor("test", cs)
        val cl = CountDownLatch(2)
        val w1 = thread(start = true) {
            for (i in 0..15) {
                Thread.sleep(1000)
                cm.scrape().let {
                    println("from monitor, loss is ${it.rtpLoss}, received is ${it.rtpTotal}")
                }
            }
            cl.countDown()
        }

        val w2 = thread(start = true) {
            for (i in 0..10000) {
                if (i % 10 == 0) {
                    // 10% packet loss
                    continue
                }
                cm.addRtp(i)
                Thread.sleep(4)
            }
            cl.countDown()
        }
        cl.await()
    }
}