Scala 纯函数式库设计:并行编程

452 阅读19分钟

回顾

前一部分我们学习了关于函数式设计的基本技能,包括基于递归的数据结构,异常包装,惰性求值,以及函数式的状态转移。

从本篇起是 《Functional Programming in Scala 》 的第二部分,原书通过三个例子 —— 并行计算,测试程序和文本解析来让我们了解设计函数式设计的过程,这篇笔记是第一个例子。了解这一章需要更多的时间成本,或许这不意味着能直接参透函数式设计的技能,但对于提升代码品味绝对是足够了。

设计库未必是一帆风顺的过程,比如在途中经常遇到的关乎 设计选择 的问题。通常情况下,我们不能决定什么样的设计应该是最合理的,这必须在一次次的小实验,原型测试中归纳出每个方案适用于哪些场景。同时,我们会在设计时候遇到反复的模式,之后的部分会讨论如何消除这种重复并提取出通用的模式。

从简单的例子开始

在前几章我们一直强调 "关注分离" 的概念 —— 用户的 声明 和实际 执行 是分开的。比如,在之前的惰性求值中,用户的声明是一回事,但数据流何时真正创建又是另一回事。类似地,我们的设计重点在于 描述计算 本身,而不在于实际的运行。有时用户所认为的纯函数在底层甚至是存在副作用的,但这并不妨碍它们在更抽象的层面去组合模块。

设计是一个循序渐进的过程。首先,挑选一个最基本的使用场景 ( use case ),其次,设计接口去契合这个场景,最后才是深入到每个接口具体是如何实现的。随着设计的不断进行,我们或许会推倒原有的实现,又或者是新增一些接口以完善更加丰富的功能,更重要的一点是:有些情况下,我们不必实现专用的功能,因为它们很可能是一些更通用函数 ( 组合子 ) 的结合。

文中的稍后部分会简要提到代数推理 ( algebraic reasoning ) 的过程。在特定的法则 ( law ) 下,API 的一些性质和功能可以用代数 ( algebra ) 来描述。

我们或许现在对创建一个并行计算只有一个模糊的想法。从一个简单的例子开始:利用分治 ( divide-and-conquer ) 法对一连串连续的序列进行求和。

def parallelSum(ints : List[Int]) : Int =
  if ints.length <= 1 then ints.headOption getOrElse 0
  else
    val (l,r) = ints.splitAt(ints.size / 2)
    parallelSum(l) + parallelSum(r)

使用 splitAt 方法对半分开传入的 ints,然后递归求和,返回它们的结果。重点是:左右子列表的计算是可以并行的。

在这个例子中使用并行计算会更慢,并行的花销远多余节省的时间。但在设计库的过程中,首先使用简单的例子有助于理清思路,而不是深陷于各种细枝末节。我们期望去阐明问题本身,而最好的方法就是从简单的例子入手,提取问题的共同点,再逐渐增加代码的复杂度。在函数式设计中,构建出一个可组合的核心数据类型,然后用它解决实际问题才是我们的终极目的。

本章不会基于 scala.concurrent.Future 库来组合并行计算,而是从稍底层的视角设计一套自己的 API。至于更底层的东西,如线程调度,本文使用一些 Java 工具来完成。假定有一个这样的容器 Par[A] ,它具备下面的方法:

case class Par[A](a : A)
object Par :
  // 将一个普通值 A 提升为 Par[A]。
  def unit[A](a : => A) : Par[A] = ???
  // 从一个并行计算中得到结果。
  def get[A](a : Par[A]) : A = ???

我们目前希望 unit 方法接收一个未求值的 A,返回的结果在独立的线程中求值,然后通过 get 方法提取它 ( 这可能会引起阻塞 )。不必过早地纠结具体实现是什么,现在暂时将它们看作是黑盒,只需要通过方法签名明确输入和输出。更新版的 parallelSum 方法如下:

def parallelSum2(ints : List[Int]) : Int =
  if ints.length <=1 then ints.headOption getOrElse 0
  else
    val (l,r) = ints splitAt (ints.size / 2)
    val parLeft: Par[Int] = Par.unit {parallelSum2(l)}
    val parRight: Par[Int] = Par.unit {parallelSum2(r)}
    Par.get(parLeft) + Par.get(parRight)

直接使用 Java Thread 的问题

我们首先排除基于 java.lang.Thread 的方案。下面是一部分相关的 API:

// line: 699
public synchronized void start(){/*...*/}
// line: 746
public void run() {/*..*/}

Java Thread 的方法都没有返回值。想要从控制抽象 Runnable 当中获取有用的信息,就必须以副作用的形式介入其中。我们必须要关注每个 run 方法的内部行为,而这显然不利于函数组合。

直接使用 java.lang.Thread 实现 Par[A] 还有一个缺点:Java 会直接创建一个对应操作系统的线程,而这是十分稀缺的资源。最好是创建一些 "逻辑线程",然后将它们映射到少量的系统线程上,比如,从 java.concurrent.Executors 那里可以获取多种模式的 Java 线程池。

为避免混淆,后文的 "逻辑线程",使用 worker 来称呼。不同的 worker 可以复用 Java 线程,在本篇中可以不去关注线程分配的细节。

在后文,我们会介绍如何使用 java.util.concurrent.Future ( JDK 5 ) 和 ExecutorService 来提交异步任务。见:future.get 方法阻塞问题的解决,实现按照任务完成的先后顺序获取任务的结果_傅里叶、的博客-CSDN博客

// java.util.concurrent.Future
// line: 151
 V get() throws InterruptedException, ExecutionException;

需要注意的是,Java Future 虽然提供 get 方法取出异步计算的结果,但该方法会通过阻塞的方式获取值。

JDK 8 之后提供了CompletableFuture,它提供了 thenApplythenCombineapplyToEither 等方法描述并组合 Java Future 任务。有关于 Scala 原生库的 Future,见:Scala + Future 实现异步编程 - 掘金 (juejin.cn) [ 文章 1 ]| 基于组合 Future 的并行任务流 - 掘金 (juejin.cn) [ 文章 2 ]

内联代码的性能问题

是让 unit 在 worker 中立即求值,还是等到 get 被调用时再求值呢?先把这个问题搁置到一边,因为我们目前不明确严格 / 惰性计算会有什么优势。

需要注意,Scala 严格地按照从左到右的顺序对表达式的每一项求值。如果 unit 是延迟计算的,那么一旦调用 get,程序会临时衍生 ( spawn ) 出并行计算,并在创建下一个并行计算之前必须执行完当前的任务,这样的计算实际上是串行的

典型的案例是 Scala 原生库中的 Future。在文章 2 的示例中,for 表达式内部创建 Future 块会导致任务串行化执行。

/*
 [基于组合 Future 的并行任务流 - 掘金 (juejin.cn)](https://juejin.cn/post/7077844813965426702)
*/
val futures : Future[Int] = for {
    l1            <- Future {Thread.sleep(100);1}
    l2            <- Future {Thread.sleep(200);2}
    l3            <- Future {Thread.sleep(300);3}
    l4            <- Future {Thread.sleep(400);4}
} yield l1 + l2 + l3 + l4
// 程序至少需要 1s 的时间去运行,因此这段代码会抛出异常。
val r2 = Await.result(futures, .9 second)

类似地,我们如果将代码中的 parLeftparRight 替换成它们原本的定义:

Par.get(Par.unit {parallelSum2(l)}) + Par.get(Par.unit {parallelSum2(r)})

直接内联两个 unit 之后,程序将无法获得并行,因为 + 号前面的表达式总是先计算。由此可见,unit 方法存在着副作用,而这个副作用仅仅和 get 相关:虽然 unit 执行的是异步计算,但是却在阻塞式的等待过程中暴露了副作用。

想要获得最大程序的并行,必须令两边 unit 同时对参数求值并立刻返回。另一方面,我们必须避免在并行任务的过程中调用 get

组合并行计算

如果不调用 get 来立刻求值,那么函数 parallelSum 的返回值就必须使用 Par[Int] 来代替之前的 Int。依照这个思想,我们首先构思出 map2 的函数签名:

def map2[A,B,C](a : Par[A],b : Par[B])(f : (A,B) => C) : Par[C] = ???

用它来实现第三版的 parallelSum 函数:

def parallelSum3(ints : List[Int]) : Par[Int] =
  if ints.length <=1 then unit {ints.headOption getOrElse 0}
  else
    val (l,r) = ints splitAt (ints.size / 2)
    map2(parallelSum3(l),parallelSum3(r))(_ + _)

map2 是惰性的,该函数的最终产出并不是值类型 A,而是 Par[A]。换句话说,map2 虽然会按照参数从左到右的顺序逐渐展开计算路径 ( 或者说构建对计算的 描述 ),但直到调用 get 时才会对这个计算 强制求值

List(1,2,3,4) 为测试案例,map2 最终会展开为下文所示的树形计算路径:

map2(
	map2(
		unit(1),
		unit(2)
	)(_ + _),
	map2(
		unit(3),
		unit(4)
	)(_ + _)
)(_ + _)

惰性计算的 map2 赋予了左右参数均等的计算机会。

显性分流

观察下面的计算:

Par.map2(Par.unit(1), Par.unit(2))(_ + _)

很明显,整合两个字面量完全没必要再另启 worker。我们引入另一个 fork 函数来明确表示将某个异步计算分流 ( forked-off ) 到另一个 worker 执行:

def fork[A](a : => Par[A]) : Par[A]

这样,unit 的职责就可以变得更加纯粹了:将一个字面量值 a 提升为一个 Par[A],在 当前 的线程中返回即时结果,但不进行分流。此时 unit 即便是 严格 ( strict ) 的也无妨,因为它现在只用来包装字面量。

def unit[A](a  : A) : Par[A]

forkunit 那样返回 Par[A] ,因此我们并不需要再去修改 map2 函数。现在,我们将 "归并" 和 "分流" 这两个概念解耦了:

  1. map2 仅表示将两个 Par[A] 任务进行合并。
  2. map2 传入的两个 Par[A] 独立决定是否分流。

fork 应当是在另一个 worker 中被立即求值呢?还是等到 get 被调用之后呢 ( 延迟计算 ) ?从实现来看,答案无疑是后者。这样的问题要从两者的职责分配角度分析:

  1. fork 表示一个分流的计算行为。
  2. get 获取一个 Par[A] 的计算结果。

如果 fork 只是简单地持有计算任务,那么它就可以不再去关注线程实现的具体细节,这实际上只是给需要分流的任务做一个标记而已。这样看来,worker 相关的工作交给 get 显得更加合理。现在,重新将 get 函数命名为 run,以表明这里才是计算实际发生的地方:

def run[A](a : Par[A]) : A = ???

Par[A] 现在仅仅作为一个值的容器,或者称之为我们并行库的 一等对象 ( first-class ) ,unit 函数则相当于一个 Par[A] 的简单生成器。

实现并行

在确定了 Par[A] 应当具备的基本行为之后,下一步就是动手实现了。我们明确了 run 需要以某种异步的形式完成任务,这里使用 java.util.concurrent.Future ( JDK 5 ) 和 ExecutorService 来实现。

ExecutorService 允许提交 submit 一个 Callable[A] 的参数,这是一个 Java 函数式接口。在 Scala 2.12 版本后,Java 的函数式接口可以被视作 Abstract Simple Method ( 简称 ASM ),因此调用时可以直接传入 () => A 的函数替代 Callable[A]。然而,submit 是一个重载方法,编译器有时可能会无法区分一个 ASM 实际指代 Callable[Unit] 还是 Runnable

Java Future 的 get 方法会阻塞式地等待返回结果,这一点前文已经说过了。除此之外,用户可以自定义一些额外的性质,比如阻塞到一定时间之后抛出异常等等。现在假设 run 函数可以访问一个 ExecutorService 来调度任务,以此来推断 Par[A] 具体应该是什么样子:

def run[A](s : ExecutorService)(a : Par[A]) : A = ???

最直接的形式,将 Par[A] 设置为一个 (ExecutorService) => Future[A] 的类型。

type Par[A] = ExecutorService => Future[A]
def run[A](s : ExecutorService)(a : Par[A]) : Future[A] = a(s)

现在,Par[A] 是一个 函数,它只在接受外界传入的 ExecutorService 之后才会被驱动生成一个 Future。

完善 API

我们已经描述了并行计算的草图,现在是时候实现它们了。

object Par :

  type Par[A] = ExecutorService => Future[A]

  // unit 返回一个即时结果,传入 exe 只是为了让 unit 执行。
  def unit[A](a : A) : Par[A] = (exe : ExecutorService) => UnitFuture(a)

  // 暂不考虑 Java Future 的其它拓展。
  private case class UnitFuture[A](get : A) extends Future[A] :
    override def cancel(mayInterruptIfRunning: Boolean): Boolean = false // not important
    override def isCancelled: Boolean = false // not important
    override def isDone: Boolean = true // not important
    override def get(timeout: Long, unit: TimeUnit): A = get

  def map2[A,B,C](a : Par[A],b : Par[B])( f : (A,B) => C) : Par[C] =
    (exe : ExecutorService) =>
      // 不要内联 Future,否则会影响性能。
      val af: Future[A] = a(exe)
      val bf: Future[B] = b(exe)
      UnitFuture(f(af.get,bf.get))

  def run[A](s : ExecutorService)(a : Par[A]) : Future[A] = a(s)
  def fork[A]( a  : => Par[A]) : Par[A] =
    // Callable 是一个 Java 中定义的函数式接口,在 Scala 中被称之为 Abstract Simple Method.
    exe => exe.submit(()=> a(exe).get)

需要稍微留意 fork 方法。最简单又自然的想法是,向这个调度器内直接提交 Par[A] 容器 a 就可以了。

def parallelSum4(list : List[Int]) : Par[Int] = {

  if list.length <=1 then unit {
    println(Thread.currentThread().getName)
    list.headOption getOrElse 0
  }
  else
    val (l,r) = list splitAt (list.length / 2)
    map2(
      // always fork  
      Par.fork(parallelSumS(l)),
      Par.fork(parallelSumS(r))
    ){ (a,b) => a + b }
}

可以设计一个简单的对照试验来验证并行版本的性能。考虑到纯粹的加和计算实在是太简单了,不妨再插入一段 Thread.sleep(delay) 来假设每一次独立的计算都需要至少 delay 左右的延时 :

def parallelSum4(list : List[Int]) : Par[Int] = {
  if list.length <=1 then unit {list.headOption getOrElse 0}
  else
    val (l,r) = list splitAt (list.length / 2)
    map2(
      Par.fork(parallelSumS(l)),
      Par.fork(parallelSumS(r))
    ){(a,b) => {Thread.sleep(50);a+b}}
}
/** 串行版本 */
def sumByFor(list : List[Int]) : Int = {
  var sum = 0
  for(i <- list) {Thread.sleep(50);sum = sum + i}
  sum
}

下面是不同变量下的测试数据 ( parallelSum4 使用 forJoin 线程池,采取默认配置 )

delaydata sizeparallelserial
50 ms501100 ms3100 ms
10 ms50300 ms700 ms
5 ms50130 ms290 ms
5 ms100252 ms568 ms
5 ms10001000 ms5600 ms
01000360 ms4 ms

测试的机器配置不同,时延会有所不同。但显而易见的是,只有每个 worker 的计算量越复杂 ( 基于时延的评判标准 ),数据量越大时,并行计算才有意义,性能优势愈加明显。

Fixed 线程池的死锁问题

需要指出的是,当使用 固定线程数的线程池 ( 比如 FixedThreadPool ) 时,基于目前 fork 函数实现的 parallelSum4 会产生 死锁。要了解 Java 的四个基本线程池,见:Java 四种线程池 - 博客园 (cnblogs.com)。一个最简单却能引发 BUG 的例子:在单线程的环境下,parallelSum4 一旦处理三个及以上元素的列表,就会出现问题:

// 死锁
val testList = (1 to 3).toList
parallelSum4(testList)(Executors.newFixedThreadPool(1)).get()

跟踪这个例子的计算步骤来分析上段程序发生死锁的过程:

  1. 在首次执行时,List(1,2,3) 被 main 线程分成 List(1)List(2,3)。随后,main 线程阻塞,并准备使用线程池 exe fork 出其它线程递归处理这两个子任务。这里的线程池只有一个工作线程,不妨称它 Thread-1。
  2. List(1) 首先被 fork 并计算。它向 Thread-1 返回一个即时结果,因此不会引起阻塞。
  3. List(2,3) 被 Thread-1 继续递归处理 —— Thread-1 对 List(1,2) 切片后,准备 fork 出线程池内其它可用的工作线程继续处理 List(2)List(3)
  4. 问题出现在这里。由于没有其它的工作线程,Thread-1 无法执行 fork,未完成的任务和未启动的任务互相等待,死锁在此发生了。

完成这个任务至少要两个线程,此时的计算过程是:

  1. 在首次执行时,List(1,2,3) 被 main 线程分成 List(1)List(2,3)。随后,main 线程阻塞。和上次不一样的是,这里的线程池有两个工作线程,标记为 Thread-1 和 Thread-2。
  2. List(1) 首先被 fork 计算。它向 Thread-1 返回一个即时结果,因此不会引起阻塞。
  3. List(2,3) 则被分配给 Thread-2 处理 —— Thread-2 对 List(1,2) 切片后,准备 fork 出线程池内其它可用的工作线程继续处理 List(2)List(3)
  4. Thread-2 陷入阻塞,重新可用的 Thread-1 串行处理剩下的 List(2)List(3) 并返回。
  5. Thread-2 和 main 线程依次完成加和并返回。

基于目前的 fork 实现,可以证明任何固定大小的线程池都会产生死锁。我们有可能通过修复 fork 的方式来避免死锁吗?看看下面的实现:

def fork[A]( a  : => Par[A]) : Par[A] = exe => a(exe)

这无疑可以避免死锁,但不是解决办法。如此设计不会分流出新的线程,这会使得整个任务都在 main 线程内串行完成。尽管如此,它这个还是一个有用的组合子,它使得计算被推迟到被真正需要的那一刻。不如将这样的组合子改名成 delay

def delay[A](a : => Par[A]) : Par[A] = exe => a(exe)

这里搁置关于 fork 的 Bug,将重点放回到函数式设计的过程。

使用 map2 实现 map

假设现在有一个 Par[List[Int]] 表示一个返回 List[Int] 的并行计算,比如说排序:

def sortPar(parList : Par[List[Int]]) : Par[List[Int]] = ???

我们目前能用的组合子只有 map2。如果要是有一个 map 组合子可以接受用于排序的 f : List[Int] => List[Int] 然后并行地执行就好了。实际上,map 完全可以通过 map2 来实现:只将 parList 传给 map2 的其中一侧,至于另一侧的参数则直接忽略:

def map[A,B](pa : Par[A])(f : A => B ) : Par[B] = map2(pa,unit(()))((a,_) => f(a))

这种复用套路在设计库内很常见:看似是一个基础的函数实际上却是由另一个更基础的函数完成的。现在,sortPar 可以这样简单地实现:

def sortPar(parList : Par[List[Int]]) : Par[List[Int]] = map(parList)(_.sorted)

在基于已有的实现下,我们不费力地将原有的排序 sorted 方法变为了并行模式,并且完全不用关注在底层各个 worker 如何完成任务 ( 尽管在这个例子中,并行排序的做法并不是更好的 )。

进一步迭代

下面再实现一些更高级的功能。我们不满足于只使用 map2 组合两个并行计算,现在还想组合 N 个并行计算。等待计算的多个初值 A 构成了一个列表 List[A],最终计算的结果是另一个类型 List[B]

暂且将这个函数命名为 parMap 。它的函数签名应该是:

def parMap[A,B](pars : List[A])(f : A => B) : Par[List[B]] = ???

首先,创建一个 asyncFunc 函数把一个普通的 f 提升为一个异步计算:

def lazyUnit[A](a : A) : Par[A] = fork(unit(a))
// or a => fork(unit(f(a)))
def asyncFunc[A,B](f : A=> B) : A => Par[B] = a => lazyUnit(f(a))

接下来使用上文的 map 方法将整个元素序列全部转换成异步任务,这会得到一个 List[Par[B]] 类型。显然,若要返回一个 Par[List[B]] 类型,还需要实现用于翻转的 sequence 函数。

def sequence[A](pars : List[Par[A]]) : Par[List[A]]

有了在前几个章节中实现 sequence 的经验,下面应该算是最容易想到且实现的方式了:

def sequence[A](pars : List[Par[A]]) : Par[List[A]] =
  pars.foldLeft(unit(List.empty[A])){
    (acc,b) => map2(acc,b)(_ :+ _)
  }

目前的 sequence 会不断地嵌套创建 map2 计算路径,因此它是串行的,不妨先命名为 sequenceNotBalanced。参考前文例子中结合二分法实现的 parallelSum,下面给出两侧均衡的 sequence 方法:

def BalancedSequence[A](idxPars: List[Par[A]]) : Par[List[A]] =
  if idxPars.isEmpty then unit {List.empty[A]}
  else if idxPars.size == 1 then map(idxPars.head)(List(_))
  else
    val (l,r) = idxPars.splitAt (idxPars.length / 2)
    map2(BalancedSequence(l),BalancedSequence(r))(_ ++ _)

def sequence[A](pars : List[Par[A]]) : Par[List[A]] = map(BalancedSequence(pars))(l => l)

经过验证,在少量简单数据的情形,串行的 sequenceNotBalanced 效率更高。在大量数据的情况,并行的 sequence 效率更高。同时,sequence 将计算路径的深度从 n 折叠成 log2(n),因此它更不容易出现栈溢出的问题。

val n = 10000
// 快速生成 List[Par[Int]].
val testPars: List[Par[Int]] = Random.shuffle((1 to n ).toList).map {unit}
val a = sequence(testPars)(Executors.newWorkStealingPool(8))				// 100 ms
val b = sequenceNotBalanced(testPars)(Executors.newWorkStealingPool(8))		 // 700 ms

完善更通用的形式

函数式设计是一个反复的过程。从写下 API 到至少有一个原型实现,它都会不断被应用在越来越复杂的场景当中。有时候,我们又会发现某些场景中需要添加新的组合子。与其实现一个专用的功能,不妨先尝试着设计一个更通用的形式,然后看看它是否可以通过结合其它组合子的方式实现预期的功能。

比如,我们最开始期望的 choice 函数是一个基于 Boolean 结果来二选一执行的运算:

def choice[A](cond : Par[Boolean])(t : Par[A],f : Par[A]) : Par[A] = 
  if cond(exe).get then t(exe) else f(exe)

目前而言,一个直接的实现办法就是首先阻塞式地等待 cond 的结果,然后在这两个可能的并行计算当中选择一个执行。其实,完全可以选择先实现更通用的 choiceN 方法,然后将 choice 视作 choiceN 的一个特例。

def choiceN[A](n : Par[Int])(choices : List[Par[A]]) : Par[A] = exe =>
  val n_ : Int = n(exe).get
  Par.run(exe)(choices(n_))

再进一步,前置计算的判定结果未必要用 Boolean 或者是 Int 表示,可以将其泛化为 C 类型 ( 表示 Condition )。另外,所有后续的并行计算也不需要使用固定的 ListMap 或是其它什么表示,使用模式匹配来表明计算策略就可以了。

def choiceN[C,A](n : Par[C])(choices : PartialFunction[C,Par[A]]) : Par[A] = exe =>
  val n_ : C = n(exe).get
  Par.run(exe)(choices(n_))

现在可以将 choice 看作是 choiceN 的一种特殊情况:

def choice[A](cond : Par[Boolean])(t : Par[A],f : Par[A]) : Par[A] = exe =>
  Par.run(exe)(choiceN(cond){
    case true => t
    case false => f
  })

我们自认为从 choice 方法当中提炼了最简洁的模式,但还是要用批判的眼光重新审视一下。函数可能会被用在特定的情况和场景下,但是函数命名本身可以有更普遍的意义。

如果将 PartialFunction[C,Par[A]] 再特化成 C => Par[A] 类型呢?不难发现,choice 的本质就是 flatMap 组合子。

def flatMap[A,B](a : Par[A])(f : A => Par[B]) : Par[B] = exe =>
  val n_ : A = a(exe).get
  Par.run(exe)(f(n_))

比如,下面的代码表示了根据 Int 结果来打印不同的内容到控制台:

Par.flatMap(unit(1)){
  case 1 => unit {println("result is 1")}
  case 2 => unit {println("result is 2")}
}

虽然逻辑上没有问题,但是写起来十分别扭。我们更习惯于将 flatMap,以及前文的 runmap 等方法看作是一个中缀运算符,就像这样:

unit(1) flatMap {/**/}

通过隐式转换变换中缀运算

下面的实现基于 Scala 3。Scala 2 的隐式类 implicit class 也完全可以做到这一点,这里使用 opaque type 是为了不对外暴露 Par[A] 的真实类型。参见:Scala 3 新特性一览 - 掘金 (juejin.cn)

目前为止,所有的功能都是作为函数实现在单例对象 object Par 中的,因为 Par[A] 本身只是一个 Function1 的类型别称,无法直接为它添加方法。类似这样的问题可以使用隐式方法实现一个适配器,同时不用改动原有的内部实现。如:

object Par:
  opaque type Par[A] = ExecutorSerivce => Future[A]
  extension [A](ths : Par[A])
    infix def eqs(that : Par[A]): ExecutorService => Boolean = (exe : ExecutorService) => ths(exe).get == that(exe).get
    infix def map[B](f : A => B): Par[B] = Par.map$(ths)(f)
    infix def flatMap[B](f : A => Par[B]) : Par[B] = Par.flatMap$(ths)(f)
    infix def run(exe : ExecutorService) : Future[A] = ths(exe)

  /** other functions */

  private[this] def map$[A,B](pa : Par[A])(f : A => B) : Par[B] = map2(pa,unit(()))((a,_) => f(a))
  def map2[A,B,C](a : Par[A],b : Par[B])( f : (A,B) => C) : Par[C] =
    (exe : ExecutorService) =>
      val af: Future[A] = a(exe)
      val bf: Future[B] = b(exe)
      UnitFuture(f(af.get,bf.get))

  private[this] def flatMap$[A,B](a : Par[A])(f : A => Par[B]) : Par[B] = exe =>
    val n_ : A = a(exe).get
    Par.run$(exe)(f(n_))
  
  private[this] def run$[A](s : ExecutorService)(a : Par[A]) : Future[A] = a(s)

尽管我们将 flatMapmap 等组合子包装成了中缀运算符,但是在 Scala 3 中,函数风格的调用也会被保留 ( Scala 2 的 implicit class 不支持这样的自动翻译 ):

import Par.*
// 会被翻译成下面的表示。
flatMap(unit(1)){
  case 1 => unit {println("result is 1")}
  case 2 => unit {println("result is 2")}
}

unit(1) flatMap {
  case 1 => unit {println("result is 1")}
  case 2 => unit {println("result is 2")}
}

Scala 3 新引入的关键字 opaque 会对外隐匿 Par[A] 的真实类型,换句话说,用户不能再将 Par[A] 看作是柯里化的函数调用了。如:

// 无法柯里化调用,因为现在的 Par[A] 隐藏真实类型。
parallelSum4(testList)(Executors.newWorkStealingPool())

这样设计的好处是用户只能在开发者限定的规则内使用 Par[A],坏处则可能是丧失一些灵活性。

parallelSum4(testList).run(Executors.newWorkStealingPool())

关于 opaque 的设计并不是必须的,一切都取决于设计者如何权衡。这里的 eqs 是由我们自定义的,用于比较两个并行任务相等性的方法,见下一节。

API 和代数

在前面的过程中,我们首先根据功能标注函数签名,然后按照它的返回值给出具体的实现。以这种方式工作,我们在设计时就可以摆脱繁文缛节,只需专注类型的串联。将 API 看作是一个代数 ( algebra ),并在此基础上制定一些 法则 ,以此来设计函数库。

法则的选择会限定什么操作会有什么样的含义,同时保证了我们所设计的 API 的性质是稳定的。

映射法则

首先,基于某种假设来编写一个恒等式并认为它应当成立。前文 eqs 的比较规则是:如果两个并行任务 Par[A] 最终返回的数值相同,则返回 true,反之为 false

val exe = Executors.newWorkStealingPool()

val eg1 = map(unit(1))(_ + 1) eqs unit(2)
assert(eg1(exe))   // true

验证这个例子成立,然后提取这个模式,我们便能得到其中的法则。函数 _ + 1unit(1) 映射为 unit(2),且断言成功。不妨将 _ + 1 泛化为 f,将值 1 泛化为 x,从中可以泛化出这样的法则:

val x = 1  //  x 可以是任意的值
val f: Int => Int = (x : Int) => x  // f 可以是任何的 Int 映射。
// 法则 1
val law1 = map(unit(x))(f) eqs unit(f(x))

从这个法则中可以发现,对于单个的 unit 而言,map 是一个 "多余" 的步骤。进一步,假设 y(x) = unit{x},我们还能证明下面的断言是成立的:

// u compose v == v andThen u
val law2  = map(map(y(x))(g))(f) eqs map(y(x))( f compose g ) 

这个推论来自 1989 年 Philip Wadler 的论文《 Theorems for free 》,名为免费定理,论文可以去参考 free.dvi (ttic.edu)。这也被成为 map 融合,这个定理可以作为一个优化来使用,思想是:与其用两个 worker 分别进行 gf 变换,不如用一个 worker 一次性执行 f compose g

分流法则

我们对分流制定一个强规则,即分流不应当影响计算的结果。

val law3 = fork(y(x)) eqs y(x)

这个性质的成立应当是显而易见的,fork 只不过是在一个主线程分离的线程当中异步完成了一模一样的运算。如果这一项法则不成立,那我们就必须得知道什么时候才是不改变含义的调用。

为什么关于代码的法则很重要

声明和证明 API 的性质不是一个常见的做法,至少我们在传统的编程中几乎不会考虑这样的问题。

在函数式编程中,我们能感受到:将公共的功能分解成更通用,可重用的组件是很容易的事情。副作用会破坏函数之间的可组合性,任何隐藏 ( 或不可见的,out of band ) 的假设或者行为都会妨碍我们将组件看作是一个黑盒,把它们组合起来是十分困难的。

比如,如果上述的分流法则不成立,那么意味着我们原本设计的很多组合子都存在着风险,通俗地说就是计算结果会依赖于线程池分配的计算次序,这会脱离开发者的掌控和期望。赋予 API 以代数的性质,在法则的限定下,它们应该会变得更有意义且有助于推导,同时也意味着 API 的用户可以将其视作是黑盒子。

目前的 fork 就是破坏法则的一个反例。当传入固定线程的 pool 时,fork(y(x)) eqs y(x) 就不再成立了。当我们发现这样的反例时,要么明确的声明法则成立的条件,要么通过修改内部实现来修复这个法则。原书提出了一个基于 Actor 的并行模式,它保证在固定线程的环境下,分流法则也不会被破坏。见:

fpinscala/Actor.scala at second-edition · fpinscala/fpinscala · GitHub

我们在以后会意识到,从不同的库之间分解出共同的模式,对函数式设计的能力至关重要。函数式编程写得越多,我们识别并提取模式的能力就越强,就像我们一开始其实并不知道 choice 本身是 flatMap 的一个特殊用法。同时,从实践的角度看,缩小基础 API 的范围是有好处的,尽可能地重用它们可以避免重复逻辑。