Scala:在纯函数中优雅地转移状态

850 阅读2分钟

纯函数式状态

首先从 Scala / Java 自带的随机数生成器开始说起。

val rng = new Random()
println(rng.nextInt(6))  // 两次调用,得到不同的结果。
println(rng.nextInt(6))  

这个 rng.nextInt(6) 不是引用透明的,因为 Random 对象内部的状态 seed 在 以副作用的形式 发生了更改,而这个更改用户不可见。每一次调用 nextInt,都意味着内部的状态被破坏 ( 这擦除了 seed 上一次的状态,使得状态不可回溯 )。像这类函数都是难以测试,组合,模块化和并行化的。

想要恢复引用透明,其关键就是要让状态更新变为 显式 的。不要以副作用的方式更新状态,而是连同生成的值一起返回一个新的状态。

trait RandomNrGen :
  def nextInt : (Int,RandomNrGen)

不同于原 nextInt 只返回一个新的随机数,现在我们返回一个随机数 和一个新的状态,而之前状态保持不变。

这样,状态变化不再以擦除进行,而是衍生出下一个状态传递出去。这样,nextInt 的调用者有机会复用同一个状态。理想的情况是状态本身依然是被封装的,API 的使用者不需要直到状态传递的细节。

传递状态的思想还可以应用在 累积计算过程的尾递归函数:将累积过程通过隐式参数的途径传播。这个套路可以把一部分非尾递归函数优化为尾递归的,从而规避栈溢出的风险,同时用户可以不关注累积的过程。比如斐波那契数列的尾递归实现:

// 向右折叠。
@tailrec
def fib(n : Int)(using left : Int = 0, right : Int = 1, list : LazyList[Int] = LazyList(0,1).take(n)) : List[Int] =
  if(n <= 2) list.toList else fib(n-1)(using right,left + right,list :+ (right + left))

// 用户可以不关心 left, right, list 的中间状态。
// 0, 1, 1, 2, 3, 5, 8, ...
fib(40)

下面是使用 Scala 实现的 线性同余生成器。seed 是如何计算的并不重要,仅仅需要注意 nextInt 生成的是二元组:根据当前状态计算出的随机数,还有下一个状态。

case class SimpleRandomNrGen(seed : Long) extends RandomNrGen :
  private val nexus: Long = (seed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFL
  private lazy val nextRandomNrGen: SimpleRandomNrGen = SimpleRandomNrGen(nexus)
  override def nextInt: (Int, RandomNrGen) = ((nexus >>> 16).toInt,nextRandomNrGen)

注意,nextRandomNrGen 必须被声明为延迟加载的,否则创建实例时程序会递归调用 apply 方法直到栈溢出。见:

现在,对同一个随机数生成器的 nextInt 调用将是幂等的。它具备引用透明特性,只要生成器的状态保持不变,生成的随机数就肯定是同一个。因此,可以将 simpleRandomNrGen.nextInt 看作是一个定值 value。

val simpleRandomNrGen = SimpleRandomNrGen(3789012)
val (n,_) = simpleRandomNrGen.nextInt
val (m,_) = simpleRandomNrGen.nextInt

println(n == m)

通过连续调用下一个状态的 SimpleRandomNrGen,可以获得连续不同的随机数,同时每一次调用的状态都得以保留。

val simpleRandomNrGen = SimpleRandomNrGen(3789012)
val (n,nextGen1) = simpleRandomNrGen.nextInt
val (m,nextGen2) = nextGen1.nextInt
val (k,nextGen3) = nextGen2.nextInt
println(s"$n,$m,$k")

也可以借助上一章介绍的惰性加载流构建一个惰性随机序列。

// Scala 2.13 之后,Stream 被 LazyList 替换。
// #:: 是构造 LazyList 的方法,类比 ::。
def RandomStream(seed : Long = new Date().getTime)(using gen :RandomNrGen = SimpleRandomNrGen(seed)) : LazyList[Int] =
  gen.nextInt._1 #:: RandomStream()(using gen.nextInt._2)

用纯函数式实现带状态的 API

将带状态的 API 改造成这种传递状态的函数式风格并不是随机数生成器独有的,它是一个普遍性的问题,我们都可以用相同的方式来处理。比如:

class Obj :
  private var state : Short = 0
  def bar : Int = ???  // 会修改 state 变量。

假设 bar 每次都会以某种形式改变 state,那就可以通过在函数签名中明确地声明转换到下一个状态,将它翻译成纯函数式 API:

class Obj2 :
  def bar : (Int,Obj2) = ???

当然,相比直接擦除方式的就地更新,使用纯函数计算下一状态会带来一些性能上的损失。这取决于我们的程序是否需要保留状态。

首先实现几个简单的随机数生成方法来做热身:

  1. 生成 0 到 Int.MaxValue ( 包含 ) 的随机数。这可以使用 % 运算来实现,但是需要注意处理 Int.MaxValue 的情况。
  2. 生成一个 [0,1) 区间的 Double 数。
  3. 生成 (Int,Double) 随机数对,可以用已有的方法去实现。
case class SimpleRandomNrGen(seed: Long = new Date().getTime) extends RandomNrGen :
  // 包含这个上界, 防止溢出 [0,bound]
  // 这里需要考虑极限情况: Int.MaxValue。如果使用取模运算,我们最多只能获得 Int.MaxValue - 1。
  // 因此在 Long 的精度中计算,再转换回 Int。
  override def nonNegativeInt(bound: Int = Int.MaxValue): (Int, RandomNrGen) =
    val (v, n) = nextInt
    (if v < 0 then (math.abs(v + 1) % (bound.toLong + 1)).toInt else (v % bound.toLong + 1).toInt, n)

  override def _0_1double: (Double, RandomNrGen) =
    val tuple: (Int, RandomNrGen) = nonNegativeInt(bound = 99)
    (tuple._1.toDouble / 100, tuple._2)

object SimpleRandomNrGen:
  // 伴生对象版本的实现,它对外接收 randomNrGen。
  def intDouble(rdmGen : RandomNrGen) : ((Int,Double),RandomNrGen) =
    val (i,nxt) = rdmGen.nextInt
    val (d,nnxt) = nxt._0_1double
    ((i,d),nnxt)
  def nonNegativeInt(randomNrGen: RandomNrGen) : (Int,RandomNrGen) = randomNrGen.nonNegativeInt()
  def _0_1double(randomNrGen : RandomNrGen) : (Double,RandomNrGen) = randomNrGen._0_1double

  // 生成连续的 Int 数串
  def RdmIntStream(seed: Long = new Date().getTime, bound: Int = Int.MaxValue)(using gen: RandomNrGen = SimpleRandomNrGen(seed)): LazyList[Int] =
    val (i, nxt) = gen.nonNegativeInt(bound)
    i #:: RdmIntStream(bound = bound)(using nxt.asInstanceOf[SimpleRandomNrGen])

  // 生成连续的 Double [0,1) 数串
  def Rdm0_1DoubleStream(seed: Long = new Date().getTime)(using gen: RandomNrGen = SimpleRandomNrGen(seed)): LazyList[Double] =
    val (d, nxt) = gen._0_1double
    d #:: Rdm0_1DoubleStream()(using nxt)

行为状态更好的 API

回顾之前的所有实现。我们的每一个函数形式上都类似:

// S: state
// V: value
(State) => (V,State)

这种类型的函数被称之为状态转移:从一个状态 S 变换到了另一个状态 S。如本文的随机数生成器可以定义为:

type Rdm[+A] = RandomNrGen => (A,RandomNrGen)

注意,Rdm[+A] 是函数类型。采取这种类型的目的是对用户屏蔽 RamdomNrGen 的状态传递过程而只去关注值 A任何返回 Rdm[A] 的函数都是高阶函数。为了便于理解,我们后文称这样的高阶函数为 行为。比如:

def f[A]: Rdm[A] = ???
def f[A](args : Any*) : Rdm[A] = ???     // 可以使用自由变量。

f 的具体实现应该是以初始状态 S 为参数,返回值为 (A,S) 二元组的闭包。如:

def f[A]: Rdm[A] = s => 
  val (a,nxt) = s.g    // 通过上一个 s 计算获取了当前状态下的值和状态。
  (a,nxt)

行为f 只有接收到一个 RandomNrGen 之后才会被驱动执行,因为 使用高阶函数相当于借助柯里化实现了延迟执行的效果。柯里化的延迟执行 ( 或称对参数的延迟确认 ) 特性在最后一个例子:模拟有限状态机中还会用到。

现在,我们不再关注外界会传入什么样的 ( 初始 ) 状态进来,而是将目光聚焦到行为本身并做进一步抽象。行为本身是函数,因此我们本能地认为行为本身也可以映射,组合,嵌套,这种直觉是正确的,是时候针对 Rdm[+A] 类型做一些 "基建" 了。

首先实现 unit 方法,它接收一个 a,然后仅简单地传递一个固定的 randomNrGen 状态。可以将 unit(100) 调用看作是一个字面量;换一个角度,也可以把 unit 本身看作是一个从 [A]Rdm[A]的 **提升 (lift) ** 行为。

def unit[A](a: A): Rdm[A] = randomNrGen => (a, randomNrGen)

然后实现基础的 map 组合子,它是一个转换行为:提取 Rdm[A] 行为的结果 A,将它映射为 B 之后继续传递。

def map[A, B](rdm: Rdm[A])(f: A => B): Rdm[B] = rdmNrGen =>
  val (a, nxt) = rdm(rdmNrGen)
  (f(a), nxt)

组合,嵌套状态行为

map 还不够强大到表达组合行为。我们还需要另创建一个 map2 组合子,用来 将两个行为叠加成一个行为。实现如下:

def map2[A,B,C](rdmA : Rdm[A],rdmB : Rdm[B])(f : (A,B) => C) : Rdm[C] =
  randomNrGen =>
	// randomNrGen 的状态在此转移两次。
	val (value_A,nxt) = rdmA(randomNrGen)
	val (value_B,nnxt) = rdmB(nxt)
	(f(value_A,value_B),nnxt)

有了 map2 方法,上文的 intDouble ( 或者 doubleInt 等 ) 就可以转而用非常简短的表述来代替:

def intDouble(firstDo : Rdm[Int],andThen : Rdm[Double]): Rdm[(Int, Double)] = map2(firstDo,andThen)((_,_))

稍微对这类方法做一层泛化,可以得到 both 方法,它允许随机生成任意类型的 (A,B) 对。

def both[A,B](rdmA : Rdm[A],rdmB : Rdm[B]): Rdm[(A, B)] = map2(rdmA,rdmB)((_,_))

一个稍稍难以理解的是 sequence 方法。这个通用组合子在之前的章节已经出现过多次,它代表 "翻转",比如将 List[Option[A]] 变换为 Option[List[A]]。同理,我们在这里预期实现将 List[Rdm[A]] 翻转为 Rdm[List[A]] 的方法:将一连串生成单值的行为翻转为一个生成多个值的行为。

def sequence[A](fs : List[Rdm[A]]) : Rdm[List[A]] =
  fs.foldRight(unit[List[A]](List[A]())){
    (rdm,acc) => {map2(rdm,acc)(_ :: _)}
  }

sequence 方法在 map2 和 List 提供的 foldRight 的基础上进行。注意,右折叠的初值传递 List[A]() 而非 Nil,因为后者无法让编译器进行有效的类型推断。

flatMap 是另一个具有强大表达能力的组合子,用于提供嵌套组合 Rdm[List[A]] 的能力:

  // Rdm[B] = (Rng) => (B,Rng)
  def flatMap[A,B](rdmA : Rdm[A])(f : A => Rdm[B]) : Rdm[B] = rdm =>
    val (value_A,nxt) = rdmA(rdm)
    f(value_A)(nxt)

下一节展示了 flatMap 是如何实现 mapmap2 方法的,这也是 flatMap 比其它两个表达能力更强大的原因。实现 flatMap 的另一个重要用途是 令 Scala 编译器支持用 for 表达式 表达行为的嵌套组合,以及归约,见后文。

提炼通用表达

我们到目前为止已经实现了对随机数生成器的 unitmapmap2flatMapsequence 函数。这些都是函数式编程中的通用行为,不关心状态类型。于是我们提取出了更加泛化的签名:

def map[S,A,B](a :S => (A,S))(f: A => B) : S => (B,S)

之前的 Rdm[A] 也可以有更通用的形式:

type State[S,+A] = S => (A,S)

这里,State 可以代指 "状态",甚至延伸为 "指令" statement 的缩写

我们从编写随机数生成器的例子总结经验,最终完成一个的通用模式。mapmap2flatMap 这三个组合子在类定义中实现,unitsequence 在伴生对象中实现。

case class State[S, +A](run: S => (A, S)):
  // 当参数是单个参数时,可以使用花括号 {} 代替 () 。
  def map[B](from: A => B): State[S, B] = flatMap { a => unit(from(a)) }

  // 与另一个 State 合并出一个新的 State。
  // 理解了它,可以结合 foldRight 实现 sequence 方法。
  def map2[B, C](otherState: State[S, B])(zip: (A, B) => C): State[S, C] =
	flatMap { a => { otherState.map { b => zip(a, b) }}}

  // 最基础组合子。使用 A 生成了下一个状态 nxt,然后返回包含下一个状态 nxt 的 statement。
  def flatMap[B](f: A => State[S, B]): State[S, B] = State { s => val (a, nxt) = run(s);f(a).run(nxt) }

object State:
  // 可以看作是将单个值 a : S 结合另一个状态 S 升级成 State[S,A] 的过程。
  // 如果 a 是 List[T] 类型,那么 unit 方法会提升为 State[S,List[T]], 见 sequence。
  def unit[S, A](a: A): State[S, A] = State { s => (a, s) }

  def sequence[S, A](ss: List[State[S, A]]): State[S, List[A]] =
    ss.foldRight(unit[S, List[A]](List[A]()))((statement,acc) => { statement.map2(acc)( _ :: _) })

mapmap2 本质上是对flatMap 的复用,而 flatMap 内部包含隐式的状态转移。因此,所有的上层 API 逻辑调用均会触发状态转移,但用户无需对此进行过多关注

纯函数式命令编程与 For 表达式

在命令式编程中,程序是由一系列的指令 statement 组成的。每个指令可以修改状态,而在本章每一个指令是一个函数:他们接受参数读取程序的当前状态,然后返回一个值代表写入程序状态。

因此,函数式编程和命令式编程并不是对立的,使用无副作用的函数来维护程序状态也是完全合理的。函数式编程对写命令式程序也有很好的支持,还有额外的好处,比如程序可以被等式推理。

在之前已经实现了 mapmap2 以及 flatMap 终极组合子,来处理状态从一个指令到另一个指令的传播。随着状态的转移,返回的值类型也许会产生变化。仍以随机数生成器为例子:

val int = State[RandomNrGen,Int]{_.nextInt}
val action : State[RandomNrGen,List[Int]] = int.flatMap {
  x => int.flatMap {
    y => int.map {
        z => List(x,y,z)
    }
  }
}

// 传入生成器驱动行为执行。
println(action.run(new SimpleRandomNrGen(3000)))

这种代码风格看起来还是少了些 "命令式" 的语气,不太容易一下看出这段代码在做什么。 Scala 的 for 表达式推导可以还原 "命令式" 的风格:

val value: State[RandomNrGen, List[Int]] =
for {
  x <- int  // 从 int 行为中获取 x
  y <- int  // 从 int 行为中获取 y
  z <- int  // 从 int 行为中获取 z
} yield List(x, y, z)

// 传入生成器驱动行为执行。
val(a,nxt) = action.run(new SimpleRandomNrGen(3000))
// 打印随机数列
println(a.mkString(","))

for 表达式推导的代码可读性更强,它看起来是一段命令式的程序,但实际上和上一段 flatMapmap 的组合代码 完全等价

进一步,假设我们有一个 get 组合子来获取当前状态,set 组合子设置当前状态,那么就可以实现一个以任意方式修改状态的组合子:

def modify[S](f : S => S ): State[S, Unit] = for {
  s <- get
  _ <- set(f(s))
} yield ()

// 保持状态类型不变为 S,同时将当前状态 S 作为值返回,因此是 State[S,S]
def get[S] : State[S,S] = State {s => (s,s)}

// 设置状态是一个副作用,不要求返回值,因此将返回值 A 标注为 Unit。
// () 是 Unit 的字面量。
def set[S](s : S) : State[S,Unit] = State { _ => ((),s)}

modify 行为可以通过组合偏函数的方式来 有选择性地修改状态 ( 之前提到过偏函数也是 Function01 类型 )。另一方面,为了便于编译器进行类型推断,调用时最好显示地标注类型参数。

modify[RandomNrGen]{
  case SimpleRandomNrGen(100) => SimpleRandomNrGen(200)
  case s => s
}.run(new SimpleRandomNrGen(100))

modify 行为的大意是:如果传入的生成器的种子是 100,就将其替换成 200 的那个。

这段演示还是有些拙劣。我们不如看看一个新的例子:糖果机问题。

模拟有限状态机

这个例子是书中练习 6.11 的最后一道难题。给定以下实现:

// Scala 2.x 可以用 trait + case class 来表示代数类型。
enum Input:
  case Coin extends Input
  case Turn extends Input

case class Machine(locked : Boolean, candies : Int, coins : Int)
def simulateMachine(inputs : List[Input]) : State[Machine,(Int,Int)] = ???

机器遵循这样的规则:

  1. 对锁定状态 ( locked = true ) 的售货机投入一枚硬币,如果有剩余的糖果 ( coins != 0 ) 就将它变成非锁定状态。
  2. 对一个非锁定的 ( locked = false ) 的售货机投入一枚硬币,将给出一枚糖果然后变成锁定状态。
  3. 对一个锁定状态的售货机按下按钮或对非锁定状态的售货机投入硬币则什么都不发生。
  4. 售货机在输出糖果时忽略其它输入 ( 一次只处理一个状态,串行的 )。

simulateMachine 的输入很明确,如 List(Coin,Turn,Coin,Turn) 这样的 List[Input] ,用户期望从最后一个 State[Machine,(Int,Int)] 输出中提取 (Int,Int) 元组。

这道题的答案可在下面的 github 链接当中找到: fpinscala/11.answer.md at second-edition · fpinscala/fpinscala (github.com) 。想要直接看懂这段代码是一段困难的事,这里做逐步分析。

糖果机的状态用前文声明的 State 类型包装并传递。这个例子 不关注中间状态是如何变化的,不妨将中间状态设置为 State[Machine,Unit] ( 参考之前的 set 方法 )。在最后一次调用中,使用二元组表示剩下的糖果数量和硬币数量,同时传递下一个 Machine 状态 ( 虽然之后它不会再被使用了 ),此时为 State[Machine,(Int,Int)]

有了上述线索之后,下一步就是解构 List[Int] 转化为 State[Machine,(Int,Int)] 的中间过程。首先,List[Inputs] 一定会映射 map 成记录机器中间状态的序列:List[State[Machine,Unit]]

// 暂时写到这里。
inputs.map{ input  => ??? }

状态转移的具体规则根据题目的要求实现。在上手之前,我们注意到,机器的每一次状态转移还可拆分为 有序的 三个步骤:

  1. 首先确认用户的输入 Input
  2. 检查机器的当前状态 Machine
  3. 返回机器的下次状态 Machine

想要把这三个步骤集成到一个函数,并实现延迟确认的效果,因此这里再次引入柯里化。它的函数签名应该是这样的:

def update: Input => Machine => Machine

根据题意完善逻辑,可得:

def update: Input => Machine => Machine = (i: Input) => (s: Machine) =>
  (i, s) match {
    case (_, Machine(_, 0, _)) => s
    case (Coin, Machine(false, _, _)) => s
    case (Turn, Machine(true, _, _)) => s
    case (Coin, Machine(true, candy, coin)) =>
      Machine(false, candy, coin + 1)
    case (Turn, Machine(false, candy, coin)) =>
      Machine(true, candy - 1, coin)
  }

整个模式匹配描述了状态转移的所有过程,因此将它集成到 modify 行为即可。当 update 接受一个 Input 输入后,并不会立刻得到结果,而是返回一个 Machine => Machine 的偏函数,它对应 modify 行为要求的 f: S => S 参数。

inputs.map{ input  =>
    val machineToMachine: Machine => Machine = update(input)
    modify[Machine](machineToMachine)
}

官网使用 compose 给出了更加紧凑且抽象的实现,这要求我们对这类组合子比较熟悉 ( 至少要分清它和 andThen 的区别:Java 8 compose 和 andThen 方法区别-CSDN)。两者的语义是等价的:

// Scala 3 之后, Eta 拓展的过程是自动的,可以写成:
// inputs.map {modify[Machine].compose(update)}

// compose 的计算顺序是从右向左,相当于隐式地将 map 的 input 参数传递到 update,
// 得到 Machine => Machine,再将它传递到 modify[Machine] 内。
inputs.map {modify[Machine] _ compose update}

最后一个问题是:如何获取最后一个状态。思考我们以前获取列表末尾元素时是怎么做的 —— 创建一个迭代器遍历,从头开始,直到末尾;这往往通过一个典型的命令式 for 循完成。

前文已经提示过 Scala 的 for 表达式可以给出关注分离的等效实现:在 "遍历" 所有 Input 状态之后,调用之前的 get 方法将最后一个 Machine 状态赋值给 s,然后从中提取 candiescoins 信息。

def simulateMachine(inputs : List[Input]) : State[Machine,(Int,Int)] = for {
  // 不关心中间结果,将 A 类型设置为 Unit 即可。
  _ <- sequence[Machine,Unit](inputs.map{
    input  =>
    val machineToMachine: Machine => Machine = update(input)
    modify[Machine](machineToMachine)
  })
  // 获取最后一刻的状态
  s <- get
} yield (s.candies,s.coins)

它的简化版本就是链接中的源码:

def simulateMachine(inputs: List[Input]): State[Machine, (Int, Int)] = for {
  _ <- sequence(inputs map (modify[Machine] _ compose update))
  s <- get
} yield (s.coins, s.candies)

下面做一个简单的测试:

val ((candies,coins),machine) = simulateMachine(List(Coin,Turn,Coin,Turn)).run(Machine(true,10,0))
println(s"candies = $candies, coins =$coins")