3-1-3 高阶函数-深入解析

18 阅读5分钟

Kotlin高级函数深入解析

一、函数类型与函数引用

1.1 函数类型声明

// 基础函数类型
val sum: (Int, Int) -> Int = { a, b -> a + b }

// 可空函数类型
var nullableFun: ((String) -> Unit)? = null

// 带接收者的函数类型
val toString: Int.() -> String = { this.toString() }

// 使用typealias简化复杂类型
typealias Predicate<T> = (T) -> Boolean
typealias ClickHandler = (View) -> Unit

1.2 函数引用

fun isEven(n: Int): Boolean = n % 2 == 0
fun String.isPalindrome(): Boolean = 
    this == this.reversed()

// 函数引用
val predicate: (Int) -> Boolean = ::isEven

// 扩展函数引用
val palindromeCheck: String.() -> Boolean = 
    String::isPalindrome

// 构造函数引用
data class Person(val name: String, val age: Int)
val createPerson: (String, Int) -> Person = ::Person

二、高阶函数设计模式

2.1 策略模式

class PaymentProcessor(
    private val validationStrategy: (Payment) -> Boolean,
    private val feeCalculation: (Payment) -> Double
) {
    fun process(payment: Payment): Result {
        return if (validationStrategy(payment)) {
            val fee = feeCalculation(payment)
            processWithFee(payment, fee)
        } else {
            Result.Failure("Validation failed")
        }
    }
}

// 使用
val processor = PaymentProcessor(
    validationStrategy = { it.amount > 0 },
    feeCalculation = { it.amount * 0.02 }
)

2.2 模板方法模式

inline fun <T> withRetry(
    maxRetries: Int = 3,
    delay: Long = 1000,
    shouldRetry: (Throwable) -> Boolean = { true },
    block: () -> T
): T {
    var lastException: Throwable? = null
    
    repeat(maxRetries) { attempt ->
        try {
            return block()
        } catch (e: Throwable) {
            lastException = e
            if (!shouldRetry(e) || attempt == maxRetries - 1) {
                throw e
            }
            Thread.sleep(delay)
        }
    }
    throw lastException!!
}

三、高级函数组合

3.1 函数组合操作符

// 自定义compose函数
infix fun <A, B, C> ((B) -> C).compose(f: (A) -> B): (A) -> C = 
    { a -> this(f(a)) }

infix fun <A, B, C> ((A) -> B).andThen(f: (B) -> C): (A) -> C = 
    { a -> f(this(a)) }

// 使用示例
val add5: (Int) -> Int = { it + 5 }
val multiplyBy2: (Int) -> Int = { it * 2 }

val composed = multiplyBy2 compose add5  // 先加5再乘2
val pipelined = add5 andThen multiplyBy2 // 先加5再乘2

println(composed(10))  // 输出: 30
println(pipelined(10)) // 输出: 30

3.2 柯里化与部分应用

// 柯里化
fun <A, B, C> curry(f: (A, B) -> C): (A) -> (B) -> C = 
    { a -> { b -> f(a, b) } }

// 反柯里化
fun <A, B, C> uncurry(f: (A) -> (B) -> C): (A, B) -> C = 
    { a, b -> f(a)(b) }

// 使用
val add: (Int, Int) -> Int = { a, b -> a + b }
val curriedAdd: (Int) -> (Int) -> Int = curry(add)
val add5 = curriedAdd(5)

println(add5(10)) // 输出: 15

// 部分应用函数
fun <A, B, C> partial1(f: (A, B) -> C, a: A): (B) -> C = 
    { b -> f(a, b) }

fun <A, B, C> partial2(f: (A, B) -> C, b: B): (A) -> C = 
    { a -> f(a, b) }

四、DSL构建与接收者

4.1 构建类型安全的DSL

class HtmlDsl {
    private val children = mutableListOf<Any>()
    
    fun div(block: DivDsl.() -> Unit) {
        children.add(DivDsl().apply(block))
    }
    
    fun span(text: String) {
        children.add("<span>$text</span>")
    }
    
    override fun toString(): String = 
        children.joinToString("\n")
}

class DivDsl {
    private val attributes = mutableMapOf<String, String>()
    private val content = StringBuilder()
    
    fun attr(key: String, value: String) {
        attributes[key] = value
    }
    
    fun text(content: String) {
        this.content.append(content)
    }
    
    override fun toString(): String {
        val attrs = attributes.entries
            .joinToString(" ") { "${it.key}=\"${it.value}\"" }
        return "<div $attrs>${content}</div>"
    }
}

// 使用DSL
val html = HtmlDsl().apply {
    div {
        attr("class", "container")
        text("Hello")
    }
    span("World")
}

4.2 带接收者的函数类型应用

// 事务处理DSL
class Transaction {
    private val operations = mutableListOf<() -> Unit>()
    
    fun execute(block: Transaction.() -> Unit) {
        try {
            begin()
            block()
            commit()
        } catch (e: Exception) {
            rollback()
            throw e
        }
    }
    
    fun operation(op: () -> Unit) {
        operations.add(op)
    }
    
    private fun begin() { println("开始事务") }
    private fun commit() { 
        operations.forEach { it() }
        println("提交事务")
    }
    private fun rollback() { println("回滚事务") }
}

// 数据库查询构建器
class QueryBuilder {
    private var select = ""
    private var from = ""
    private var where = ""
    
    infix fun String.eq(value: Any) = "$this = '$value'"
    
    fun select(vararg columns: String) {
        select = "SELECT ${columns.joinToString()}"
    }
    
    fun from(table: String) {
        from = "FROM $table"
    }
    
    fun where(condition: String) {
        where = "WHERE $condition"
    }
    
    fun where(block: QueryBuilder.() -> String) {
        where = "WHERE ${block()}"
    }
    
    fun build(): String = listOf(select, from, where)
        .filter { it.isNotEmpty() }
        .joinToString(" ")
}

五、函数式编程模式

5.1 Monad模式

// Result Monad
sealed class Result<out T> {
    data class Success<out T>(val value: T) : Result<T>()
    data class Failure(val error: Throwable) : Result<Nothing>()
    
    // Functor: map
    fun <R> map(transform: (T) -> R): Result<R> = when (this) {
        is Success -> Success(transform(value))
        is Failure -> this
    }
    
    // Monad: flatMap
    fun <R> flatMap(transform: (T) -> Result<R>): Result<R> = when (this) {
        is Success -> transform(value)
        is Failure -> this
    }
    
    // Applicative: apply
    fun <R> apply(fn: Result<(T) -> R>): Result<R> = when {
        this is Success && fn is Success -> Success(fn.value(this.value))
        this is Failure -> this
        fn is Failure -> fn
        else -> Failure(IllegalStateException())
    }
}

// 使用示例
fun parseNumber(s: String): Result<Int> = 
    try { Result.Success(s.toInt()) }
    catch (e: NumberFormatException) { Result.Failure(e) }

fun double(n: Int): Int = n * 2

val result = parseNumber("42")
    .map(::double)
    .flatMap { parseNumber(it.toString()) }

5.2 尾递归优化

// 使用tailrec进行尾递归优化
tailrec fun factorial(n: Int, acc: Int = 1): Int =
    if (n <= 1) acc else factorial(n - 1, n * acc)

// 尾递归遍历树
sealed class Tree<out T> {
    data class Node<out T>(
        val value: T,
        val left: Tree<T> = Leaf,
        val right: Tree<T> = Leaf
    ) : Tree<T>()
    object Leaf : Tree<Nothing>()
}

tailrec fun <T, R> Tree<T>.fold(
    initial: R,
    operation: (R, T) -> R
): R = when (this) {
    is Tree.Leaf -> initial
    is Tree.Node -> {
        val newAcc = operation(initial, value)
        // 尾递归调用较大的子树
        if (left.size > right.size) {
            left.fold(newAcc, operation).let { 
                right.fold(it, operation) 
            }
        } else {
            right.fold(newAcc, operation).let { 
                left.fold(it, operation) 
            }
        }
    }
}

六、性能优化与内联

6.1 内联函数高级用法

// crossinline和noinline
inline fun measureTime(
    crossinline block: () -> Unit,
    noinline onComplete: (Long) -> Unit
): Unit {
    val start = System.nanoTime()
    block()
    val duration = System.nanoTime() - start
    onComplete(duration)
}

// 具体化的类型参数
inline fun <reified T> createInstance(vararg args: Any): T {
    val constructor = T::class.java.constructors.firstOrNull()
        ?: throw IllegalArgumentException("No constructor found")
    return constructor.newInstance(*args) as T
}

// 内联属性
inline var <T> MutableList<T>.lastOrThrow: T
    get() = if (isEmpty()) throw NoSuchElementException() else last()
    set(value) {
        if (isEmpty()) add(value) else this[size - 1] = value
    }

6.2 避免内联的情况

// 不应该内联的情况
// 1. 大型函数体
inline fun largeFunction(block: () -> Unit) {
    // 大量代码...
    block()
    // 更多代码...
    // 这会导致调用处代码膨胀
}

// 2. 递归函数中的函数参数
inline fun recursive(n: Int, block: (Int) -> Unit) {
    if (n > 0) {
        block(n)
        recursive(n - 1, block) // 内联可能导致代码重复
    }
}

七、协程与挂起函数

7.1 高阶挂起函数

// 带超时的重试机制
suspend fun <T> withTimeoutAndRetry(
    timeout: Long,
    maxRetries: Int = 3,
    block: suspend () -> T
): T {
    var lastException: Throwable? = null
    
    repeat(maxRetries) { attempt ->
        try {
            return withTimeout(timeout) {
                block()
            }
        } catch (e: TimeoutCancellationException) {
            lastException = e
            if (attempt == maxRetries - 1) throw e
        } catch (e: Exception) {
            lastException = e
            if (attempt == maxRetries - 1) throw e
        }
        delay(1000 * attempt) // 指数退避
    }
    throw lastException!!
}

// 资源管理
suspend fun <T : AutoCloseable, R> withResource(
    resource: T,
    block: suspend T.() -> R
): R = try {
    resource.block()
} finally {
    resource.close()
}

八、实战案例:事件处理系统

class EventBus {
    private val handlers = mutableMapOf<Class<*>, MutableList<(Any) -> Unit>>()
    
    inline fun <reified T : Any> subscribe(
        noinline handler: (T) -> Unit
    ) {
        val eventClass = T::class.java
        handlers.getOrPut(eventClass) { mutableListOf() }
            .add { handler(it as T) }
    }
    
    fun <T : Any> post(event: T) {
        handlers[event::class.java]?.forEach { it(event) }
    }
}

// 状态机
class StateMachine<S, E>(
    initialState: S,
    private val transitions: Map<S, Map<E, S>>,
    private val onTransition: ((S, E, S) -> Unit)? = null
) {
    private var currentState: S = initialState
    
    fun transition(event: E): Boolean {
        val nextState = transitions[currentState]?.get(event) ?: return false
        onTransition?.invoke(currentState, event, nextState)
        currentState = nextState
        return true
    }
    
    companion object {
        fun <S, E> create(init: Builder<S, E>.() -> Unit): StateMachine<S, E> {
            val builder = Builder<S, E>()
            init(builder)
            return builder.build()
        }
    }
    
    class Builder<S, E> {
        private var initialState: S? = null
        private val transitions = mutableMapOf<S, MutableMap<E, S>>()
        private var onTransition: ((S, E, S) -> Unit)? = null
        
        fun initialState(state: S) {
            initialState = state
        }
        
        infix fun S.on(event: E) = TransitionBuilder(this, event)
        
        inner class TransitionBuilder(
            private val from: S,
            private val event: E
        ) {
            infix fun to(to: S) {
                transitions.getOrPut(from) { mutableMapOf() }[event] = to
            }
        }
        
        fun onTransition(handler: (S, E, S) -> Unit) {
            onTransition = handler
        }
        
        fun build(): StateMachine<S, E> {
            requireNotNull(initialState) { "Initial state must be set" }
            return StateMachine(initialState!!, transitions, onTransition)
        }
    }
}

最佳实践建议

  1. 函数纯度:尽可能编写纯函数,避免副作用
  2. 组合优先:使用函数组合代替复杂的控制流
  3. 合理内联:只对小型高阶函数使用inline
  4. 类型安全:充分利用Kotlin类型系统
  5. 性能考量:注意lambda的性能开销,必要时使用内联
  6. 可读性:保持DSL的直观性和可读性
  7. 错误处理:使用Result或Either类型处理错误

这些高级函数特性使Kotlin能够编写出表达力强、类型安全且高效的代码,特别适合构建DSL、响应式系统和函数式编程架构。