2-2-32 快速掌握Kotlin-合并函数 fold 和 reduce

35 阅读4分钟

Kotlin 合并函数 foldreduce

在 Kotlin 中,foldreduce 是用于将集合元素合并为单个值的函数。它们被称为"折叠"操作,是函数式编程中的核心概念。

1. fold 函数详解

基本语法

// fold 函数签名
fun <T, R> Iterable<T>.fold(
    initial: R,                    // 初始值
    operation: (acc: R, T) -> R    // 合并操作
): R

基础示例

val numbers = listOf(1, 2, 3, 4, 5)

// 求和
val sum = numbers.fold(0) { acc, number ->
    acc + number
}
println(sum)  // 15

// 求乘积
val product = numbers.fold(1) { acc, number ->
    acc * number
}
println(product)  // 120

// 字符串连接
val strings = listOf("a", "b", "c", "d")
val concatenated = strings.fold("") { acc, str ->
    acc + str
}
println(concatenated)  // "abcd"

// 复杂类型:计算平均值
val average = numbers.fold(0.0 to 0) { (sum, count), number ->
    sum + number to count + 1
}.let { (sum, count) -> sum / count }
println(average)  // 3.0

2. reduce 函数详解

基本语法

// reduce 函数签名
fun <S, T : S> Iterable<T>.reduce(
    operation: (acc: S, T) -> S    // 合并操作
): S

fold 的区别

val numbers = listOf(1, 2, 3, 4, 5)

// fold: 需要提供初始值
val sumWithFold = numbers.fold(0) { acc, n -> acc + n }

// reduce: 使用第一个元素作为初始值
val sumWithReduce = numbers.reduce { acc, n -> acc + n }

println("fold: $sumWithFold, reduce: $sumWithReduce")  // 都是15

// 重要区别:空集合的处理
val emptyList = emptyList<Int>()

try {
    emptyList.reduce { acc, n -> acc + n }  // 抛出异常:UnsupportedOperationException
} catch (e: UnsupportedOperationException) {
    println("reduce 不能用于空集合")
}

val safeResult = emptyList.fold(0) { acc, n -> acc + n }  // 安全返回0
println("fold 空集合: $safeResult")

reduce 的注意事项

// 1. 空集合会抛出异常
// emptyList<Int>().reduce { acc, n -> acc + n }  // 抛出异常

// 2. 单元素集合直接返回该元素
val single = listOf(42).reduce { acc, n -> acc + n }
println(single)  // 42,operation 不会执行

// 3. 操作必须满足结合律
// 对于不可结合的操作,reduce 结果可能不正确
val numbers = listOf(1, 2, 3)

// 减法不是可结合操作
val subtractReduce = numbers.reduce { acc, n -> acc - n }
println("reduce 减法: $subtractReduce")  // (1 - 2) - 3 = -4

val subtractFold = numbers.fold(0) { acc, n -> acc - n }
println("fold 减法: $subtractFold")  // 0 - 1 - 2 - 3 = -6

3. foldreduce 的变体

foldRightreduceRight

// 从右向左折叠
val numbers = listOf(1, 2, 3, 4)

// foldRight: 从右向左,初始值在右侧
val rightFold = numbers.foldRight(0) { number, acc ->
    println("处理 $number, 累计值: $acc")
    acc + number
}
// 输出: 处理 4, 累计值: 0
//       处理 3, 累计值: 4
//       处理 2, 累计值: 7
//       处理 1, 累计值: 9
// 结果: 10

// reduceRight: 从右向左
val rightReduce = numbers.reduceRight { number, acc ->
    acc + number
}
println(rightReduce)  // 10

// 字符串反转
val str = "Kotlin"
val reversed = str.toList().foldRight(StringBuilder()) { char, acc ->
    acc.append(char)
}.toString()
println(reversed)  // "niltok"

runningFoldrunningReduce

// 返回每一步的中间结果
val numbers = listOf(1, 2, 3, 4, 5)

// runningFold: 包含初始值
val runningSum = numbers.runningFold(0) { acc, n -> acc + n }
println(runningSum)  // [0, 1, 3, 6, 10, 15]

// runningReduce: 从第一个元素开始
val runningProduct = numbers.runningReduce { acc, n -> acc * n }
println(runningProduct)  // [1, 2, 6, 24, 120]

// 实际应用:计算累积平均值
val cumulativeAverage = numbers.runningFold(0.0 to 0) { (sum, count), n ->
    (sum + n) to (count + 1)
}.map { (sum, count) -> sum / count }
println(cumulativeAverage)  // [NaN, 1.0, 1.5, 2.0, 2.5, 3.0]

4. foldIndexedreduceIndexed

// 带索引的折叠操作
val numbers = listOf(10, 20, 30, 40)

// foldIndexed: 带索引
val indexedFold = numbers.foldIndexed(0) { index, acc, number ->
    acc + number * index  // 每个数乘以它的索引
}
println(indexedFold)  // 0*10 + 1*20 + 2*30 + 3*40 = 200

// reduceIndexed: 带索引
val indexedReduce = numbers.reduceIndexed { index, acc, number ->
    if (index == 0) number  // 第一个元素直接作为初始值
    else acc + (number * (index + 1))
}
println(indexedReduce)  // 10 + 2*20 + 3*30 + 4*40 = 280

// 实际应用:加权求和
val scores = listOf(85, 90, 78, 92, 88)
val weights = listOf(0.1, 0.2, 0.15, 0.25, 0.3)

val weightedSum = scores.foldIndexed(0.0) { index, acc, score ->
    acc + score * weights[index]
}
println("加权总分: $weightedSum")

5. 实际应用场景

统计计算

data class DataPoint(val value: Double, val timestamp: Long)

val dataPoints = listOf(
    DataPoint(10.5, 1000),
    DataPoint(12.3, 2000),
    DataPoint(9.8, 3000),
    DataPoint(11.2, 4000),
    DataPoint(13.1, 5000)
)

// 使用 fold 计算多种统计量
val stats = dataPoints.fold(
    StatsAccumulator(min = Double.MAX_VALUE, max = Double.MIN_VALUE)
) { acc, point ->
    acc.copy(
        sum = acc.sum + point.value,
        count = acc.count + 1,
        min = minOf(acc.min, point.value),
        max = maxOf(acc.max, point.value),
        timestamps = acc.timestamps + point.timestamp
    )
}.let { acc ->
    mapOf(
        "平均值" to acc.sum / acc.count,
        "最小值" to acc.min,
        "最大值" to acc.max,
        "数据量" to acc.count,
        "时间范围" to (acc.timestamps.maxOrNull()!! - acc.timestamps.minOrNull()!!)
    )
}

data class StatsAccumulator(
    val sum: Double = 0.0,
    val count: Int = 0,
    val min: Double,
    val max: Double,
    val timestamps: List<Long> = emptyList()
)

println(stats)

数据转换和聚合

data class Order(
    val id: Int,
    val customerId: Int,
    val amount: Double,
    val items: List<OrderItem>
)

data class OrderItem(
    val productId: Int,
    val quantity: Int,
    val price: Double
)

val orders = listOf(
    Order(1, 101, 150.0, listOf(
        OrderItem(1, 2, 50.0),
        OrderItem(2, 1, 50.0)
    )),
    Order(2, 102, 75.0, listOf(
        OrderItem(3, 3, 25.0)
    )),
    Order(3, 101, 200.0, listOf(
        OrderItem(1, 1, 50.0),
        OrderItem(4, 3, 50.0)
    ))
)

// 聚合客户订单数据
val customerSummary = orders.fold(mutableMapOf<Int, CustomerData>()) { acc, order ->
    val customerData = acc.getOrPut(order.customerId) {
        CustomerData(order.customerId, 0.0, 0, mutableMapOf())
    }
    
    customerData.apply {
        totalAmount += order.amount
        orderCount += 1
        
        // 统计产品购买量
        order.items.forEach { item ->
            val current = productQuantities.getOrPut(item.productId) { 0 }
            productQuantities[item.productId] = current + item.quantity
        }
    }
    
    acc
}

data class CustomerData(
    val customerId: Int,
    var totalAmount: Double,
    var orderCount: Int,
    val productQuantities: MutableMap<Int, Int>
)

// 输出聚合结果
customerSummary.forEach { (customerId, data) ->
    println("客户 $customerId:")
    println("  订单数: ${data.orderCount}, 总金额: ${data.totalAmount}")
    println("  购买产品: ${data.productQuantities}")
}

构建复杂数据结构

// 构建树形结构
data class TreeNode(
    val name: String,
    val children: MutableList<TreeNode> = mutableListOf()
)

val paths = listOf(
    "root/home/user/documents",
    "root/home/user/downloads",
    "root/home/user/pictures/vacation",
    "root/var/log",
    "root/usr/bin"
)

// 使用 fold 构建文件树
val fileTree = paths.fold(TreeNode("root")) { root, path ->
    val parts = path.split('/').drop(1)  // 移除 "root"
    
    parts.fold(root) { currentNode, part ->
        val child = currentNode.children.find { it.name == part }
            ?: TreeNode(part).also { currentNode.children.add(it) }
        child
    }
    
    root
}

// 打印树结构
fun printTree(node: TreeNode, indent: String = "") {
    println("$indent${node.name}/")
    node.children.forEach { child ->
        printTree(child, "$indent  ")
    }
}

printTree(fileTree)

解析和处理文本

// 解析 CSV 数据
val csv = """
    name,age,city,score
    Alice,25,New York,85.5
    Bob,30,London,92.0
    Charlie,22,Paris,78.5
    Diana,35,Tokyo,88.0
""".trimIndent()

// 使用 fold 解析 CSV
val parsedData = csv.lineSequence().fold(ParsingState()) { state, line ->
    when {
        state.headers.isEmpty() -> {
            // 第一行是表头
            state.copy(headers = line.split(','))
        }
        line.isNotBlank() -> {
            // 数据行
            val values = line.split(',')
            val record = state.headers.zip(values).toMap()
            state.copy(records = state.records + record)
        }
        else -> state
    }
}

data class ParsingState(
    val headers: List<String> = emptyList(),
    val records: List<Map<String, String>> = emptyList()
)

// 分析数据
val analysis = parsedData.records.fold(AnalysisResult()) { result, record ->
    result.copy(
        totalCount = result.totalCount + 1,
        totalAge = result.totalAge + record["age"]?.toIntOrNull() ?: 0,
        totalScore = result.totalScore + record["score"]?.toDoubleOrNull() ?: 0.0,
        cities = result.cities + record["city"].orEmpty(),
        nameLengths = result.nameLengths + (record["name"]?.length ?: 0)
    )
}.let { result ->
    mapOf(
        "平均年龄" to result.totalAge.toDouble() / result.totalCount,
        "平均分数" to result.totalScore / result.totalCount,
        "城市数量" to result.cities.distinct().size,
        "平均姓名长度" to result.nameLengths.average()
    )
}

data class AnalysisResult(
    val totalCount: Int = 0,
    val totalAge: Int = 0,
    val totalScore: Double = 0.0,
    val cities: List<String> = emptyList(),
    val nameLengths: List<Int> = emptyList()
)

println(analysis)

6. 性能优化和注意事项

性能考虑

// 1. 对于大型集合,考虑使用序列
val largeList = (1..1_000_000).toList()

// 常规 fold
val regularTime = measureTimeMillis {
    largeList.fold(0L) { acc, n -> acc + n }
}

// 序列 fold
val sequenceTime = measureTimeMillis {
    largeList.asSequence().fold(0L) { acc, n -> acc + n }
}

println("常规: ${regularTime}ms, 序列: ${sequenceTime}ms")

// 2. 避免在 fold 中创建大量临时对象
// 不好的做法
val badFold = largeList.fold("") { acc, n ->
    acc + n.toString()  // 每次迭代都创建新字符串
}

// 好的做法:使用 StringBuilder
val goodFold = largeList.fold(StringBuilder()) { acc, n ->
    acc.append(n)
}.toString()

常见错误

// 1. 在空集合上使用 reduce
// emptyList<Int>().reduce { acc, n -> acc + n }  // 抛出异常

// 2. 使用不可结合的操作
val numbers = listOf(1, 2, 3)
val wrong = numbers.reduce { acc, n -> acc - n }  // 结果可能不符合预期

// 3. 在 fold 中修改可变状态(应使用不可变方式)
// 不好的做法
val mutableFold = numbers.fold(mutableListOf<Int>()) { acc, n ->
    acc.add(n * 2)  // 副作用
    acc
}

// 好的做法
val immutableFold = numbers.fold(emptyList<Int>()) { acc, n ->
    acc + (n * 2)  // 创建新列表
}

7. 自定义扩展函数

创建专用的 fold 函数

// 加权平均值
fun List<Double>.weightedAverage(weights: List<Double>): Double {
    require(this.size == weights.size) { "数据数量和权重数量必须相等" }
    
    return this.zip(weights).fold(0.0 to 0.0) { (sum, weightSum), (value, weight) ->
        (sum + value * weight) to (weightSum + weight)
    }.let { (sum, weightSum) -> sum / weightSum }
}

// 使用
val values = listOf(85.0, 90.0, 78.0, 92.0)
val weights = listOf(0.2, 0.3, 0.25, 0.25)
println(values.weightedAverage(weights))

// 分组统计
fun <T, K> List<T>.groupAndCount(keySelector: (T) -> K): Map<K, Int> {
    return this.fold(mutableMapOf<K, Int>()) { map, item ->
        val key = keySelector(item)
        map[key] = map.getOrDefault(key, 0) + 1
        map
    }
}

// 使用
val words = listOf("apple", "banana", "apple", "cherry", "banana", "apple")
val counts = words.groupAndCount { it }
println(counts)  // {apple=3, banana=2, cherry=1}

链式 fold 操作

// 管道式数据处理
class DataPipeline<T>(private val data: List<T>) {
    fun <R> fold(initial: R, operation: (R, T) -> R): DataPipeline<T> {
        val result = data.fold(initial, operation)
        println("折叠结果: $result")
        return this
    }
    
    fun filter(predicate: (T) -> Boolean): DataPipeline<T> {
        return DataPipeline(data.filter(predicate))
    }
    
    fun map<R>(transform: (T) -> R): DataPipeline<R> {
        return DataPipeline(data.map(transform))
    }
}

// 使用
val pipeline = DataPipeline(listOf(1, 2, 3, 4, 5))
    .filter { it % 2 == 0 }
    .fold(0) { acc, n -> acc + n }
    .map { it * 2 }
    .fold(emptyList<Int>()) { acc, n -> acc + n }

8. scan 的对比

// scan 是 runningFold 的别名,返回所有中间结果
val numbers = listOf(1, 2, 3, 4, 5)

// scan / runningFold
val intermediateResults = numbers.scan(0) { acc, n -> acc + n }
println("中间结果: $intermediateResults")  // [0, 1, 3, 6, 10, 15]

// fold 只返回最终结果
val finalResult = numbers.fold(0) { acc, n -> acc + n }
println("最终结果: $finalResult")  // 15

// 实际应用:计算移动平均值
val windowSize = 3
val movingAverages = numbers.scan(emptyList<Int>()) { window, n ->
    (window + n).takeLast(windowSize)
}.drop(1)  // 丢弃初始空列表
 .map { window -> window.average() }

println("移动平均值: $movingAverages")

总结

foldreduce 是 Kotlin 中强大的合并函数:

fold 的特点:

  • 需要提供初始值
  • 可以处理空集合
  • 更通用,适用于各种场景
  • 支持从左右两个方向(foldRight

reduce 的特点:

  • 使用第一个元素作为初始值
  • 不能处理空集合(会抛出异常)
  • 更简洁,适用于非空集合

选择建议:

  1. 安全性优先:使用 fold(特别是处理可能为空的集合时)
  2. 简洁性优先:使用 reduce(确保集合非空)
  3. 需要中间结果:使用 runningFold/scanrunningReduce
  4. 需要索引:使用 foldIndexedreduceIndexed

最佳实践:

  • 使用不可变累加器避免副作用
  • 对于大型集合,考虑使用序列(asSequence()
  • 复杂的聚合操作可以封装为自定义扩展函数
  • 理解操作的结合律,避免使用不可结合的操作

这些函数是函数式编程的核心,掌握它们可以让你写出更简洁、表达力更强的代码。