探究 Scala 非严格求值与流式数据结构

732 阅读10分钟

关注分离

俄国著名作家契科夫有一句著名的戏剧理论:假如不打算开火,就别让一支上膛的来福枪出现。

知乎有一个对懒加载和 Stream 的简要总结,可参考:Scala函数式编程(六) 懒加载与Stream - 知乎 (zhihu.com)

FP 的天性就是 "不可变" 的,这意味对每做一次变换就必须新建一个数据副本,因此不可变序列的 mapfilterfoldLeftfoldRight 方法等总会构造一个新的序列。当多个变换相互衔接时,每一次转换的输出会作为输入传给下一个操作,之后就会被立刻丢弃。

val l = List(1, 2, 3, 4, 5).map(_ + 1)    // List(2,3,4,5,6)
  .filter(_ > 2)                          // List(3,4,5,6)
  .map(_ * 3)                             // List(9,12,15,18)
println(l.mkString(",")) 

显然,函数式编程的不变性是以牺牲空间为代价的。避免产生中间结果的一个方式是使用 for 或者是 while 循环重写代码,每迭代一个元素就让其执行全部变换:

val l = List(1,2,3,4,5)
// 这里只演示大意,
// 因为 Scala 没有传统意义上的 for 循环,底层仍然是 flatMap, map, fiterWith 。
val r = for i <- l if i + 1 > 2 yield (i + 1) * 3
println(r.mkString(","))

而理想的情况是,我们既不必手动完成这部分工作,同时还保留高阶组合的编程风格,因此本章引入惰性求值 ( 正式的叫法是非严格 non-strictness 求值 ) 来解决。

Scala 已经给出了关于惰性计算的解决方案,如 lazy 关键字 ( 底层是线程安全的懒汉式单例,只能用于不可变变量 val ),StreamView。它们的设计遵从一个主旨,那就是 关注分离 ( separation of concerns ),具体点说就是将计算的声明和实际执行区分开来,具体来说程序只在需要的时候才执行必要的求值,而未必在数据被声明的一刻。这有两个好处:

  1. 避免潜在的无用但耗费大量资源的计算。
  2. 避免产生大量中间结果。

在 Scala 2.13 版本之后,Stream 已被标记为过时的,官方推荐使用 LazyList 作为代替。见:Scala Standard Library 2.13.7 - scala.collection.immutable.LazyList (scala-lang.org)

// LazyList(<not computed>)
LazyList(1, 2, 3, 4).map(_ + 1).filter(_ > 2).map(_ * 3)

// 9,12,15,18
println(LazyList(1, 2, 3, 4, 5).map(_ + 1).filter(_ > 2).map(_ * 3).mkString(","))

有关于流和视图的内容可以参考笔者的旧笔记:Scala +:类型推断,列表操作与 for loop - 掘金 (juejin.cn)

非严格求值 ( 惰性求值 )

**非严格求值 **是函数的一个属性,称一个函数进行非严格求值的意思是该函数可以选择不对传入的所有参数进行求值。在 Scala 中,默认情况下都是严格求值的。比如下面的程序执行会抛出异常,因为在调用 div 函数之前,参数 b 位置上的 IllegalArgumentException 被率先创建并抛出。

def div(a : Double, b : Double): Double = a / b
div(1,throw new IllegalArgumentException("crash this."))

在很多编程语言 ( 包括 Scala ) 中,典型非严格求值的例子是:||&& 。这两个运算可能只需接受前一个参数就能得到确切的结果,因此它们也被称之 短路运算。另一个非严格求值的例子是 if 条件分支,如三目运算符:con ? A : B ,显然,当 contrue 时,B 分支会被丢弃,反之亦然。

我们已在 Scala 中接触过如何编写非严格求值函数,那就是 传名调用 ( call by name ):

def div(a : => Double, b : => Double): Option[Double] = try Some(a / b) catch case _ => None
div(1,throw new IllegalArgumentException("crash this."))

由于 ab 的计算被推迟到调用 div 之后,因此 try-catch 块得以捕获异常并返回 None,这一点是理解后文的基础。另一个等价表达是传入 () = > A ( 它实际上是 Function0[A] ) 形式的空括号函数,但是调用的写法更麻烦一些:

def div(a : () => Double, b : () => Double): Option[Double] = try Some(a() / b()) catch case _ => None
div(()=>1,()=>throw new IllegalArgumentException("crash this."))

无论是哪种表达方式,这类未求值的语句块被称为 thunk 。而 a()b() 又称之为延迟求值 ( 原书中称强制求值 ),代表此刻 thunk 被真正计算。

惰性列表

下面通过自实现一个惰性列表 ( lazy list ) 或 Stream 的例子来演示 Scala 如何利用惰性操作提升效率。首先给出基础定义:

enum _Stream[+A]:
  case Empty extends _Stream[Nothing]
  // Scala 规定属性不能定义传名参数,这里使用空括号函数代替。
  case Cons[+A](head: () => A, tail: () => _Stream[A]) extends _Stream[A]

  import _Stream.*
  def headOption: Option[A] = this match
    case Empty => None
    case Cons(h,t) => Some(h())

object _Stream:
  def cons[A](hd: => A, tl: => _Stream[A]): _Stream[A] =
    lazy val h: A = hd
    lazy val t: _Stream[A] = tl
    Cons(() => h, () => t)

  def empty[A] : _Stream[A] = Empty
  // Scala 2: as.tail : _*
  // Scala 3: as.tail*
  // 方便我们构造一个流,如 _Stream(1,2,3,4,5)。
  def apply[A] (as : A*) : _Stream[A] = if as.isEmpty then empty else cons(as.head,apply(as.tail*))

由于编译器技术的限制,Cons[+A] 的属性不能定义成传名参数,因此这里替换成了空括号函数的形式。而前文提到空括号函数的声明会影响用户体验,故我们还额外定义了首字母小写的同名 "伪构造器",如 consempty 等,这让用户规避了调用时冗余的 ()=>... 写法。

//  使用 "伪构造器" 可以简化用户的体验。
//  val y = cons({println("init y(0)");10},empty)
val x = Cons(()=>{println("init x(0)");10},()=>empty)

采纳 "伪构造器" 的另一个原因是直接使用 Cons(head,tail) 构造器虽然能够延迟加载的时机,但不能缓存 thunk 的计算结果。为了证明这一点,不妨在构造 Cons 时插入一点副作用,然后多次调用 headOption 方法。可以发现控制台打印了多行 init x(0),这表示传名调用被反复地计算了。

// 延迟加载,声明 x 时不打印任何信息。
val x = Cons(()=>{println("init x(0)");10},()=>empty) 
val h1 = x.headOption		// init x
val h2 = x.headOption		// init x 

cons 数据构造器内部使用两个 lazy 变量 ht。在 hdtl 被首次计算之后,该结果会被缓存。此时无论反复调用多少次 headOption,控制台也仅打印一行 init y(0)

// 延迟加载,声明 y 时不打印任何信息。
val y = cons({println("init y(0)");10},empty)

val h1 = y.headOption	// init y
val h2 = y.headOption	// 使用 lazy h 保存的结果,不打印。

注:类似这种伪构造器 ( 它其实只是个普通的方法 ) 的创建在 Scala 中是比较常用的编程技巧。

方法实现

首先尝试着在一般的序列中常用的其它方法:toListtakedrop,这些方法都可以使用递归 + 模式匹配的组合实现,因为 _Stream 自身的定义就是递归的。

def toList : List[A] =  this match
  case Empty => Nil
  case Cons(h,t) => h() :: t().toList

def take(n : Int) : _Stream[A] = this match
  case Cons(h,t) if n > 0 => cons(h(),t().take(n-1))
  case _ => empty

def drop(n : Int) : _Stream[A] = this match
  case Cons(_,t) if n > 0 => t().drop(n-1)
  case _ => this

def takeWhile(p : A => Boolean) : _Stream[A] = this match
  case Cons(h,t) if p(h()) => cons(h(),t().takeWhile(p))
  case Cons(_,t) => t().takeWhile(p)
  case _ => empty

def forAll(p : A => Boolean) : Boolean = this match
  case Cons(h,t) if !p(h()) => false
  case Cons(_,t) => t().forAll(p)
  case _ => true

注意,只有 h()t() 代表实际求值。假设判断发生了短路,那么意味着后续的 t 根本就不会被计算。forAll 的两个 case 分支可以用短路运算符 &&|| 合并为更紧凑的判断逻辑。同理,给出 exists 的实现:

def forAll0(p : A => Boolean) : Boolean = this match
  case Cons(h,t) => p(h()) && t().forAll0(p)
  case _ => true

def exists(p : A => Boolean) : Boolean = this match
  case Cons(h,t) => p(h()) || t().exists(p)
  case _ => false

takeWhileforAllexists 这三个方法的逻辑存在高度重合,它们可以被提炼为更加纯粹的 foldRight 方法。值得注意的是,f 函数对第二个参数也是非严格求值的。如果 f 不对它求值,那么则相当于终止了递归。

// z 是初始值。
def foldRight[B](z : => B)(f : (A,=> B) => B) : B = this match {
  case Cons(h,t) => f(h(),t().foldRight(z)(f))
  case _ => z
}

foldRight 是非常基础的抽象逻辑,因为它能够进一步复现 mapfilterappend 方法:

def exists0(p : A=> Boolean) : Boolean = foldRight(false)((a,b) => p(a) || b)
def takeWhile1(p :A => Boolean) : _Stream[A] = foldRight(empty){
  (h,t)=> if(p(h)) cons(h,t) else t
}

def map[B]( f : A => B) : _Stream[B] = foldRight(empty){
  (h,t)=> cons(f(h),t)
}

def filter(f : A => Boolean) : _Stream[A] = foldRight(empty){
  (h,t) => if f(h) then cons(h,t) else t
}

def append[B >: A](other : _Stream[B] ) : _Stream[B] = foldRight(other){cons(_,_)}

flatMap 则可以通过 foldRightappend 方法组合:

def flatMap[B](f : A => _Stream[B]) : _Stream[B] = foldRight(empty){
  (h,t) => f(h).append(t)
}

同时实现 flatMapmap_Stream[+A] 现在可以应用 For 表达式。如:

val value: _Stream[Int] = for {subStream <- s ; i <- subStream} yield i

到目前为止,我们应该理解 FP 范式中 "模块化编程" 的概念了。另一方面,Scala 提供的 For 表达式允许用户以命令式声明逻辑,但实际以递归执行 —— 这就是关注分离的宗旨。

以上所有的方法实现全部是 增量( incremental ) 的 —— 它们不会一次性生成所有答案,直到需要的元素被外部 "观测" 到。因为这种增量性质,我们可以一个接一个地调用函数但不对中间结果进行实例化,这个性质是共递归以及构建无限流的基础

Why Stream is lazy

下面只需跟踪一个流计算是如何交替执行 filtermap 的,就可以解释为什么 Stream 的各种转换只使用必要的空间,而不会创建大量的中间结果。

纯函数使得运算具备等式推理的能力,因此在这里大可放心地用返回值替换原本的函数调用,使得解析步骤看起来如同脱式运算 ( 见:Scala 函数式数据结构与递归的艺术 - 掘金 (juejin.cn) | 引用透明部分 )。整个过程完全没有使用任何显式的 for,while 这类循环,但是我们已经实现了一个等效于迭代的递归逻辑。出于这个原因, Stream 也被称描述为 "一等循环" ( first-class loop )。

// 下面所有的声明等价,均返回 List(12,14)
println(cons(1, cons(2, cons(3, cons(4, empty)))).map(_ + 10).filter(_ % 2 == 0).toList)

println(cons(11, cons(2, cons(3, cons(4, empty))).map(_ + 10)).filter(_ % 2 == 0).toList)

println(cons(12, cons(3, cons(4, empty)).map(_ + 10)).filter(_ % 2 == 0).toList)

println(12 :: cons(13, cons(4, empty).map(_ + 10)).filter(_ % 2 == 0).toList)

// 依照 map 的定义,匹配非 Cons 时默认返回 empty。
// 故 empty.map((a : Int) => a + 10) === Empty
println(12 :: cons(14, empty.map((a: Int) => a + 10)).filter(_ % 2 == 0).toList)

println(12 :: 14 :: empty.toList)

println(12 :: 14 :: Nil)

惰性列表的增量性质对内存使用有重要的影响。比如在这个例子中,数据 1113 在计算过程中被 filter 过滤掉了。在处理大量元素或者是处理大对象的情况下,及时地回收它们的内存可以降低程序对资源的占用。

关于提升内存效率的流式计算 ( streaming calculation ),尤其是涉及 I/O 的计算会在后面的章节提及。

无限流与共递归

可以利用增量的性质创建一个 无限流 ( infinite stream )。无限流的概念在 Java 8 中也有提及,它就像源源不断的水龙头,只要打开它就能不断地产生数据。下面是一个可生成无限个数字 1 的例子:

// 声明在 object _Stream 中。
val ones : _Stream[Int] = cons(1,ones)
// 根据统一访问原则,使用 val 还是 def 定义它并不重要。
// def ones1 : _Stream[Int] = cons(1,ones1)

使用无限流要格外谨慎,因为它会引发栈溢出异常。

println(ones)     // ok
println(ones.take(5).toList)    // ok
println(ones.toList)  // err : StackOverflowError

将这个 ones 函数稍微泛化一下,即可得到给定值的无限流生成器 constant

def constant[A](i : A) : _Stream[A] = cons(i,constant(i))
// println(constant("java").take(5).toList)

为了进一步发现规律,我们不妨再实现两个无限流:一个生成 _Stream(n, n+1, ...)from 方法和生成斐波那契数列无限流的 fib 方法:

def constant[A](i : A) : _Stream[A] = cons(i,constant(i))

def from(i : Int) :  _Stream[Int] = cons(i,constant(i + 1))

def fib(using init :(Int,Int) = (0,1)) : _Stream[Int] = cons(init._1,fib(using (init._2,init._1 + init._2)))

这三个方法的逻辑大部分是也是相同的,可进一步提取出 unfold 方法。

def unfold[A,S](z : S)(f : S=>Option[(A,S)]) : _Stream[A] = f(z) match 
  case Some((h,s)) => cons(h,unfold(s)(f))
  case None => empty  

类型参数 S 代表 状态 ( state )unfold 的递归伴随着 状态转移,我们在下一个专题继续讨论它。

unfold 是一个典型的 共递归 ( corecursive ) 函数 ( 或称之守护递归 guarded cursive )。在普通的递归中,传入的参数范围会逐步缩小,以达到收敛的目的;而共递归不设置收敛条件 ( 或称临界条件 ),因此它依赖惰性加载的数据结构,并构建出无限流。

def constant1[A](i: A): _Stream[A] = unfold(true)(Some(i, _))

def from1(i: Int): _Stream[Int] = unfold(true)(Some(i + 1, _))

def fib1: _Stream[Int] = unfold((0, 1)) { case (f0, f1) => Some((f0, (f1, f0 + f1))) }

unfold 不仅仅提供无限流。如果给定一个收敛的偏函数,它同样能去实现 maptaketakeWhile 方法。如:

// 定义在 enum _Stream[+A] 下
def map1[B]( f : A => B) : _Stream[B] = unfold(this){
  case Cons(h,t) => Some((f(h()),t()))
  case _ => None
}

def take1(n: Int): _Stream[A] = unfold(this,n){
  case (Cons(h,t), m) if m > 0 => Some((h(),(t(),m-1)))
  case _ => None
}

def takeWhile2(f : A => Boolean) : _Stream[A] = unfold(this){
  case Cons(h,t) if f(h()) => Some((h(),t()))
  case _ => None
}