Scala +:类型推断,列表操作与 for loop

1,260 阅读21分钟

列表 List 是 Scala 开发中最常用的工具,而许多对列表的操作又可以提取成相似的结构:以某种形式对内部所有元素进行变换,筛选出满足条件的元素,规定某个方法组合内部元素等等 ..... Java 8 推出了流操作的概念,并允许以 Lambda 语法描述出这些抽象过程,它使得我们逐渐从原先的 命令式编程 ( Imperative ) 过渡到 声明式编程 ( Declarative ) 1

Scala 允许使用高阶操作符来更精简,直接地表达出对列表的抽象操作,但这个前提是:我们需要对 IDE 和 Scala 编译器的类型推断机制做一些了解,否则很难理解诸如 (_ + _) 或者是 (_ - _) 之类的缩写形式,要么就是写出一份语法检查通过却在编译过程出错的代码。

全文涉及的列表操作适用于 Scala 所有符合 “集合” 特征的数据结构,并且永远不会改变原有的状态,无论该集合可变还是非可变的。笔者这里以统一的不可变列表 List 为例子介绍。在之后的文章中,笔者还会介绍如何结合函数式编程的逻辑重现一部分高阶操作。

Scala 的数据结构可以分为两类:一个是严格求值和惰性求值。对于之前学习的 ListArrayMap 等 "所见即所得" 的数据集合都属于前者,而本章的 Scala 流和视图则属于后者。

另外一个重点是:在了解了一些重要的列表操作之后,我们有必要重新审视 Scala 中的 for 表达式,它提供了非常多不局限于 "遍历" 的功能,我们由此引申出 monad 算子的概念。

1. 编译器的类型推断机制

首先给定一个 Scala 泛型函数,它接收两个元素 ab ( 两者的类型也许一致,也许不一致 ),并通过提供的 op 函数将两者合并:

/*
	下面声明在 Java 程序中的等价写法和使用演示:
		
    Integer a = 100;
    Integer b = 200;
	
	BiFunction<T,U,R> 形容的是 <T,U> -> R 的二元操作函数,是一个函数式接口。
    public static BiFunction<Integer,Integer,Integer> sum = (a, b) -> a + b;

    public static <A, B> B reduce(A a, B b, BiFunction<A, B, B> op) {
        return op.apply(a, b);
    }
*/  
def reduce[A, B](a: A, b: B, op: (A, B) => B): B = op(a, b)

泛型 AB 是任意一个继承于 Any 而抽象于 Nothing 的类型。 在实际类型被推断出来之前,IDE 会以 NotInferredANotInferredB 的方式标记它们。另定义两个符合 (A,B)=> B 类型的函数变量:

val intsSum: (Int, Int) => Int = (a, b) => a + b
val concat: (Char, Char) => String = (a, b) => a.toString + b.toString

如果我们这样调用 reduce

reduce(1,1,concat)

很明显这是一个问题代码,因为我们正企图用一个 (Char,Char) => String 类型的函数处理两个 Int 类型的数值。但是这种写法逃过了 IDE ( 笔者基于 IntelliJ IDEA 的 Scala 插件开发项目 ) 的语法检测,因为直到这段代码被送入到 Scala 编译器之前,NotInferredANotInferredB 类型都没有被推导出来 ,这意味着它们从语义上讲可以是任意一个介于 AnyNothing 之间的类型。

我们总是希望某个语法问题如果能够提前被 IDE 纠正,就不要推迟到编译代码时再去解决。一种解决方式是:在调用泛型函数时提供准确的类型。如:

reduce[Int,Int](1,1,intsSum)

这样,IDE 就可以将 op : (NotinferredA,NotInferredB) => NotInferredB 推导为 op : (Int,Int) => Int 。不仅如此,它还严格地提前限制了允许接收的 ab 两个元素的类型。比如下面的调用会报语法错误,尽管这从类型推导的角度看来,传入的参数全都是合理的:

reduce[Int,Int]('a','b',concat)

1.1 柯里化在类型推导的应用

对于接收函数作为参数的高阶函数,笔者从习惯上会对其进行柯里化2,以便将操作数和匿名函数分开赋值 ( 稍后会介绍为什么这么做 )。以刚才的 reduce 函数为例:

def reduce[A, B](a: A, b: B)(op : A => B): B = op(a, b)

如果我们事先传入参数 ab ,IDE 就会具备足够的线索推导 op 实际的类型。这意味着它能够提供更严谨的语法检查,而不是直到编译期才提示错误。这里利用了柯里化的 "延迟确认" 特性,这里指代延迟确认 op 函数的类型:先让高阶函数 reduce 接收参数,然后再接收对应类型的函数 op 。当然,从 "存储运算结果" 的角度去理解这步柯里化也是合理的。

val curriedHof: ((Int, Int) => Int) => Int = reduce1(1,1)
println(curriedHof(intsSum))

我们不需要手动补齐高阶函数的类型,仅需要按下 alt + Enter 键让 IDE 完成这步操作即可。如果 IDE 最终仍无法推断某个 NotInferredX 类型,如果它出现在返回值的位置,则 IDE 会将它解析成 Nothing 类型,如果它出现在了参数列表的位置,则将它解析成 Any 类型 3

但是对我们而言,没有任何 "边界" 的 AnyNothing 通常来说都没有什么实际的作用,因此在最终的代码中不应该还遗留着未被推导出来的类型参数。

1.2 更加精简的函数声明

现在以被柯里化的 reduce 函数为例,我们可以在调用时直接传入匿名函数:

reduce(1,1)((a :Int,b : Int) => a + b)

鉴于 IDE 和 Scala 编译器现在仅通过前两个操作数就能够推断 op 的正确类型,我们再去强调参数 abInt 类型就显得多余了。现在可以省略掉匿名函数的参数类型声明部分:

reduce(1,1)((a,b) => a + b)

甚至是更加简略的方式:

reduce(1,1)(_ + _)

这里的两个下划线 _ 依次 指代 reduce 函数中的第一个操作数 a 和第二个操作数 b 。为了避免混淆,Scala 规定每一个下划线 _ 仅指代一个对应位置上的操作数。在底层,编译器会将这些下划线 _ 依次翻译成具名的变量 x$1x$2 .... x$n 并根据 "线索" 推断出它们的具体类型。

在某些情况下不能使用这种省略写法,比如:

reduce(1,1)((a :Int, b : Int) => b - a)
// 可以省略 a,b 的类型
reduce(1,1)((a,b) => b - a)
// 不能省略成 _ - _ 的形式!!!!!因为这表示:(x$1 : Int, x$2 : Int) => x$1 - x$2 .
reduce(1,1)(_ - _)     // 语法检查和编译都可通过,但是得不到预期的结果。

或者在匿名函数中复用了某个变量:

reduce(1,1)((a : Int, b : Int) => a + a + b)
// 可以省略 a,b 的类型
reduce(1,1)((a,b) => a + a + b)
// 不能省略成 _ + _ + _ 的形式!!!!!因为这表示:(x$1 : Any, x$2 : Any, x$3 : Any) => x$1 + x$2 + x$3
// 实际上,由于提供的参数类型不足,这导致了 IDE 根本无法分辨 x$1,x$2,x$3 是何种类型。因此 + 操作本身变成了 "非法" 操作。
reduce(1,1)(_ + _ + _)    
/*
	如果一定要将这段代码送入编译器编译,编译器也会提示无法分辨这几个变量的实际类型。
	Error:(xx, xx) missing parameter type for expanded function ((x$1, x$2, x$3) => x$1.$plus(x$2).$plus(x$3))
    reduce1(1,1)(_ + _ + _)
*/

现在,我们将 reduce 的两个参数列表反转过来:先接收 op 函数,然后再处理两个操作数:

def reduce_R[A,B](op : (A,B)=>B)(a : A,b : B) = op(a,b)
reduce_R(_ + _)(1,1)

由于此处过早地让编译器进行类型推断,因此编译器没有有效信息来展开 _ + _ 操作。这时我们需要主动提供类型:

reduce_R[Int,Int](_ > _)(1,1)

综上,首先确保编译器能进行正确的类型推断,其次是被缩写的匿名方法逻辑最好不要太复杂。满足这两个前提下,笔者才推荐使用下划线的简写形式。

现在,我们应该能够理解 Scala 是基于程序流进行类型推断的。这虽然带来了一部分局限性,但是在大部分情况都可以通过显式提供参数类型来解决。当我们对高阶函数进行柯里化时,正确的做法是 "提前赋值" 那些能提供实际类型的参数,而越依赖类型推导的参数,则越应该 "推迟赋值"。

1.3 将伴生对象泛化成类的实例

在类型推断机制下,调用函数时传入单例对象还可能会导致一些 "歧义" 发生,这里举一个例子。假定 ABAC 都是 A 的子类,而 AC 出于单例模式的考虑被声明为了单例对象 ( 在这里比称呼它 "伴生对象" 更加贴切 ) 。

刚才的 reduce 方法被迁移到了 A 的伴生对象内,为了避免混淆将原来的泛型 [A,B] 替换成了 [T,U]

trait A
class AB extends A
object AC extends A

object A{def reduce[T,U](t : T,u: U)(op : (T,U) => U):U = op(t,u)}

如果传入两个 AB 的实例,IDE 尚且能正确地推导出 op 的类型为 (AB,AB) => (AB) 类型。

A.reduce(new AB,new AB)(???)

但如果将 AC 做为参数传递到 reduce 函数内:

A.reduce(new AB,AC)(???)

由于 AC 是单例对象,IDE 和编译器会给出这样的类型推导:(AB,AC.type) => AC.type 。在 Scala 中,任何以 .type 为后缀标识符都指代伴生对象 ( 或称单例对象 ) 本身。

这种类型推导会导致期望的 op 函数的输出值被限定成 AC 这个单例对象本身,而不是 AB 或者是其它继承于 A 的子类型。通常情况下,这都不如把 AC 当作是 A 的实例更加灵活:显然 (AB,A) => A 表达出了更加泛化的逻辑。

第一种可靠的方法仍然是显式地提供类型参数:

f[AB,A](new AB,AC)(???) // op 被推导成泛化的 (AB,A) => A
f[A,A](new AB,AC)(???)  // op 被推导成更泛化的 (A,A) => A

第二种方法是在传入 AC 时将它指派成是 A 类型的一个实例。写法如下:

f(new AB,AC : A)(???)  // op 被推导成泛化的 (AB,A) => A

2. 映射

2.1 map / flatMap

map方法接收一个 A => B 的函数,并将列表的所有 A 类元素变换成 B 类的元素 ( 类型 AB 可以是一个类型 )。举一个简单的例子:将某个 List[Int] 内部的元素全部执行一次 +1 操作:

println(List(1, 2, 3, 4, 5) map (_ + 1))

flatMap 方法和 map 函数相比多了一步 flatten 的步骤。下面的代码块演示了将 List[String] 类型的句子拆分成 List[List[String]] 然后将其 "展平" 为 List[String] 的过程。

/*
  List("if you want to do something","just do it") =>
  List(List("if","you","want","to","do","something"),List("just","do","it")) =>
  List("if","you","want","to","do","something","just","do","it")  
*/
println(
    List("if you want to do something","just do it")
    .flatMap(_.split(" ").toList)
)

2.2 foreach

foreach 方法接收的是 A => Unit 函数,用于对元素做一些具备副作用的处理,比如打印,提交数据等等。

// foreach 后传递的是语句块:(x : Int) => { println(x) }
// 使用中缀表达式写法时应该将语句块括起来。
List(1,2,3,4) foreach {println(_)}

3. 过滤

3.1 filter / partition / distinct

filter方法需要一个这样的检验函数:A => Boolean,该方法会保留检验结果为 true 的对象,并剔除掉检验结果为 false 的元素。

println(
	List(1,2,3,4,5) filter(_ > 1)
    //List(2,3,4,5)
)

另外有去重方法 distinct 用于删除掉集合中重复的元素:

println(List(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4).distinct.mkString(","))
//List(1,2,3,4)

partition 方法是 filter 方法的一个引申,它会产生一个对偶列表:第一个是所有筛选条件为真的列表,另一个是筛选条件为假的列表。

// xs partition p   <=> (xs filter p(_), xs filter !p(_))
println(
    List(1,2,3,4,5) partition(_>3)
    // (List(4, 5),List(1, 2, 3))
)

3.2 find

find 方法会返回满足条件的第一个元素,它的返回值是 Option[A] ,当没有找到任何一个满足条件的元素时返回 None

println(
    List(1,2,3,4,5) find (_ == 4),   // Some(4)
    List(1,2,3,4,5) find (_ == 6)    // None
)

3.3 takeWhile / dropWhile / span

xs takeWhile p 表示:从头部开始对每个元素按照 p 条件进行检查,直到有一次不满足为止。此时,该方法截取出从头部到该位置之前的所有元素并返回 ( 相当于是截取前缀 );而 dropWhile 则相当于它的取反操作 ( 可以理解成是返回 takeWhile 方法 "遗留" 的后缀 )。

字符串可被认为是不可变列表的一种,在这里用它举一个例子。比如获取某个电子邮箱地址 @ 符号的前半部分:

println(
    "123456@net.com" takeWhile(!_.equals('@')), // "123456"
    "123456@net.com" dropWhile(!_.equals('@')),  // "@net.com"
    "123456@7890@net.com" takeWhile(!_.equals('@')), // "123456" 不截取后面的 @7890部分
    "123456@7890@net.com" dropWhile(!_.equals('@')),  // "@7890@net.com"
    "123456@net.com" filter(!_.equals('@'))      // "123456net.com"
)

spantakeWhiledropWhile 的组合方法,像 partition 一样返回二元组:

println(
    "123456@7890@net.com" span(!_.equals('@')) // (123456,@7890@net.com)
)

3.4 forall / exists

forall 可以用于检查列表 xs 内的元素是否全部满足某一个条件,若是,则返回 true ,否则返回 false 。而 exists 的要求更加 "宽松",只要求 xs 内至少有一个元素满足条件。

println(
    List(1,2,3,4,5) forall (_ > 4), // 所有的元素都大于 4 吗? => false.
    List(1,2,3,4,5) exists (_ > 3)  // 至少有一个元素大于 3 吗? => true.
)

4. 化简

4.1 fold 折叠

给定一个 List[A] 类型的元素并提供 B 类型的初始值和折叠操作 opfold 可以将列表内的所有元素合并为一个 B 类型。折叠分为左折叠和右折叠:

左折叠代表从左向右折叠,计算开始时,初始元素会放在 list 的最左边。左折叠要求提供一个 (B,A)=> B 类型的函数lop ,计算结果向左聚合;右折叠代表从右向左折叠,初始元素会放在 list 的最右边。右折叠要求提供一个 (A,B)=> B 类型的函数rop ,计算结果向右聚合。

比如说对 List(1,2,3) 的折叠方式:

List(1,2,3).foldLeft(0)(lop) // =>  op(op(op(0,1),2),3)
List(1,2,3) foldRight(0)(rop) // => op(1,op(2,op(3,0)))

假定 loprop 都代表减法操作 (x : Int , y : Int) => x - y,下面是测试代码:

List(1,2,3).foldLeft(0)(_ - _)   //  -6
List(1,2,3).foldLRight(0)(_ - _) //  2

两者的计算结果是不同的。计算过程用图形化的表述就是:

可以使用(a /: list)(op)来代表左折叠:list.foldLeft(a)(op),使用(list :\ a)(op)来代表右折叠:list.foldRight(a)(op)

(init /: listBuffer)(sub) //== listBuffer.foldLeft(init)(sub)
(listBuffer :\ init)(sub) //== listBuffer.foldRight(init)(sub)

E.g 利用 fold 反转字符串

fold 折叠操作是一个基础的高阶操作,更加抽象的 map , filter 都可以看作是 fold 操作的特化版本。这里我们将利用左折叠或者右折叠实现对字符串的反转操作。具体的思路是,不断地取出原字符串的头部并放到另一个字符串中,以 "FILO" 的方式实现字符串的倒转。

def reverseString(str: String): String = {
    val f: (String, Char) => String = (str: String, head: Char) => head.toString ++ str
    str.foldLeft("")(f)
    // 这是一种更简短,更激进的写法。s 和 h 的实际类型将由 Scala 编译器自动推断,因此不再需要显式地声明类型。
    // str.foldLeft("")((s,h)=> h ++ s)
}

在未来的 Scala FP :: 函数式数据结构章节,我们会再一次讨论有关该内容的部分。

4.3 reduce 规约

规约也用于合并列表 List[A] 内的元素,但是不需要提供初始值,因此合并的结果也是 A 类型。规约操作也分为左规约和右规约,逻辑和折叠操作类似。

println(
    List(1,2,3) reduce(_ - _),    // (1-2)-3 = -4
    List(1,2,3) reduceLeft(_ - _), // (1-2)-3 = -4
    List(1,2,3) reduceRight(_ - _) // 1-(2-3) = 2
)

4.4 scan 扫描

scan 从功能上来看和 fold 差不多,但是 fold 方法只返回一个合并后的最终结果,而 scan 则会收集每一步折叠的计算结果到列表中返回。scan 也分为左扫描 scanLeft 和右扫描 scanRight

通过 scan 方法,有助于我们理解 fold 的计算过程。

val list =  List(1,2,3)

println("scan left:→")
//↓ 0
//↓ 0 - 1 = -1
//↓ -1 - 2 = -3
//↓ -3 - 3 = -6
//(0, -1, -3, -6)
list.scanLeft(0)(_-_).foreach(i => print(i.toString +"\t"))

println("\nscan right:←")
//↑ 1 - (-1) = 2 
//↑ 2 - 3 = -1
//↑ 3 - 0 = 3
//↑ 0
//(2, -1, 3, 0)
list.scanRight(0)(_-_).foreach(i => print(i.toString +"\t"))

5. 拉链

5.1 zip 方法

假如说现在有两个列表:("Java","Go","Python")(2018,2019,2020)。使用拉链 zip 方法,可以将这两个列表"缝合"成一个二元组 ( 或称对偶元组,可以放入 Map 中 ) 列表:(("Java",2018),("Go","2019"),("Python","2020"))

val strings: List[String] = List[String]("Java", "Go", "Python")
val ints: List[Int] = List[Int](2018, 2019, 2020)

//(Java,2018),(Go,2019),(Python,2020)
println(strings.zip(ints).mkString(","))

进行拉链操作的两个列表长度要一致,否则会造成数据的丢失

val strings: ListBuffer[String] = ListBuffer[String]("Java", "Go", "Python")
val ints: ListBuffer[Int] = ListBuffer[Int](2018, 2019)

//(Java,2018),(Go,2019),而"Python"会被丢弃掉。
println(strings.zip(ints).mkString(","))

6. 迭代器

Scala 保留了 Iterator,用法和 Java 的迭代器没有太大区别。在调用 next 方法之前,Iterator 总是默认指向序列的首个位置。

不过,Scala 的迭代器额外支持使用映射,过滤,化简等方法来返回新的数据,但这个作用域是从迭代器当前的位置开始直到序列的末尾。利用这个特性,我们可以像 "游标" 一样使用它:比如实现对序列的部分转换,需要从哪里开始,就通过 next 将迭代器移动到哪里。

val list: Range.Inclusive = 1 to 10
val iterator: Iterator[Int] = list.iterator

// iterator 最终会将移动到 6 的位置。
while(iterator.next() < 5) Unit

// 这个遍历操作将只从 6 开始。
iterator.foreach(println(_))

// 这个 map 操作生成了 7 8 9 10 11。
iterator.map(_+1).toList

注意,每一步 next 操作都会永久改变该迭代器记录的位置。当它已经被用作一次完整的遍历后 ( 即 Iterator.hasNext为空 ) 就失去了作用。如果只是对序列的简单遍历,我们通常都不会考虑到使用迭代器,因为 Scala 提供了很多其它的途径来实现,比如仅调用 foreach 方法。

while (iterator.hasNext) {
	println(iterator.next)
}

7. Stream 流

掌握 Java 8 语法特性的同学,对于数据流 ( Stream ) 的概念一定不会陌生。Scala 所有集合支持使用 toStream 方法转换成数据流。数据流支持执行映射,过滤,化简等方法,并转换成另一个数据流。使用流的一大目的是:利用其自身的延迟加载特性来避免了程序耗费资源生成大量且无用的数据。

相比于实际存放在内存中,且元素数量有限的集而言,流表述了有规律的,无穷数量的元素序列。比如说等差数列:1,2,3,4,.....,按照更 "更数学" 的说法,描述这个等差数列只需要两个元素:首项 a = 1 ,公差 d = 1。等比数列同理。

Scala 流是典型的函数式数据结构的缩影,我们以后再去讨论 "何为函数式数据结构"。

7.1 通过递归创建一个无限流

用递归来描述数据流最合适不过了。写法如下:

def generate(init: Int): Stream[Int] = int #:: generate(init + 1)

当创建一个流时,将 init 作为流的首项,后面的元素会之后调用流时递归生成。通常的用法是:通过 take 方法生成一段前缀序列,然后使用 toList 方法将抽象的数据流转化成实际的数据结构。

val stream: Stream[Int] = generate(1)
println(stream.take(10).toList)
println(stream.take(10).toList)

take 方法是不附带任何副作用的调用,因此它永远不会改变流自身的状态。我们在这里调用了两次 stream.take(10).toList ,但结果始终是 1 to 10 的序列,而不是在下一次调用take 方法时得到 11 to 20 ,或者是 21 to 30

我们称这种特性为引用透明:一切表达式都可以简单地使用它的结果代替。在上述代码中,我们使用 val 保证了 stream 在任何时刻都是引用不变的。换句话说,在程序中任何出现了 stream.take(10).toList 的地方,都可以使用 List(1,2,3,...,10) 对其替换。

满足引用透明的程序能够建立起一套代换模型,换句话说就是程序具备可推理的性质。

7.2 将已有序列转换为有限流

有限数量的数据集合能够通过 toStream 方法转为有限流,因此这里支持直接调用 last 方法来获取最后一个元素。

val stream: Stream[Int] = (1 to 1000).toStream
println(stream.take(10).toList)
println(stream.last)

不过,对于刚才使用迭代方式生成的无限流来说,直接调用 last 方法会导致程序陷入无限递归中。

8. 视图

除了通过 toStreamlazy 等方式来实现数据懒加载之外,对于任意一个 C[A] 类型 ( C 是集合类型,A 是集合内的元素类型),可以通过视图 view 的方式将原有的数据转换为惰性计算的 SeqView[A,C[A]] 类型。这对于 Map 映射也同样适用。

val ints = List(1,2,3,4,5)
println(ints)
//List(1,2,3,4,5)

val view: SeqView[Int, List[Int]] = List(1,2,3,4,5).view 
println(view)
//SeqView(...)

如果我们直接打印 view ,屏幕上输出的将是 SeqView(...) 而非期望的 1,2,...,5 这些数据。原因是:直到对 view 表示的抽象数据做一步终止操作之前,所有的中间操作都只能得到另一个惰性计算的视图4 ,这个规则同样适用于刚才介绍的 Scala 流。在视图不断转换为视图的过程中,所有的数据都不会被真正计算。

下面的例子是一个有力的证据:我们在中间操作 map 中掺杂一些副作用并执行代码,可以发现 println(x) 根本没有被执行。

val view: SeqView[Int, List[Int]] = List(1,2,3,4,5).view
view.map(x => {
  println(x)
  x
})

而终止操作,或者将视图重新转换为序列等行为将会迫使程序对视图数据进行计算:

val view: SeqView[Int, List[Int]] = List(1,2,3,4,5).view 

// 延迟操作的例子1 
view.map(x => {
  println(x)
  x
}).toList

// 延迟操作的例子2 
view foreach {println(_)}

惰性计算在涉及大量计算的场合通常有更好的性能。但是,如果惰性计算的数据流始终不通过终止操作来收束,那么所有的中间变换全部都是没有意义的,程序就永远都不会实际执行中间操作。比如,Spark - Streaming 框架对数据流处理有严格要求:程序一定要通过 foreach 算子以副作用的形式将数据作持久化处理,否则就抛出异常。

9. 重返 for 表达式

本节正式地讨论 Scala for 表达式背后的机制。Scala 中任何复杂的 for 表达式都可以用三个高阶函数 mapflatMapwithFilter 来表示。或者可以反过来想,有些 mapflatMapwithFilter 的组合可以直接使用一个 for 表达式简单表述出来。

首先是第一个简单的 for 表达式。其中 x 代表变量,而 x <- expr1 则被称之为 for 表达式的一个生成器。它可以被翻译为:

for (x <- expr1 ) yield expr2 // <=> expr1.map(x => expr2) 两者等效。

举个例子,下面两种表示法是等效的。

//    expr1 := 1 to 5
//    expr2 := x + 1
//
//    yield 会收集语句块内的返回值。
//    for ( x <- 1 to 5) yield {
//      x + 1
//    }
(1 to 5).map(x => x + 1)

接下来就是附带条件守卫的 for 表达式:

for (x <- expr1 if expr2) yield expr3 // <=> expr1.withFilter(expr2) yield expr3

expr1.withFilter(expr2) 看作是一个整体,根据第一个推导式,还可以进一步将其推导为:

for (x <- expr1 if expr2) yield expr3 // <=> expr1.withFitler(expr2) map (x => expr3)

我们还可以轻松推导出当存在多个条件守卫时的情形。下面是另一个实例:

/*        
   这个 for 循环将将奇数乘以2.
   expr1 := 1 to 5
   expr2 := x % 2 == 1
   expr3 := 2 * x
 */
for (x <- 1 to 5 if x % 2 == 1) yield { 2 * x }
(1 to 5) withFilter (_ % 2 == 1) map (2 * _)

下面是嵌套 for 循环的情况:

for ( x <- expr1; y <- expr2) yield expr3
			// expr1.flatmap(x => for(y <- expr2) yield expr3) 
			// expr1.flatmap(x => expr2.map(y => expr3))

下面再用简单的一个实例说明。比如 "展平" 一个二维数组:

val matrix: Array[Array[Int]] =Array[Array[Int]](
    Array(2,3,4),
    Array(5,6,7)
)

for ( line <- matrix; e <- line) yield e

/*
    expr1 := matrix
    expr2 := line
    expr3 := e
 */

matrix.flatMap( line => for (e <- line) yield e)
matrix.flatMap(line => line.map(e => e))

9.1 for 表达式中的模式匹配

如果 for 表达式中附带模式匹配,那么翻译起来就略显繁琐了。不妨从最简单的匹配元组的模式入手:

for ((x1,x2,...,xn) <- expr1) yield expr2  
/*
	<=> expr1.map { 
		case (x1,x2,...,xn) => expr2 
		case _ => Unit
	}
*/

下面的例子筛选出了 Map 中所有 value 值为 "scala" 的 key 值:

val stringMap = Map("akka" -> "scala","spark" -> "scala", "jvm" -> "java", "numpy" -> "python")
val stringBuffer: ListBuffer[String] = ListBuffer[String]()

for((k,"scala") <- stringMap) yield stringBuffer += k
	/*
        expr1 := stringMap
        expr2 := stringBuffer += k

        stringMap.map {
          case (k,"scala") => stringBuffer+= k;
          case _ => Unit
        }
     */
println(stringBuffer)

推广到更加一般的模式匹配 pat 的情形:

for (pat <- expr1) yield expr2 
/*
	<=> expr1.withFilter {
		case pat => true
		case _ => false
	} map {case pat => expr2}
*/

9.2 内嵌定义的 for 表达式

下面是最后一种情况,即 for 循环出现内嵌定义的时候:

for (x <- expr1; y = expr2) yield expr3 
/*			
	for ((x,y) <- for(x <- expr1) yield (x,expr2))
	yield expr3
	↓
	(expr1.map(x -> (x,expr2))).map({case (x,y) => expr3})
*/

下面附带内嵌定义的 for 循环实例:

/*
	par   := (x,y)
	expr1 := 1 to 5
	expr2 := 2
	expr3 := x * y
*/
for(x <- 1 to 5;y = 2) yield x * y
// ↓
for((x,y) <- for(x <- 1 to 5) yield (x,2)) yield x * y
// ↓
(1 to 5).map((_,2)).map({case (x,y) => x * y})

expr2 的计算与 x 无关时,这种定义方式就不再推荐了,尤其是 expr2 涉及到复杂计算时:因为每生成一对 (x,y),则 expr2 就需要被计算一次。这通常不如下面的写法来得好:

val y = 2
for(x <- 1 to 5) yield x * y

9.3 for 循环

以上是包含 yield 关键字的 for 表达式,而没有包含循环语句块 body 。一般来讲,我们使用 body 都是为了执行一些副作用,比如通过下标索引修改外部的 Array ,或者是向控制台输出结果等等。

我们回归到最熟悉的 for 循环写法中。Scala 可以用 foreach 将它翻译成更加简洁的写法:

for (x <- expr1) body  // <=> expr1.foreach(x => body)

如果是带条件守卫的嵌套 for 循环,则会多一个过滤器 withFilter

for(x <- expr1; if expr2; y <- expr3) body
/*
	<=> 
	(expr1.withFilter(expr2)).
	foreach(
		x => expr3.foreach( y => body )
	)
*/

下面的例子演示了如何用 for 循环实现二维矩阵中非负元素的累计和,以及它的等效写法:

val square = Array(
    Array(1, 2, 3),
    Array(-1, 5, -5),
    Array(4, -6, 9)
)

var sum = 0
for(line <- square; e <- line if e > 0) sum = sum + e

/*	
    sum = 0
    square.foreach(
        line => line.withFilter(_ > 0).foreach(
            e => sum = sum + e
        )
    )
*/
println(sum)

9.4 泛化的 for 表达式与 monad

综上所述,Scala 的 for 表达式不仅可以用在广义的上的 “序列","数组" 中,还能用在区间 ( range ),迭代器 ( iterator ),流 ( stream ),和所有集 ( Set ) 的实现类。甚至说,对于任何一个全部或部分实现了 flatMapmapwithFilterforeach 方法的类,我们都可以将它用在 for 表达式中。具体的规则如下:

  1. 如果该类型只定义了 map 方法,则它只允许使用包含单个生成器的 for 表达式。
  2. 如果该类型同时定义了 flatMapmap 方法,则它支持使用多个生成器的 for 表达式。
  3. 如果该类型额外定义了 foreach 方法,则它还支持使用带副作用的 for 循环。
  4. 如果该类型额外定义了 withFilter 方法,则它支持插入 if 开头的条件守卫。

Scala 没有对这四个函数做额外的任何约束,唯一的要求就是程序设计者保证被翻译后的 for 表达式能够通过类型检查。比如说笔者这里创建了自定义的类 Procedure

abstract class Procedure[A]{
    def map[B](f : A => B) : Procedure[B]
    def flatMap[B](f : A => Procedure[B]) : Procedure[B]
    def withFilter(f : A => Boolean) : Procedure[A]
    def foreach(f : A => Unit) : Unit
}

这四个函数应该具备的功能笔者已经在前文介绍过了。值得注意的是 withFilter 方法返回的是被过滤后的同类型 Procedure[A] 。此外,在 for 表达式的翻译过程中,这个被过滤后的 procedure[A] 总是会被再次传入到其它三个函数内 ( 因为条件守卫不可能单独存在 )。如果 A 是一个大对象,那么应该避免不断地创建中间结果。笔者建议使用 "标记 - 整理" 的套路来实现 withFilter 方法:先标记需要被筛除的对象,然后将剩下的对象集中装入到一个新的 Procedure[A] 中返回。

在后续的函数式编程中,我们会知道 Procedure[A] 有更加专业的名字 —— "单子" ( monad ),它用于解决包含大量运算的场合,包括从集合到对状态和 I/O 计算,回溯算法,事务管理等。

到这里,笔者了解了 Scala 中 for 表达式的高阶用法 —— 它表达的概念远远不局限于 "遍历数组" 那么简单。

10. 后续的计划

笔者目前对 Scala 的了解程度应该足够去应付简单的应用程序了。当然,还有一部分 “悬而未决” 的问题:比如在模式匹配中如何区分出不同泛型的 Map[_,_]5 此外还有:既然 Scala 将函数视作一等公民,那么两个函数 fg 之间是否存在继承关系?如果有,那么满足什么条件才能称函数 g 继承于函数 f3

笔者后续会花少许篇幅介绍有关于 SBT ( Scala 版本的 Maven ) ,Play 2! ( 基于 Scala 的 Web 框架 ),Akka ( Scala 提供的异步通讯模型 ) 的简单使用,然后介绍 Scala 的运行时反射和类型系统,最终过渡到 Spark/Kafka 框架的学习,以及 Scala 视角的纯函数式编程。

Footnotes

  1. 知乎:什么是声明式编程?

  2. 笔者曾介绍过有关于 Scala 柯里化 的内容。

  3. 解决这个问题,需要在 Scala 的型变机制下去阐述 "函数间的继承关系",这是未来的课题。 2

  4. 中间操作,终止操作的概念源自于 Java 8 的 Stream 流。读者可以自行百度,或者查看笔者以往的笔记:简单回顾 Java 8 (二)

  5. JVM 自身的类型擦除机制会影响类型推断,见Scala 之:模式匹配