Spark:Task not serializable 追根溯源
在修改别人的spark代码时遇到了org.apache.spark.SparkException: Task not serializable这个问题。并且还有一大段摸不着头脑的报错信息,咋一看还以为是乱码,仔细一看才发现是各种类和方法名称。
上网一搜,算作spark的一个高频问题,有些人甚至用了notorious这个词,不少有些经验的开发者都踩过这个坑,更别说我scala/spark 0基础的萌新了。
在chrome浏览器里面打开了N个标签之后,总结出搜索结果大概给了几种不一样的解决方案,但很多都类似csdn风格,抄来抄去。看这些文档里面的操作方法,排列组合试几种可能解决这个问题,之后可能就不再深究了。但是很不幸,这次遇到的问题和网上高频遇到的情况都不同,显然哪些方法在这个情况下不太靠谱,并且没有把问题讲透彻,所以有了这篇文章。
问题分析
如果有时间的话,还是希望深入一下这个问题,探究spark为啥会报这个错误,而不是简单的把网上的方法都试一遍。其实我感觉追根溯源也是编程,甚至是计算科学的乐趣所在。
尝试最小代码复现
class MyJob(spark: SparkSession) {
import spark.implicits._
def wordCount: UserDefinedFunction =
udf { str: String => str.length }
def run(): Unit =
spark
.sql("SELECT 'hello' as text")
.withColumn("word_count", wordCount($"text"))
.show()
}
这个代码很简单,目的是计算某个字符串的长度。 写个main函数运行一下:
object SparkTask {
def main(args: Array[String]): Unit = {
val conf = new SparkConf()
.setMaster("local")
.setAppName("SparkTask")
new SparkContext(conf)
val spark = SparkSession.builder()
.appName(this.getClass.getSimpleName)
.enableHiveSupport()
.getOrCreate()
val myJob = new MyJob(spark)
myJob.run()
}
}
输出很nice,编程的成就感油然而生:
作为解耦狂魔,为了应对产品随时可能提出的新的word_count需求,我们可以把代码改成这样:
class MyJob(spark: SparkSession) {
import spark.implicits._
def wordCount(str: String): Integer =
str.length
def wordCountUdf: UserDefinedFunction =
udf { str: String => wordCount(str) }
def run(): Unit =
spark
.sql("SELECT 'hello!' as text")
.withColumn("word_count", wordCountUdf($"text"))
.show()
}
看似完美,实则运行报错,还就是那个熟悉的错误:
试着加一个extends Serializable
class MyJob(spark: SparkSession) extends Serializable {
import spark.implicits._
def wordCount(str: String): Integer =
str.length
def wordCountUdf: UserDefinedFunction =
udf { str: String => wordCount(str) }
def run(): Unit =
spark
.sql("SELECT 'hello!' as text")
.withColumn("word_count", wordCountUdf($"text"))
.show()
}
可以了,没有报错。
后续写着写着试图加一个logger
class MyJob(spark: SparkSession) extends Serializable {
import spark.implicits._
val logger: Logger = LogManager.getLogger(getClass)
logger.info("Init")
def wordCount(str: String): Integer =
str.length
def wordCountUdf: UserDefinedFunction =
udf { str: String => wordCount(str) }
def run(): Unit =
spark
.sql("SELECT 'hello!' as text")
.withColumn("word_count", wordCountUdf($"text"))
.show()
}
发现那个错误又回来了😂
最后来一个终极版本解决
class MyJob(spark: SparkSession) extends Serializable {
import spark.implicits._
val logger: Logger = LogManager.getLogger(getClass)
logger.info("Init")
def wordCount(str: String): Integer =
str.length
def wordCountUdf: UserDefinedFunction =
udf { str: String => WordCountJob.wordCount(str) }
def run(): Unit =
spark
.sql("SELECT 'hello!' as text")
.withColumn("word_count", wordCountUdf($"text"))
.show()
}
object WordCountJob {
def wordCount(str: String): Integer =
str.length
}
为什么要序列化?
Spark最大的特点就是 分布式执行,处理的数据单元为RDD(弹性分布式数据集),表示分布在不同节点的数据。程序启动时,编译的代码被所有分布式节点加载。所以每个节点都需要有编译类的副本。
因为driver和excutor的运行在不同的jvm中,副本传输过程中势必会涉及到对象的序列化与反序列化,如果这个变量没法序列化,或则是引用的对象可以序列化,但是引用的对象本身引用的其他对象无法序列化,就会有异常。
- 代码中对象在driver本地序列化
- 传输到分发到远程executor
- 远程executor节序列化
- 执行
静态类的情况比较简单;编译期就可以确定。因此,可以将副本传递给每个executor节点、所以无需在runtime进行序列化。
还有一种情况:transient修饰的变量也不会被序列化,在被反序列化后,transient变量的值被设为0或者null。
在上面的错误例子中。show()执行时,由executor节点执行查询。此时,driver必须向executor发送要执行的code,有些不带static的会在运行时期传递,问题就出现在下面的这一行中,虽然看起来像是一个没有任何状态的简单的程序闭包。
str: String => wordCount(str)
虽然简单,但是也是一个对象的方法,而不是一个函数
str: String => this.wordCount(str)
这个可能会被JVM实现为
public class WordCountFunction {
private WordCountJob job;
public WordCountFunction(WordCountJob job) {
this.job = job;
}
public String wordCountJob(String str) {
return str.lenth
}
}
在序列化的时候,Spark会调用ObjectOutputStream.writeObject(),类必须实现可序列化接口java.io.Serializable。对于复杂的类,会在原始对象持有的每个成员上递归调用(比如Logger),只要遇到任何一个不支持Seriezable接口的对象,就会抛异常。
再回到我们的问题:
对于scala语言开发,解决序列化问题主要如下几点:
- 在Object中声明对象 (每个class对应有一个Object)
- 如果在闭包中使用SparkContext或者SqlContext,建议使用SparkContext.get() and SQLContext.getActiveOrCreate()
- 使用static或transient修饰不可序列化的属性从而避免序列化。
对于java语言开发,对于不可序列化对象,如果本身不需要存储或传输,则可使用static或trarnsient修饰;如果需要存储传输,则实现writeObject()/readObject()使用自定义序列化方法。
此外注意,对于Spark Streaming作业,注意哪些操作在driver,哪些操作在executor。因为在driver端(foreachRDD)实例化的对象,很可能不能在foreach中运行,因为对象不能从driver序列化传递到executor端(有些对象有TCP链接,一定不可以序列化)。所以这里一般在foreachPartitions或foreach算子中来实例化对象,这样对象在executor端实例化,没有从driver传输到executor的过程。
dstream.foreachRDD { rdd =>
val where1 = "on the driver"
rdd.foreach { record =>
val where2 = "on different executors"
}
}
}
问题解决了吗?
很遗憾,还是没有,依然会出现这个Task not serializable错误。
经过仔细debug,发现问题出在这个FileSystem.get:
val fs = FileSystem.get(URI.create(path), conf)
它会throws一个IOException ,导致了这个Task not serializable
所以真正的问题是:
在mapPartitionsWithIndex域内使用了一个org.apache.hadoop.fs.FileSystem
这个变量是不能被序列化的。
This happens after the broadcast because broadcast variable is a part of the enclosing object same as file system instance and closure cleaner is not perfect.
解决办法
将不可序列化的对象定义在闭包内
object SparkTest {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setMaster("local[*]").setAppName("test")
val sc = new SparkContext(conf)
val rdd = sc.parallelize(1 to 10,3)
rdd.map(x=>new UnserializableClass().method(x)).foreach(println(_)) //在map中创建UnserializableClass对象
}
}
将所调用的方法改为函数,在高阶函数中使用
将UnserializableClass类中的method方法改为method函数
class UnserializableClass {
//method方法
/*def method(x:Int):Int={
x*x
}*/
//method函数
val method = (x:Int)=>x*x
}
在SparkTest中传入函数:
object SparkTest {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setMaster("local[*]").setAppName("test")
val sc = new SparkContext(conf)
val rdd = sc.parallelize(1 to 10,3)
val usz = new UnserializableClass()
rdd.map(usz.method).foreach(println(_)) //注意这里传入的是函数
}
}
给无法序列化的类加上java.io.Serializable接口
class UnserializableClass extends java.io.Serializable{ //加接口
def method(x:Int):Int={
x*x
}
}
注册序列化类
以上三个方法基于UnserializableClass可以被修改来说的,假如UnserializableClass来自于第三方,你无法修改其源码就可以使用为其注册序列化类的方法。
object SparkTest {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setMaster("local[*]").setAppName("test")
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") //指定序列化类为KryoSerializer
conf.registerKryoClasses(Array(classOf[net.bigdataer.spark.UnserializableClass])) //将UnserializableClass注册到kryo需要序列化的类中
val sc = new SparkContext(conf)
val rdd = sc.parallelize(1 to 10,3)
val usz = new UnserializableClass()
rdd.map(x=>usz.method(x)).foreach(println(_))
}
}
transient
既然这个是不可序列化的,并且确实不需要序列话,
使用@transient把这个变量标记为瞬态即可,在函数调用中初始化它们,而不是在构造函数中初始化。
@transient val fs = FileSystem.get(URI.create(path), new Configuration())
参考: medium.com/swlh/spark-… stackoverflow.com/questions/2… cloud.tencent.com/developer/a…