使用 Kotlin 函数式地解决问题:从希思罗机场到伦敦

1,708 阅读8分钟

本文转载与本人个人公众号:Kotlin 维修车间 原文发布日期:2020-09-09

这篇文章算是一篇读书笔记,最近在同时读两本 Haskell 的书:《Haskell 趣学指南》与《魔力 Haskell》,作为一个函数式编程领域的初学者,在读书的过程中我经常需要反复阅读和反复思考,也因此经常能获得一些新的灵感。本文描述的问题就可以紧紧围绕函数式编程中的两个核心点来解决:

1.使用正确的类型设计来描述问题 2.几乎所有的递归处理列表的函数都可以使用折叠(fold)实现

在《Haskell 趣学指南》的前言中作者就说过:“你不再像命令式语言那样命令电脑‘要做什么’,而是通过用函数来描述问题‘是什么’......”那么如何描述问题是什么,类型的设计就显得非常重要。在函数式编程中,函数和值之间的界限并没有非常明显,我们可以将它们都视为函数,那么一个值只不过是一个没有参数只有返回值的函数而已;我们也可以将它们都视为值,那么函数只不过是一种由多种类型以特定形式组合在一起的类型更复杂的值罢了;也可以再换一个角度看,函数通过传参并得到结果实际上也是柯里化中通过绑定参数来实现函数类型的变化。因此在函数式编程中,使用类型系统对整个问题的抽象甚至比使用面向对象式编程对类型的要求更高。

fold 函数(包括 foldl、foldr 等)是函数式编程中的一个核心函数,在 Haskell 中非常重要,我们经常会用到它们。Kotlin 标准库中也提供有相应的函数实现,但是由于 Kotlin 不是纯函数式编程语言,所以 fold 函数的使用场景大大减少,而本文讲解的问题也将体现 fold 函数的神奇之处。

在 《Haskell 趣学指南》一书中,第 10 章列举了两个很典型的可以用函数式编程思想解决的问题。一个是“逆波兰式计算器”,另一个就是“从希思罗机场到伦敦”,其中后者相对来说更加复杂,也更能体现本文想要传达的思想,所以我们将后面这个问题通过逐步讲解,并使用 Kotlin 进行实现,从而加深我们对函数式编程思想的运用。

我们先来描述我们的问题,从希思罗机场到伦敦:

假设我们在出差,飞机已经抵达英国了,我们租了一辆车。很快就有一场会议,我们需要尽可能快地从希思罗机场到伦敦(前提是要保证安全)。

从希思罗机场到伦敦有两条主干道,他们之间有一些联通的小路。两个路口之间需要花费的时间是固定的。找到最优路径从而准时到达伦敦开会就靠我们了。我们从左边那条路出发,既可以换行到另一条路,也可以笔直向前。

那么我们直接使用原书中的说明图来直观的看一下问题描述,并了解每个路段花费的时间具体是多少:

640.webp

两条主干道分别是 A 与 B,最终的目的地是 A4 或 B4。中间的小路我们可以统称为 C,小路与主干道交汇的路口我们都用 A1、A2......B1、B2......来表示。我们的目的是找出最快的路线。

由于篇幅问题,我尽量使用自己组织的简洁语言来描述原书中的解法,因为本文的重点还是 Kotlin 代码,想通过文字来了解解法思想还是建议阅读原书。

先忘记代码和程序,如果我们用人工来解这个问题会如何思考?我们会分别从 A 和 B 出发,建立两条路径,然后将出发点到下一个路口(可能是 A 的,也可能是 B 的)最短时间作为新的路线累加到原路线上,然后依次向后计算,最终比较这两条路线哪一条更短。基于这样的解题思想,我们这样抽象一下,将从 A 干道始发点到 A1 的路段,与从 B 始发点到 B1 的路段,以及 A1 与 B1 之间的联通小路作为一个地段(Section),后面的也依此类推。

那我们先对 Section 进行一下类型定义:

data class Section(val a: Int, val b: Int, val c: Int) 
typealias RoadSystem = List<Section>

我们首先用 data class 定义了 Section 类型。那么这条道路系统就是 Section 的列表,为了方便,我们使用类型别名将其定义为 RoadSystem。

那么从希思罗机场到伦敦的道路系统就可以这样表示:

val heathrowToLondon: RoadSystem = listOf(    
    Section(a = 50, b = 10, c = 30),    
    Section(a = 5 , b = 90, c = 20),    
    Section(a = 40, b = 2 , c = 25),    
    Section(a = 10, b = 8 , c = 0 )
)

我们要求得最佳路线的函数的类型应该是什么?它应该接受道路系统类型为参数,并返回路线。路线也应该表示为一个列表,它应该包含两个信息,1.位于哪条路,2.花费的时间是多少。因此路线可以表示为:

enum class Label {    
    A { override fun toString(): String = "A" },    
    B { override fun toString(): String = "B" },    
    C { override fun toString(): String = "C" };
}
typealias Path = List<Pair<Label, Int>>

当然,我们这里还是先定义了一个枚举类型 Label 来表示路。

所以我们求得最优路线的函数应该定义为:

fun optimalPath(roadSystem: RoadSystem): Path

回想一下我们刚才人工求解这个问题时的步骤,我们通过遍历每一个路段(Section)得到路线(Path),然后将得到的路线累加。这正是 fold 函数折叠操作的用武之地,我们不断的折叠 Section 列表(RoadSystem),并将其归约到一个表示路线的列表(Path)。

我们在人工求解的时候一直在重复一个步骤,就是根据抵达当前 A 路与 B 路的最佳路线,计算到达下一地段 A 路与 B 路的最佳路线。所以如果将这个步骤编写成一个函数的话,函数类型应该如下:

(Pair<Path, Path>, Section) -> Pair<Path, Path>

这个函数实际上也就是要作为我们 fold 函数参数的函数,我们现在来实现该函数:

fun roadStep(pair: Pair<Path, Path>, section: Section): Pair<Path, Path> {    
    val (pathA, pathB) = pair    
    val (a, b, c) = section    
    val timeA = pathA.map(Pair<Label, Int>::second).sum()    
    val timeB = pathB.map(Pair<Label, Int>::second).sum()    
    val forwardTimeToA = timeA + a    
    val crossTimeToA = timeB + b + c    
    val forwardTimeToB = timeB + b    
    val crossTimeToB = timeA + a + c    
    val newPathToA = mutableListOf<Pair<Label, Int>>().apply {        
        if (forwardTimeToA <= crossTimeToA) {            
            addAll(pathA)            
            add(Label.A to a)        
        } else {            
            addAll(pathB)            
            add(Label.B to b)            
            add(Label.C to c)        
        }    
    }    
    val newPathToB = mutableListOf<Pair<Label, Int>>().apply {        
        if (forwardTimeToB <= crossTimeToB) {            
            addAll(pathB)            
            add(Label.B to b)        
        } else {            
            addAll(pathA)            
            add(Label.A to a)            
            add(Label.C to c)        
        }    
    }    
    return newPathToA to newPathToB
}

这么看起来这个函数有点典型的命令式写法,不过《Haskell 趣学指南》中作者使用 Haskell 的 let...in... 表达式也写出了类似的效果,这里是对书中代码的还原。

在这个函数中我们分别计算从 A 直接到达 A 与从 B 穿过 C 到达 A 所需的时间,以及从 B 直接到达 B 与从 A 穿过 C 到达 B 的时间。然后分别计算时间的长短,最后通过比较并最终返回新的路线。

最后我们以 fold 函数作为整个计算的核心,通过递归来确定两条路线,并比较哪一条时间更短,最终编写 optimalPath 函数的实现:

fun optimalPath(roadSystem: RoadSystem): Path =    
    roadSystem.fold(listOf<Pair<Label, Int>>() to listOf(), ::roadStep).let { (bestAPath, bestBPath) ->        
        if (bestAPath.map(Pair<Label, Int>::second).sum() <= bestBPath.map(Pair<Label, Int>::second).sum())            
        bestAPath        
        else            
        bestBPath    
    }

最后打印输出结果:

fun main() = println(optimalPath(heathrowToLondon))

结果为:

[(B, 10), (C, 30), (A, 5), (C, 20), (B, 2), (B, 8), (C, 0)]

由于最后的 A4 和 B4 是同一个地点(中间的路段 C 耗时为 0)。所以,最终的结果多了一个 (C, 0)。

下面给出整个问题的完整代码:

fun main() = println(optimalPath(heathrowToLondon))

data class Section(val a: Int, val b: Int, val c: Int)
typealias RoadSystem = List<Section>

val heathrowToLondon: RoadSystem = listOf(    
    Section(a = 50, b = 10, c = 30),    
    Section(a = 5 , b = 90, c = 20),    
    Section(a = 40, b = 2 , c = 25),    
    Section(a = 10, b = 8 , c = 0 )
)

enum class Label {    
    A { override fun toString(): String = "A" },    
    B { override fun toString(): String = "B" },    
    C { override fun toString(): String = "C" };
}
typealias Path = List<Pair<Label, Int>>

fun optimalPath(roadSystem: RoadSystem): Path =    
    roadSystem.fold(listOf<Pair<Label, Int>>() to listOf(), ::roadStep).let { (bestAPath, bestBPath) ->        
        if (bestAPath.map(Pair<Label, Int>::second).sum() <= bestBPath.map(Pair<Label, Int>::second).sum())            
        bestAPath        
        else            
        bestBPath    
    }
    
fun roadStep(pair: Pair<Path, Path>, section: Section): Pair<Path, Path> {    
    val (pathA, pathB) = pair    
    val (a, b, c) = section    
    val timeA = pathA.map(Pair<Label, Int>::second).sum()    
    val timeB = pathB.map(Pair<Label, Int>::second).sum()    
    val forwardTimeToA = timeA + a    
    val crossTimeToA = timeB + b + c    
    val forwardTimeToB = timeB + b    
    val crossTimeToB = timeA + a + c    
    val newPathToA = mutableListOf<Pair<Label, Int>>().apply {        
        if (forwardTimeToA <= crossTimeToA) {            
            addAll(pathA)            
            add(Label.A to a)        
        } else {            
            addAll(pathB)            
            add(Label.B to b)            
            add(Label.C to c)        
        }    
    }    
    val newPathToB = mutableListOf<Pair<Label, Int>>().apply {        
        if (forwardTimeToB <= crossTimeToB) {            
            addAll(pathB)            
            add(Label.B to b)        
        } else {            
            addAll(pathA)            
            add(Label.A to a)            
            add(Label.C to c)        
        }    
    }    
    return newPathToA to newPathToB
}

当然原书中的  Haskell 代码这里一并给出:

main :: IO ()
main = print . optimalPath $ heathrowToLondon

data Section = Section { getA :: Int, getB :: Int, getC :: Int }
type RoadSystem = [Section]

heathrowToLondon :: RoadSystemheathrowToLondon = [
    Section 50 10 30,                     
    Section 5 90 20,                     
    Section 40 2 25,                     
    Section 10 8 0
]

data Label = A | B | C deriving (Show)
type Path = [(Label, Int)]

optimalPath :: RoadSystem -> Path
optimalPath roadSystem =    
    let (bestAPath, bestBPath) = foldl roadStep ([], []) roadSystem    
    in if sum (map snd bestAPath) <= sum (map snd bestBPath)           
        then reverse bestAPath           
        else reverse bestBPath 
        
roadStep :: (Path, Path) -> Section -> (Path, Path)
roadStep (pathA, pathB) (Section a b c) =    
    let timeA = sum (map snd pathA)        
        timeB = sum (map snd pathB)        
        forwardTimeToA = timeA + a        
        crossTimeToA = timeB + b + c        
        forwardTimeToB = timeB + b        
        crossTimeToB = timeA + a + c        
        newPathToA = if forwardTimeToA <= crossTimeToA                         
                     then (A, a) : pathA                         
                     else (C, c) : (B, b) : pathB        
        newPathToB = if forwardTimeToB <= crossTimeToB                         
                     then (B, b) : pathB                         
                     else (C, c) : (A, a) : pathA    
    in (newPathToA, newPathToB)

Kotlin 版本相比 Haskell 的原版,这里有几个点要注意一下。1.由于 Haskell 中变量不可变,所以所有的修改列表的操作都会产生一个新列表,在 Kotlin 的实现版中我也模仿了这一点,在静态类型声明的时候都使用的是 List 类型而不是 MutableList 类型,而且在每次修改列表后,都返回了一个新列表而没有修改之前的列表。但是如果要写一个对空间复杂度更友好的版本,这里建议采用可变列表 MutableList。2.由于 Haskell 中列表的特殊性,即元素头插的效率远高于尾插,所以 Haskell 的原代码中 roadStep 函数都使用列表头插来更新 Path,而在 optimalPath 函数返回要展示的结果时,使用 reverse 函数将列表反转以得到正确的顺序。但在 Kotlin 中我们并不存在这个问题,add 函数本身就是尾部插入元素,所以 optimalPath 函数中也就没有这个列表反转操作。

函数式编程是解决某些问题的利器,不过由于其过于“数学化”也会在有些时候造成对内存的无谓开销,在日常的工作中编写代码的时候,要善用函数式编程思想解决问题,也要利用现代编程语言通常是多范式的这一优势来规避一些会在“计算机科学”上产生的负面作用。