2-2-28 快速掌握Kotlin-变换函数flatMap

30 阅读5分钟

Kotlin 中的 flatMap 函数详解

flatMap 是 Kotlin 函数式编程中非常重要的一个变换函数,它结合了 mapflatten 的功能,用于处理嵌套集合结构。

1. 基本概念

1.1 flatMap 的定义

// 官方定义
inline fun <T, R> Iterable<T>.flatMap(
    transform: (T) -> Iterable<R>
): List<R>

// 核心思想:map + flatten

1.2 工作原理示意图

原始数据: [A, B, C]
transform: A -> [A1, A2], B -> [B1], C -> [C1, C2, C3]
map 结果: [[A1, A2], [B1], [C1, C2, C3]]
flatMap 结果: [A1, A2, B1, C1, C2, C3]

2. 基本用法

2.1 简单示例

fun main() {
    val words = listOf("Hello", "World")
    
    // map 产生嵌套列表
    val mapped = words.map { it.toList() }  // [[H, e, l, l, o], [W, o, r, l, d]]
    println("map 结果: $mapped")
    
    // flatMap 展平结果
    val flatMapped = words.flatMap { it.toList() }  // [H, e, l, l, o, W, o, r, l, d]
    println("flatMap 结果: $flatMapped")
    
    // 等价于 map + flatten
    val mapThenFlatten = words.map { it.toList() }.flatten()
    println("map + flatten 结果: $mapThenFlatten")
}

2.2 不同类型集合的处理

fun main() {
    val numbers = listOf(1, 2, 3, 4)
    
    // 1. 每个数字生成一个列表
    val result1 = numbers.flatMap { listOf(it, it * 10) }
    println(result1)  // [1, 10, 2, 20, 3, 30, 4, 40]
    
    // 2. 生成空列表(过滤效果)
    val result2 = numbers.flatMap { 
        if (it % 2 == 0) listOf(it) else emptyList() 
    }
    println(result2)  // [2, 4] - 相当于 filter
    
    // 3. 生成可变数量的元素
    val result3 = numbers.flatMap { number ->
        (1..number).map { "Item$it" }
    }
    println(result3)
    // [Item1, Item1, Item2, Item1, Item2, Item3, Item1, Item2, Item3, Item4]
}

3. 实际应用场景

3.1 处理嵌套数据结构

data class Department(
    val name: String,
    val employees: List<Employee>
)

data class Employee(
    val name: String,
    val skills: List<String>
)

fun main() {
    val departments = listOf(
        Department("Engineering", listOf(
            Employee("Alice", listOf("Kotlin", "Java")),
            Employee("Bob", listOf("Python", "Go"))
        )),
        Department("Design", listOf(
            Employee("Charlie", listOf("Figma", "Sketch"))
        ))
    )
    
    // 获取所有员工
    val allEmployees = departments.flatMap { it.employees }
    println("所有员工: ${allEmployees.map { it.name }}")
    // 所有员工: [Alice, Bob, Charlie]
    
    // 获取所有技能(去重)
    val allSkills = departments
        .flatMap { it.employees }
        .flatMap { it.skills }
        .toSet()
    println("所有技能: $allSkills")
    // 所有技能: [Kotlin, Java, Python, Go, Figma, Sketch]
    
    // 统计每个部门的技能数量
    val skillsByDept = departments.associate { dept ->
        dept.name to dept.employees.flatMap { it.skills }.size
    }
    println("部门技能统计: $skillsByDept")
    // 部门技能统计: {Engineering=4, Design=2}
}

3.2 解析和转换数据

fun main() {
    // CSV 数据解析
    val csvData = """
        1,Alice,Kotlin,Java
        2,Bob,Python,
        3,Charlie,Kotlin,Python,Go
    """.trimIndent()
    
    val allSkills = csvData.lineSequence()
        .filter { it.isNotBlank() }
        .flatMap { line ->
            val parts = line.split(",")
            if (parts.size > 2) {
                parts.drop(2).filter { it.isNotBlank() }
            } else {
                emptyList()
            }
        }
        .toSet()
    
    println("所有技术栈: $allSkills")  // [Kotlin, Java, Python, Go]
    
    // 解析 JSON 类似结构
    val userActivities = listOf(
        mapOf("user" to "Alice", "actions" to listOf("login", "view", "logout")),
        mapOf("user" to "Bob", "actions" to listOf("login", "edit", "save")),
        mapOf("user" to "Alice", "actions" to listOf("login", "delete"))
    )
    
    val allActions = userActivities.flatMap { it["actions"] as List<String> }
    val actionCounts = allActions.groupingBy { it }.eachCount()
    
    println("活动统计: $actionCounts")
    // 活动统计: {login=3, view=1, logout=1, edit=1, save=1, delete=1}
}

3.3 组合多个数据源

data class Product(
    val id: Int,
    val name: String,
    val categoryIds: List<Int>
)

data class Category(
    val id: Int,
    val name: String
)

fun main() {
    val products = listOf(
        Product(1, "Laptop", listOf(1, 2)),
        Product(2, "Mouse", listOf(2)),
        Product(3, "Monitor", listOf(1, 2, 3))
    )
    
    val categories = listOf(
        Category(1, "Electronics"),
        Category(2, "Computer"),
        Category(3, "Display")
    )
    
    // 获取产品及其类别的所有组合
    val productCategoryPairs = products.flatMap { product ->
        product.categoryIds.map { categoryId ->
            val category = categories.find { it.id == categoryId }
            product.name to (category?.name ?: "Unknown")
        }
    }
    
    println("产品-类别关联:")
    productCategoryPairs.forEach { (product, category) ->
        println("  $product -> $category")
    }
    // 输出:
    // Laptop -> Electronics
    // Laptop -> Computer
    // Mouse -> Computer
    // Monitor -> Electronics
    // Monitor -> Computer
    // Monitor -> Display
    
    // 按类别分组产品
    val productsByCategory = productCategoryPairs
        .groupBy { it.second }
        .mapValues { (_, pairs) -> pairs.map { it.first } }
    
    println("\n按类别分组的产品:")
    productsByCategory.forEach { (category, products) ->
        println("  $category: $products")
    }
}

4. 与相关函数的对比

4.1 flatMap vs map vs flatten

fun main() {
    val numbers = listOf(1, 2, 3)
    
    // map: 一对一的转换
    val mapped = numbers.map { listOf(it, it * 2) }
    println("map: $mapped")  // [[1, 2], [2, 4], [3, 6]]
    
    // flatten: 展平嵌套列表
    val flattened = mapped.flatten()
    println("flatten: $flattened")  // [1, 2, 2, 4, 3, 6]
    
    // flatMap: map + flatten 一步完成
    val flatMapped = numbers.flatMap { listOf(it, it * 2) }
    println("flatMap: $flatMapped")  // [1, 2, 2, 4, 3, 6]
    
    // 性能比较
    val largeList = (1..1000000).toList()
    
    val time1 = measureTimeMillis {
        val result = largeList.map { listOf(it) }.flatten()
    }
    
    val time2 = measureTimeMillis {
        val result = largeList.flatMap { listOf(it) }
    }
    
    println("map+flatten 时间: ${time1}ms")
    println("flatMap 时间: ${time2}ms")
    // flatMap 通常更高效,因为它避免了创建中间集合
}

4.2 flatMap vs mapNotNull

fun main() {
    val numbers = listOf(1, 2, 3, 4, 5)
    
    // mapNotNull: 过滤 null,但保持一对一关系
    val mapNotNullResult = numbers.mapNotNull { 
        if (it % 2 == 0) it * 2 else null 
    }
    println("mapNotNull: $mapNotNullResult")  // [4, 8]
    
    // flatMap: 可以过滤并产生多个元素
    val flatMapResult = numbers.flatMap { 
        if (it % 2 == 0) listOf(it, it * 2) else emptyList() 
    }
    println("flatMap: $flatMapResult")  // [2, 4, 4, 8]
    
    // 使用 flatMap 实现 mapNotNull 的效果
    val flatMapAsMapNotNull = numbers.flatMap { 
        if (it % 2 == 0) listOf(it * 2) else emptyList() 
    }
    println("flatMap 作为 mapNotNull: $flatMapAsMapNotNull")  // [4, 8]
}

5. 高级用法

5.1 在序列中使用 flatMap

fun main() {
    // 使用序列进行惰性求值
    val result = (1..5).asSequence()
        .flatMap { number ->
            // 每个数字生成一个序列
            generateSequence(1) { it + 1 }
                .map { "N${number}_$it" }
                .take(3)  // 每个数字只取前3个
        }
        .toList()
    
    println(result)
    // [N1_1, N1_2, N1_3, N2_1, N2_2, N2_3, N3_1, N3_2, N3_3, N4_1, N4_2, N4_3, N5_1, N5_2, N5_3]
    
    // 无限序列处理
    val infinite = generateSequence(1) { it + 1 }
        .flatMap { n ->
            (1..n).asSequence().map { "$n-$it" }
        }
        .take(10)
        .toList()
    
    println("无限序列的前10个: $infinite")
    // [1-1, 2-1, 2-2, 3-1, 3-2, 3-3, 4-1, 4-2, 4-3, 4-4]
}

5.2 处理可选值(Option/Either)

sealed class Option<out T> {
    data class Some<T>(val value: T) : Option<T>()
    object None : Option<Nothing>()
    
    fun <R> flatMap(f: (T) -> Option<R>): Option<R> = when (this) {
        is Some -> f(value)
        None -> None
    }
}

fun parseNumber(str: String): Option<Int> {
    return try {
        Option.Some(str.toInt())
    } catch (e: NumberFormatException) {
        Option.None
    }
}

fun divide(a: Int, b: Int): Option<Int> {
    return if (b == 0) Option.None else Option.Some(a / b)
}

fun main() {
    val result = parseNumber("10")
        .flatMap { a -> parseNumber("2").flatMap { b -> divide(a, b) } }
    
    when (result) {
        is Option.Some -> println("结果: ${result.value}")  // 结果: 5
        Option.None -> println("计算失败")
    }
    
    // 使用标准库的 flatMap 处理可空类型
    val nullableResult: Int? = "10".toIntOrNull()
        ?.let { a -> "2".toIntOrNull()?.let { b -> if (b != 0) a / b else null } }
    
    println("可空结果: $nullableResult")  // 5
}

5.3 实现笛卡尔积

fun main() {
    val colors = listOf("Red", "Green", "Blue")
    val sizes = listOf("S", "M", "L")
    val materials = listOf("Cotton", "Polyester")
    
    // 生成所有组合(笛卡尔积)
    val allCombinations = colors.flatMap { color ->
        sizes.flatMap { size ->
            materials.map { material ->
                Triple(color, size, material)
            }
        }
    }
    
    println("所有商品组合 (${allCombinations.size} 种):")
    allCombinations.forEachIndexed { index, (color, size, material) ->
        println("${index + 1}. $color $size $material")
    }
    
    // 更复杂的嵌套 flatMap
    val categories = listOf(
        listOf("上衣", "裤子"),
        listOf("男装", "女装"),
        listOf("夏季", "冬季")
    )
    
    val allCategoryCombinations = categories
        .fold(listOf(listOf<String>())) { acc, list ->
            acc.flatMap { combination ->
                list.map { item -> combination + item }
            }
        }
    
    println("\n所有分类组合:")
    allCategoryCombinations.forEach { println(it) }
    // [上衣, 男装, 夏季], [上衣, 男装, 冬季], [上衣, 女装, 夏季], ...
}

6. 性能优化和注意事项

6.1 避免不必要的嵌套

fun main() {
    val data = listOf(
        listOf(1, 2, 3),
        listOf(4, 5),
        listOf(6, 7, 8, 9)
    )
    
    // ❌ 不推荐:双重 flatMap
    val badResult = data.flatMap { innerList ->
        innerList.flatMap { listOf(it, it * 2) }
    }
    
    // ✅ 推荐:先展平再处理
    val goodResult = data.flatten().flatMap { listOf(it, it * 2) }
    
    println("结果相同: ${badResult == goodResult}")  // true
    
    // 性能测试
    val largeNested = List(10000) { List(100) { it } }
    
    val time1 = measureTimeMillis {
        largeNested.flatMap { inner -> inner.flatMap { listOf(it, it * 2) } }
    }
    
    val time2 = measureTimeMillis {
        largeNested.flatten().flatMap { listOf(it, it * 2) }
    }
    
    println("双重 flatMap: ${time1}ms")
    println("先 flatten 再 flatMap: ${time2}ms")
    // 第二种方式通常更快
}

6.2 使用 asSequence() 处理大数据

fun processLargeDataset() {
    val largeDataset = (1..1_000_000).toList()
    
    // 使用 List(急切实例化)
    val listTime = measureTimeMillis {
        val result = largeDataset
            .flatMap { (1..it).take(3) }  // 每个元素生成最多3个元素
            .filter { it % 2 == 0 }
            .take(1000)
            .toList()
    }
    
    // 使用 Sequence(惰性求值)
    val sequenceTime = measureTimeMillis {
        val result = largeDataset
            .asSequence()
            .flatMap { (1..it).take(3).asSequence() }
            .filter { it % 2 == 0 }
            .take(1000)
            .toList()
    }
    
    println("List 处理时间: ${listTime}ms")
    println("Sequence 处理时间: ${sequenceTime}ms")
    // Sequence 对于大数据集通常更高效
}

7. 自定义扩展函数

7.1 为特定类型创建 flatMap

// 为 Map 创建 flatMap
fun <K, V, R> Map<K, V>.flatMapValues(transform: (V) -> Iterable<R>): List<R> {
    return this.values.flatMap(transform)
}

// 为 Pair 创建 flatMap
fun <A, B, R> Pair<A, B>.flatMap(transform: (A, B) -> Iterable<R>): List<R> {
    return transform(first, second)
}

// 为 Result 类型创建 flatMap
sealed class Result<out T> {
    data class Success<out T>(val value: T) : Result<T>()
    data class Failure(val error: Throwable) : Result<Nothing>()
    
    fun <R> flatMap(transform: (T) -> Result<R>): Result<R> = when (this) {
        is Success -> transform(value)
        is Failure -> Failure(error)
    }
}

fun main() {
    // 使用自定义扩展
    val map = mapOf("a" to 1, "b" to 2, "c" to 3)
    val flattenedValues = map.flatMapValues { listOf(it, it * 2) }
    println(flattenedValues)  // [1, 2, 2, 4, 3, 6]
    
    val pair = Pair(10, 20)
    val pairResult = pair.flatMap { a, b -> listOf(a + b, a * b) }
    println(pairResult)  // [30, 200]
}

7.2 实现复杂的 flatMap 逻辑

// 递归 flatMap
fun <T, R> List<T>.recursiveFlatMap(transform: (T) -> Iterable<R>): List<R> {
    val result = mutableListOf<R>()
    
    fun process(item: T) {
        val transformed = transform(item)
        for (element in transformed) {
            // 如果元素可以继续转换,递归处理
            if (element is T) {
                process(element)
            } else {
                result.add(element)
            }
        }
    }
    
    for (item in this) {
        process(item)
    }
    
    return result
}

// 带索引的 flatMap
fun <T, R> Iterable<T>.flatMapIndexed(transform: (index: Int, T) -> Iterable<R>): List<R> {
    val result = mutableListOf<R>()
    var index = 0
    for (element in this) {
        val transformed = transform(index++, element)
        result.addAll(transformed)
    }
    return result
}

fun main() {
    val nested = listOf(
        listOf(1, 2),
        listOf(3, listOf(4, 5)),
        6
    )
    
    // 注意:这个示例需要类型安全处理,这里简化了
    println("带索引的 flatMap:")
    val indexed = listOf("a", "b", "c").flatMapIndexed { index, value ->
        listOf("$index-$value", "${value}${index}")
    }
    println(indexed)  // [0-a, a0, 1-b, b1, 2-c, c2]
}

8. 常见错误和调试技巧

fun debugFlatMap() {
    val data = listOf(1, 2, 3, 4, 5)
    
    // 错误示例:忘记返回 Iterable
    // val error = data.flatMap { it }  // 编译错误
    
    // 正确做法
    val correct = data.flatMap { listOf(it) }
    
    // 调试技巧:添加日志
    val result = data.flatMap { element ->
        println("处理元素: $element")
        val transformed = if (element % 2 == 0) {
            listOf(element, element * 2)
        } else {
            emptyList()
        }
        println("  生成: $transformed")
        transformed
    }
    
    println("最终结果: $result")
    // 输出:
    // 处理元素: 1
    //   生成: []
    // 处理元素: 2
    //   生成: [2, 4]
    // 处理元素: 3
    //   生成: []
    // 处理元素: 4
    //   生成: [4, 8]
    // 处理元素: 5
    //   生成: []
    // 最终结果: [2, 4, 4, 8]
}

总结

flatMap 的核心价值:

  1. 数据展平 - 处理嵌套集合结构
  2. 一对多映射 - 每个输入元素可以生成多个输出元素
  3. 过滤功能 - 通过返回 emptyList() 过滤元素
  4. 组合数据 - 连接多个数据源

使用建议:

  1. 优先使用 flatMap 而不是 map + flatten
  2. 对于大数据集使用 asSequence()
  3. 避免不必要的嵌套 flatMap 调用
  4. 在复杂转换中适当添加调试信息

适用场景:

  • 处理树形或图状数据结构
  • 解析嵌套格式(JSON、XML)
  • 生成笛卡尔积
  • 数据清洗和转换管道
  • 函数式错误处理链

flatMap 是函数式编程中的核心操作之一,掌握它可以帮助你写出更简洁、更强大的数据处理代码。