Kotlin 合并函数 fold 和 reduce
在 Kotlin 中,fold 和 reduce 是用于将集合元素合并为单个值的函数。它们被称为"折叠"操作,是函数式编程中的核心概念。
1. fold 函数详解
基本语法
// fold 函数签名
fun <T, R> Iterable<T>.fold(
initial: R, // 初始值
operation: (acc: R, T) -> R // 合并操作
): R
基础示例
val numbers = listOf(1, 2, 3, 4, 5)
// 求和
val sum = numbers.fold(0) { acc, number ->
acc + number
}
println(sum) // 15
// 求乘积
val product = numbers.fold(1) { acc, number ->
acc * number
}
println(product) // 120
// 字符串连接
val strings = listOf("a", "b", "c", "d")
val concatenated = strings.fold("") { acc, str ->
acc + str
}
println(concatenated) // "abcd"
// 复杂类型:计算平均值
val average = numbers.fold(0.0 to 0) { (sum, count), number ->
sum + number to count + 1
}.let { (sum, count) -> sum / count }
println(average) // 3.0
2. reduce 函数详解
基本语法
// reduce 函数签名
fun <S, T : S> Iterable<T>.reduce(
operation: (acc: S, T) -> S // 合并操作
): S
与 fold 的区别
val numbers = listOf(1, 2, 3, 4, 5)
// fold: 需要提供初始值
val sumWithFold = numbers.fold(0) { acc, n -> acc + n }
// reduce: 使用第一个元素作为初始值
val sumWithReduce = numbers.reduce { acc, n -> acc + n }
println("fold: $sumWithFold, reduce: $sumWithReduce") // 都是15
// 重要区别:空集合的处理
val emptyList = emptyList<Int>()
try {
emptyList.reduce { acc, n -> acc + n } // 抛出异常:UnsupportedOperationException
} catch (e: UnsupportedOperationException) {
println("reduce 不能用于空集合")
}
val safeResult = emptyList.fold(0) { acc, n -> acc + n } // 安全返回0
println("fold 空集合: $safeResult")
reduce 的注意事项
// 1. 空集合会抛出异常
// emptyList<Int>().reduce { acc, n -> acc + n } // 抛出异常
// 2. 单元素集合直接返回该元素
val single = listOf(42).reduce { acc, n -> acc + n }
println(single) // 42,operation 不会执行
// 3. 操作必须满足结合律
// 对于不可结合的操作,reduce 结果可能不正确
val numbers = listOf(1, 2, 3)
// 减法不是可结合操作
val subtractReduce = numbers.reduce { acc, n -> acc - n }
println("reduce 减法: $subtractReduce") // (1 - 2) - 3 = -4
val subtractFold = numbers.fold(0) { acc, n -> acc - n }
println("fold 减法: $subtractFold") // 0 - 1 - 2 - 3 = -6
3. fold 和 reduce 的变体
foldRight 和 reduceRight
// 从右向左折叠
val numbers = listOf(1, 2, 3, 4)
// foldRight: 从右向左,初始值在右侧
val rightFold = numbers.foldRight(0) { number, acc ->
println("处理 $number, 累计值: $acc")
acc + number
}
// 输出: 处理 4, 累计值: 0
// 处理 3, 累计值: 4
// 处理 2, 累计值: 7
// 处理 1, 累计值: 9
// 结果: 10
// reduceRight: 从右向左
val rightReduce = numbers.reduceRight { number, acc ->
acc + number
}
println(rightReduce) // 10
// 字符串反转
val str = "Kotlin"
val reversed = str.toList().foldRight(StringBuilder()) { char, acc ->
acc.append(char)
}.toString()
println(reversed) // "niltok"
runningFold 和 runningReduce
// 返回每一步的中间结果
val numbers = listOf(1, 2, 3, 4, 5)
// runningFold: 包含初始值
val runningSum = numbers.runningFold(0) { acc, n -> acc + n }
println(runningSum) // [0, 1, 3, 6, 10, 15]
// runningReduce: 从第一个元素开始
val runningProduct = numbers.runningReduce { acc, n -> acc * n }
println(runningProduct) // [1, 2, 6, 24, 120]
// 实际应用:计算累积平均值
val cumulativeAverage = numbers.runningFold(0.0 to 0) { (sum, count), n ->
(sum + n) to (count + 1)
}.map { (sum, count) -> sum / count }
println(cumulativeAverage) // [NaN, 1.0, 1.5, 2.0, 2.5, 3.0]
4. foldIndexed 和 reduceIndexed
// 带索引的折叠操作
val numbers = listOf(10, 20, 30, 40)
// foldIndexed: 带索引
val indexedFold = numbers.foldIndexed(0) { index, acc, number ->
acc + number * index // 每个数乘以它的索引
}
println(indexedFold) // 0*10 + 1*20 + 2*30 + 3*40 = 200
// reduceIndexed: 带索引
val indexedReduce = numbers.reduceIndexed { index, acc, number ->
if (index == 0) number // 第一个元素直接作为初始值
else acc + (number * (index + 1))
}
println(indexedReduce) // 10 + 2*20 + 3*30 + 4*40 = 280
// 实际应用:加权求和
val scores = listOf(85, 90, 78, 92, 88)
val weights = listOf(0.1, 0.2, 0.15, 0.25, 0.3)
val weightedSum = scores.foldIndexed(0.0) { index, acc, score ->
acc + score * weights[index]
}
println("加权总分: $weightedSum")
5. 实际应用场景
统计计算
data class DataPoint(val value: Double, val timestamp: Long)
val dataPoints = listOf(
DataPoint(10.5, 1000),
DataPoint(12.3, 2000),
DataPoint(9.8, 3000),
DataPoint(11.2, 4000),
DataPoint(13.1, 5000)
)
// 使用 fold 计算多种统计量
val stats = dataPoints.fold(
StatsAccumulator(min = Double.MAX_VALUE, max = Double.MIN_VALUE)
) { acc, point ->
acc.copy(
sum = acc.sum + point.value,
count = acc.count + 1,
min = minOf(acc.min, point.value),
max = maxOf(acc.max, point.value),
timestamps = acc.timestamps + point.timestamp
)
}.let { acc ->
mapOf(
"平均值" to acc.sum / acc.count,
"最小值" to acc.min,
"最大值" to acc.max,
"数据量" to acc.count,
"时间范围" to (acc.timestamps.maxOrNull()!! - acc.timestamps.minOrNull()!!)
)
}
data class StatsAccumulator(
val sum: Double = 0.0,
val count: Int = 0,
val min: Double,
val max: Double,
val timestamps: List<Long> = emptyList()
)
println(stats)
数据转换和聚合
data class Order(
val id: Int,
val customerId: Int,
val amount: Double,
val items: List<OrderItem>
)
data class OrderItem(
val productId: Int,
val quantity: Int,
val price: Double
)
val orders = listOf(
Order(1, 101, 150.0, listOf(
OrderItem(1, 2, 50.0),
OrderItem(2, 1, 50.0)
)),
Order(2, 102, 75.0, listOf(
OrderItem(3, 3, 25.0)
)),
Order(3, 101, 200.0, listOf(
OrderItem(1, 1, 50.0),
OrderItem(4, 3, 50.0)
))
)
// 聚合客户订单数据
val customerSummary = orders.fold(mutableMapOf<Int, CustomerData>()) { acc, order ->
val customerData = acc.getOrPut(order.customerId) {
CustomerData(order.customerId, 0.0, 0, mutableMapOf())
}
customerData.apply {
totalAmount += order.amount
orderCount += 1
// 统计产品购买量
order.items.forEach { item ->
val current = productQuantities.getOrPut(item.productId) { 0 }
productQuantities[item.productId] = current + item.quantity
}
}
acc
}
data class CustomerData(
val customerId: Int,
var totalAmount: Double,
var orderCount: Int,
val productQuantities: MutableMap<Int, Int>
)
// 输出聚合结果
customerSummary.forEach { (customerId, data) ->
println("客户 $customerId:")
println(" 订单数: ${data.orderCount}, 总金额: ${data.totalAmount}")
println(" 购买产品: ${data.productQuantities}")
}
构建复杂数据结构
// 构建树形结构
data class TreeNode(
val name: String,
val children: MutableList<TreeNode> = mutableListOf()
)
val paths = listOf(
"root/home/user/documents",
"root/home/user/downloads",
"root/home/user/pictures/vacation",
"root/var/log",
"root/usr/bin"
)
// 使用 fold 构建文件树
val fileTree = paths.fold(TreeNode("root")) { root, path ->
val parts = path.split('/').drop(1) // 移除 "root"
parts.fold(root) { currentNode, part ->
val child = currentNode.children.find { it.name == part }
?: TreeNode(part).also { currentNode.children.add(it) }
child
}
root
}
// 打印树结构
fun printTree(node: TreeNode, indent: String = "") {
println("$indent${node.name}/")
node.children.forEach { child ->
printTree(child, "$indent ")
}
}
printTree(fileTree)
解析和处理文本
// 解析 CSV 数据
val csv = """
name,age,city,score
Alice,25,New York,85.5
Bob,30,London,92.0
Charlie,22,Paris,78.5
Diana,35,Tokyo,88.0
""".trimIndent()
// 使用 fold 解析 CSV
val parsedData = csv.lineSequence().fold(ParsingState()) { state, line ->
when {
state.headers.isEmpty() -> {
// 第一行是表头
state.copy(headers = line.split(','))
}
line.isNotBlank() -> {
// 数据行
val values = line.split(',')
val record = state.headers.zip(values).toMap()
state.copy(records = state.records + record)
}
else -> state
}
}
data class ParsingState(
val headers: List<String> = emptyList(),
val records: List<Map<String, String>> = emptyList()
)
// 分析数据
val analysis = parsedData.records.fold(AnalysisResult()) { result, record ->
result.copy(
totalCount = result.totalCount + 1,
totalAge = result.totalAge + record["age"]?.toIntOrNull() ?: 0,
totalScore = result.totalScore + record["score"]?.toDoubleOrNull() ?: 0.0,
cities = result.cities + record["city"].orEmpty(),
nameLengths = result.nameLengths + (record["name"]?.length ?: 0)
)
}.let { result ->
mapOf(
"平均年龄" to result.totalAge.toDouble() / result.totalCount,
"平均分数" to result.totalScore / result.totalCount,
"城市数量" to result.cities.distinct().size,
"平均姓名长度" to result.nameLengths.average()
)
}
data class AnalysisResult(
val totalCount: Int = 0,
val totalAge: Int = 0,
val totalScore: Double = 0.0,
val cities: List<String> = emptyList(),
val nameLengths: List<Int> = emptyList()
)
println(analysis)
6. 性能优化和注意事项
性能考虑
// 1. 对于大型集合,考虑使用序列
val largeList = (1..1_000_000).toList()
// 常规 fold
val regularTime = measureTimeMillis {
largeList.fold(0L) { acc, n -> acc + n }
}
// 序列 fold
val sequenceTime = measureTimeMillis {
largeList.asSequence().fold(0L) { acc, n -> acc + n }
}
println("常规: ${regularTime}ms, 序列: ${sequenceTime}ms")
// 2. 避免在 fold 中创建大量临时对象
// 不好的做法
val badFold = largeList.fold("") { acc, n ->
acc + n.toString() // 每次迭代都创建新字符串
}
// 好的做法:使用 StringBuilder
val goodFold = largeList.fold(StringBuilder()) { acc, n ->
acc.append(n)
}.toString()
常见错误
// 1. 在空集合上使用 reduce
// emptyList<Int>().reduce { acc, n -> acc + n } // 抛出异常
// 2. 使用不可结合的操作
val numbers = listOf(1, 2, 3)
val wrong = numbers.reduce { acc, n -> acc - n } // 结果可能不符合预期
// 3. 在 fold 中修改可变状态(应使用不可变方式)
// 不好的做法
val mutableFold = numbers.fold(mutableListOf<Int>()) { acc, n ->
acc.add(n * 2) // 副作用
acc
}
// 好的做法
val immutableFold = numbers.fold(emptyList<Int>()) { acc, n ->
acc + (n * 2) // 创建新列表
}
7. 自定义扩展函数
创建专用的 fold 函数
// 加权平均值
fun List<Double>.weightedAverage(weights: List<Double>): Double {
require(this.size == weights.size) { "数据数量和权重数量必须相等" }
return this.zip(weights).fold(0.0 to 0.0) { (sum, weightSum), (value, weight) ->
(sum + value * weight) to (weightSum + weight)
}.let { (sum, weightSum) -> sum / weightSum }
}
// 使用
val values = listOf(85.0, 90.0, 78.0, 92.0)
val weights = listOf(0.2, 0.3, 0.25, 0.25)
println(values.weightedAverage(weights))
// 分组统计
fun <T, K> List<T>.groupAndCount(keySelector: (T) -> K): Map<K, Int> {
return this.fold(mutableMapOf<K, Int>()) { map, item ->
val key = keySelector(item)
map[key] = map.getOrDefault(key, 0) + 1
map
}
}
// 使用
val words = listOf("apple", "banana", "apple", "cherry", "banana", "apple")
val counts = words.groupAndCount { it }
println(counts) // {apple=3, banana=2, cherry=1}
链式 fold 操作
// 管道式数据处理
class DataPipeline<T>(private val data: List<T>) {
fun <R> fold(initial: R, operation: (R, T) -> R): DataPipeline<T> {
val result = data.fold(initial, operation)
println("折叠结果: $result")
return this
}
fun filter(predicate: (T) -> Boolean): DataPipeline<T> {
return DataPipeline(data.filter(predicate))
}
fun map<R>(transform: (T) -> R): DataPipeline<R> {
return DataPipeline(data.map(transform))
}
}
// 使用
val pipeline = DataPipeline(listOf(1, 2, 3, 4, 5))
.filter { it % 2 == 0 }
.fold(0) { acc, n -> acc + n }
.map { it * 2 }
.fold(emptyList<Int>()) { acc, n -> acc + n }
8. 与 scan 的对比
// scan 是 runningFold 的别名,返回所有中间结果
val numbers = listOf(1, 2, 3, 4, 5)
// scan / runningFold
val intermediateResults = numbers.scan(0) { acc, n -> acc + n }
println("中间结果: $intermediateResults") // [0, 1, 3, 6, 10, 15]
// fold 只返回最终结果
val finalResult = numbers.fold(0) { acc, n -> acc + n }
println("最终结果: $finalResult") // 15
// 实际应用:计算移动平均值
val windowSize = 3
val movingAverages = numbers.scan(emptyList<Int>()) { window, n ->
(window + n).takeLast(windowSize)
}.drop(1) // 丢弃初始空列表
.map { window -> window.average() }
println("移动平均值: $movingAverages")
总结
fold 和 reduce 是 Kotlin 中强大的合并函数:
fold 的特点:
- 需要提供初始值
- 可以处理空集合
- 更通用,适用于各种场景
- 支持从左右两个方向(
foldRight)
reduce 的特点:
- 使用第一个元素作为初始值
- 不能处理空集合(会抛出异常)
- 更简洁,适用于非空集合
选择建议:
- 安全性优先:使用
fold(特别是处理可能为空的集合时) - 简洁性优先:使用
reduce(确保集合非空) - 需要中间结果:使用
runningFold/scan或runningReduce - 需要索引:使用
foldIndexed或reduceIndexed
最佳实践:
- 使用不可变累加器避免副作用
- 对于大型集合,考虑使用序列(
asSequence()) - 复杂的聚合操作可以封装为自定义扩展函数
- 理解操作的结合律,避免使用不可结合的操作
这些函数是函数式编程的核心,掌握它们可以让你写出更简洁、表达力更强的代码。