Spark:Task not serializable 追根溯源

539 阅读6分钟

Spark:Task not serializable 追根溯源

在修改别人的spark代码时遇到了org.apache.spark.SparkException: Task not serializable这个问题。并且还有一大段摸不着头脑的报错信息,咋一看还以为是乱码,仔细一看才发现是各种类和方法名称。

上网一搜,算作spark的一个高频问题,有些人甚至用了notorious这个词,不少有些经验的开发者都踩过这个坑,更别说我scala/spark 0基础的萌新了。

在chrome浏览器里面打开了N个标签之后,总结出搜索结果大概给了几种不一样的解决方案,但很多都类似csdn风格,抄来抄去。看这些文档里面的操作方法,排列组合试几种可能解决这个问题,之后可能就不再深究了。但是很不幸,这次遇到的问题和网上高频遇到的情况都不同,显然哪些方法在这个情况下不太靠谱,并且没有把问题讲透彻,所以有了这篇文章。

问题分析

如果有时间的话,还是希望深入一下这个问题,探究spark为啥会报这个错误,而不是简单的把网上的方法都试一遍。其实我感觉追根溯源也是编程,甚至是计算科学的乐趣所在。

尝试最小代码复现

class MyJob(spark: SparkSession) {

  import spark.implicits._

  def wordCount: UserDefinedFunction =
    udf { str: String => str.length }

  def run(): Unit =
    spark
      .sql("SELECT 'hello' as text")
      .withColumn("word_count", wordCount($"text"))
      .show()
}

这个代码很简单,目的是计算某个字符串的长度。 写个main函数运行一下:

object SparkTask {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf()
      .setMaster("local")
      .setAppName("SparkTask")

    new SparkContext(conf)
    val spark = SparkSession.builder()
      .appName(this.getClass.getSimpleName)
      .enableHiveSupport()
      .getOrCreate()

    val myJob = new MyJob(spark)
    myJob.run()
  }
}

输出很nice,编程的成就感油然而生:

image.png

作为解耦狂魔,为了应对产品随时可能提出的新的word_count需求,我们可以把代码改成这样:

class MyJob(spark: SparkSession) {

  import spark.implicits._

  def wordCount(str: String): Integer =
    str.length

  def wordCountUdf: UserDefinedFunction =
    udf { str: String => wordCount(str) }

  def run(): Unit =
    spark
      .sql("SELECT 'hello!' as text")
      .withColumn("word_count", wordCountUdf($"text"))
      .show()
}

看似完美,实则运行报错,还就是那个熟悉的错误:

image.png

试着加一个extends Serializable

class MyJob(spark: SparkSession) extends Serializable {

  import spark.implicits._

  def wordCount(str: String): Integer =
    str.length

  def wordCountUdf: UserDefinedFunction =
    udf { str: String => wordCount(str) }

  def run(): Unit =
    spark
      .sql("SELECT 'hello!' as text")
      .withColumn("word_count", wordCountUdf($"text"))
      .show()
}

可以了,没有报错。

后续写着写着试图加一个logger

class MyJob(spark: SparkSession) extends Serializable {

  import spark.implicits._
  val logger: Logger = LogManager.getLogger(getClass)
  logger.info("Init")

  def wordCount(str: String): Integer =
    str.length

  def wordCountUdf: UserDefinedFunction =
    udf { str: String => wordCount(str) }

  def run(): Unit =
    spark
      .sql("SELECT 'hello!' as text")
      .withColumn("word_count", wordCountUdf($"text"))
      .show()
}

发现那个错误又回来了😂

最后来一个终极版本解决

class MyJob(spark: SparkSession) extends Serializable {

  import spark.implicits._
  val logger: Logger = LogManager.getLogger(getClass)
  logger.info("Init")

  def wordCount(str: String): Integer =
    str.length

  def wordCountUdf: UserDefinedFunction =
    udf { str: String => WordCountJob.wordCount(str) }

  def run(): Unit =
    spark
      .sql("SELECT 'hello!' as text")
      .withColumn("word_count", wordCountUdf($"text"))
      .show()
}

object WordCountJob {
  def wordCount(str: String): Integer =
    str.length
}

为什么要序列化?

Spark最大的特点就是 分布式执行,处理的数据单元为RDD(弹性分布式数据集),表示分布在不同节点的数据。程序启动时,编译的代码被所有分布式节点加载。所以每个节点都需要有编译类的副本。

因为driver和excutor的运行在不同的jvm中,副本传输过程中势必会涉及到对象的序列化与反序列化,如果这个变量没法序列化,或则是引用的对象可以序列化,但是引用的对象本身引用的其他对象无法序列化,就会有异常。

  1. 代码中对象在driver本地序列化
  2. 传输到分发到远程executor
  3. 远程executor节序列化
  4. 执行

静态类的情况比较简单;编译期就可以确定。因此,可以将副本传递给每个executor节点、所以无需在runtime进行序列化。

还有一种情况:transient修饰的变量也不会被序列化,在被反序列化后,transient变量的值被设为0或者null。

在上面的错误例子中。show()执行时,由executor节点执行查询。此时,driver必须向executor发送要执行的code,有些不带static的会在运行时期传递,问题就出现在下面的这一行中,虽然看起来像是一个没有任何状态的简单的程序闭包。

str: String => wordCount(str)

虽然简单,但是也是一个对象的方法,而不是一个函数

str: String => this.wordCount(str)

这个可能会被JVM实现为

public class WordCountFunction {

    private WordCountJob job;

    public WordCountFunction(WordCountJob job) {
        this.job = job;
    }

    public String wordCountJob(String str) {
        return str.lenth
    }

}

在序列化的时候,Spark会调用ObjectOutputStream.writeObject(),类必须实现可序列化接口java.io.Serializable。对于复杂的类,会在原始对象持有的每个成员上递归调用(比如Logger),只要遇到任何一个不支持Seriezable接口的对象,就会抛异常。

再回到我们的问题:

对于scala语言开发,解决序列化问题主要如下几点:

  • 在Object中声明对象 (每个class对应有一个Object)
  • 如果在闭包中使用SparkContext或者SqlContext,建议使用SparkContext.get() and SQLContext.getActiveOrCreate()
  • 使用static或transient修饰不可序列化的属性从而避免序列化。

对于java语言开发,对于不可序列化对象,如果本身不需要存储或传输,则可使用static或trarnsient修饰;如果需要存储传输,则实现writeObject()/readObject()使用自定义序列化方法。

此外注意,对于Spark Streaming作业,注意哪些操作在driver,哪些操作在executor。因为在driver端(foreachRDD)实例化的对象,很可能不能在foreach中运行,因为对象不能从driver序列化传递到executor端(有些对象有TCP链接,一定不可以序列化)。所以这里一般在foreachPartitions或foreach算子中来实例化对象,这样对象在executor端实例化,没有从driver传输到executor的过程。

dstream.foreachRDD { rdd =>
  val where1 = "on the driver"
    rdd.foreach { record =>
      val where2 = "on different executors"
    }
  }
}

问题解决了吗?

很遗憾,还是没有,依然会出现这个Task not serializable错误。

经过仔细debug,发现问题出在这个FileSystem.get:

val fs = FileSystem.get(URI.create(path), conf)

它会throws一个IOException ,导致了这个Task not serializable

所以真正的问题是:

在mapPartitionsWithIndex域内使用了一个org.apache.hadoop.fs.FileSystem

这个变量是不能被序列化的。

This happens after the broadcast because broadcast variable is a part of the enclosing object same as file system instance and closure cleaner is not perfect.

解决办法

将不可序列化的对象定义在闭包内

object SparkTest {
  def main(args: Array[String]): Unit = {
  val conf = new SparkConf().setMaster("local[*]").setAppName("test")
  val sc = new SparkContext(conf)
  val rdd = sc.parallelize(1 to 10,3)
  rdd.map(x=>new UnserializableClass().method(x)).foreach(println(_)) //在map中创建UnserializableClass对象
  }
 }

将所调用的方法改为函数,在高阶函数中使用

将UnserializableClass类中的method方法改为method函数

class UnserializableClass {
  //method方法
  /*def method(x:Int):Int={
    x*x
  }*/

//method函数
  val method = (x:Int)=>x*x
}

在SparkTest中传入函数:

object SparkTest {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setMaster("local[*]").setAppName("test")
    val sc = new SparkContext(conf)
    val rdd = sc.parallelize(1 to 10,3)
    val usz  = new UnserializableClass()
    rdd.map(usz.method).foreach(println(_)) //注意这里传入的是函数
  }
}

给无法序列化的类加上java.io.Serializable接口

class UnserializableClass extends java.io.Serializable{ //加接口
  def method(x:Int):Int={
    x*x
  }
}

注册序列化类

以上三个方法基于UnserializableClass可以被修改来说的,假如UnserializableClass来自于第三方,你无法修改其源码就可以使用为其注册序列化类的方法。

object SparkTest {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setMaster("local[*]").setAppName("test")

    conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") //指定序列化类为KryoSerializer
    conf.registerKryoClasses(Array(classOf[net.bigdataer.spark.UnserializableClass])) //将UnserializableClass注册到kryo需要序列化的类中

    val sc = new SparkContext(conf)
    val rdd = sc.parallelize(1 to 10,3)
    val usz  = new UnserializableClass()
    rdd.map(x=>usz.method(x)).foreach(println(_))
  }
}

transient

既然这个是不可序列化的,并且确实不需要序列话,

使用@transient把这个变量标记为瞬态即可,在函数调用中初始化它们,而不是在构造函数中初始化。

@transient val fs = FileSystem.get(URI.create(path), new Configuration())

参考: medium.com/swlh/spark-… stackoverflow.com/questions/2… cloud.tencent.com/developer/a…