详解如何自定义Spark外部数据源

2,343 阅读3分钟

我报名参加金石计划1期挑战——瓜分10万奖池,这是我的第3篇文章,点击查看活动详情

接口说明

首先来看一下主要的两个接口

BaseRelation

Represents a collection of tuples with a known schema. 
Classes that extend BaseRelation must be able to produce the schema of their data in the form of a StructType.

源码中的注释说明,继承BaseRelation必须可以产生对应schema结构的数据
RelationProvider

Returns a new base relation with the given parameters.
通过源码注释可以看到他会提供一个relation,并且会提供参数

实现读取EXCEL文件

这里通过简单读取excel文件来说明如何具体实现。

1-实现ExcelRelation

首先定义class ExcelRelation,并且继承BaseRelation,实现TableScan接口,如果需要实现下推的功能可以实现PrunedFilteredScan接口,如下:

class ExcelRelation(override val schema: StructType, parts: Array[Partition])
                   (@transient val sparkSession: SparkSession) 
                   extends BaseRelation with TableScan
                   with PrunedFilteredScan with Serializable {

  override def sqlContext: SQLContext = sparkSession.sqlContext

  //A BaseRelation that can produce all of its tuples as an RDD of Row objects
  override def buildScan(): RDD[Row] = {
    println("build scan ");

    new ExcelRDD(sparkSession.sparkContext, parts)
  }
  
  //A BaseRelation that can eliminate unneeded columns and filter using selected predicates before producing an RDD containing all matching tuples as Row objects.
  override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
    //requiredColumns 用来对列做过滤
    println("build scan requiredColumns:" + requiredColumns.mkString(", "))
    new ExcelRDD(sparkSession.sparkContext, parts)
  }
}

通过代码注释可以看到,PrunedFilteredScan可以在早期就过滤掉不用的数据,实现下推的优化。

2-自定义ExcelRDD

上面的代码我们中我们需要实现buildScan方法,通过代码注释就可以知道他需要产生RDD的数据集。这里我们通过自己实现一个ExcelRDD来实现。 自定义RDD可以参考EmptyRDD

An RDD that has no partitions and no elements.

private[spark] class EmptyRDD[T: ClassTag](sc: SparkContext) extends RDD[T](sc, Nil) {

  //Implemented by subclasses to return the set of partitions in this RDD. This method will only be called once, so it is safe to implement a time-consuming computation in it.
  override def getPartitions: Array[Partition] = Array.empty


  //Implemented by subclasses to compute a given partition.
  override def compute(split: Partition, context: TaskContext): Iterator[T] = {
    throw new UnsupportedOperationException("empty RDD")
  }
}

这个RDD实现了一个没有分区和没有数据的空RDD.其中两个核心方法compute会按传入的分区来计算并返回Iterator。 getPartitions则代表当前RDD所有的分区信息,并且源码注释说明了只会调用一次。 那我们就可以参考他来实现一个有分区并且可以产生RDD数据集的ExcelRDD。重点就是实现上面两个方法。

此时我们可以先停下来,去设计一下应该如何实现这两个方法: 比如我这里的需求是可以读取一个目录下所有相同格式的Excel文件,/data/file目录下面有三个excel文件

  • test01.excel
  • test02.excel
  • test03.excel

getPartitions需要返回当前RDD所有的分区集合,我们这里可以设计每个文件作为一个分区,结构可以为 Array(文件路径,index) index代表第几个分区。 Array(test01.excel,1) Array(test02.excel,2) Array(test03.excel,3)

那么compute方法其实就需要根据传入进来的Partition信息获取到当前分区中的文件路径去读取对应的文件内容,最终返回一个Iterator

class ExcelRDD(sc: SparkContext, partitions: Array[Partition]) extends RDD[Row](sc, Nil) {

  /**
   * split 当前分区
   * */
  override def compute(split: Partition, context: TaskContext): Iterator[Row] = {

    val part = split.asInstanceOf[EXCELPartition]
    //解析当前分区的一个EXCEL
    val workbook = StreamingReader.builder.rowCacheSize(100).bufferSize(4096)
      .open(new File(part.fileName))
    //获取sheet
    val sheet = workbook.getSheetAt(0)
    val sheetIter = sheet.iterator()

    //task 结束后回调
    context.addTaskCompletionListener[Unit](_ => close())

    def close(): Unit = {
      if (workbook != null) {
        workbook.close()
      }
    }

    //返回迭代器,惰性返回数据
    new NextCloseIterator[Row] {

      override protected def getNext(): Row = {
        if (sheetIter.hasNext) {
          val row = sheetIter.next()
          val cellIter = JavaConverters.asScalaIterator(row.cellIterator())
          val cells = new ListBuffer[String]
          for (cell <- cellIter) {
            cells.append(cell.getStringCellValue)
          }
          Row.fromSeq(cells)
        } else {
          //结束标识
          finished = true
          null.asInstanceOf[Row]
        }
      }

      override protected def close(): Unit = {
        println("NextCloseIterator close resource")
      }
    }
  }
  
  /**
   * getPartitions用于获得分区信息。
   这里没有分区逻辑实现,是因为在构造当前RDD的时候传入了定义好的分区,在后面的代码中
   * */
  override protected def getPartitions: Array[Partition] = {
    partitions
  }

}

通过上面我们可以知道返回给buildScan的是我们自定义的RDD,他不是直接返回一个大的完整的RDD数据集,而是通过一种惰性返回。

3-实现ExcelProvider

上面我们已经完整的实现了ExcelRelation的功能,他可以通过我们自定义的ExcelRDD返回buildScan所需的RDD数据集。 接下来我们就需要实现Provider,来提供我们实现的ExcelRelation.

首先就是定义出schema,如果是jdbc,则可以去表中查询出当前表中所有的schema信息。 我这里是读取excel,先写死了,其实可以通过外部参数传递进来。

  /**
   * 全局的schema (TODO 动态获取)
   * 比如excel中完整的列信息
   * */
  def getSchema(): StructType = {
    StructType(
        StructField("id", StringType, false) ::
        StructField("name", StringType, true) ::
        StructField("age", IntegerType, true) :: Nil
    )
  }

接下来就是实现createRelation方法,这里主要就是从parameters这个Map中获取我们使用当前数据源所传递的参数,我们之前的设定就是读取一个目录下面的所有Excel,因此只定义了一个pathDir参数。

在返回我们的ExcelRelation之前,还需要将我们之前RDD中需要的分区信息定义好,也是是目录下的每个文件是一个分区,如果是jdbc则可以根据主键进行切分为多个SQL

  /**
   * 获取分区,(文件路径,index)
   * */
  def excelPartition(pathDir: String): Array[Partition] = {
    val ans = new ArrayBuffer[Partition]()

    val files = getFiles(new File(pathDir))

    for (i <- 0 until files.length) {
      ans += EXCELPartition(files(i).getPath, i)
    }

    ans.toArray
  }
  
   /**
   * 获取目录下的所有文件
   * */
  def getFiles(dir: File): Array[File] = {
    dir.listFiles.filter(_.isFile).toArray
  }

最后来看一下完整的ExcelProvider吧

class ExcelProvider
  extends RelationProvider with DataSourceRegister {

  /**
   * 读取excel数据源
   * */
  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {

    //定义有总共哪些字段,以及字段类型
    val schema = ExcelRelation.getSchema()

    /**
     * 要求一个目录下放schema相同的excel
     * */
    val pathDir = parameters.get("pathDir").get
    System.out.println("pathDir :" + pathDir)
    if (StringUtils.isEmpty(pathDir)) {
      throw new RuntimeException("pathDir empty")
    }

    val partitions = ExcelRelation.excelPartition(pathDir);

    new ExcelRelation(schema, partitions)(sqlContext.sparkSession);

  }


  override def shortName(): String = "excel"
}

4-使用测试


  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("spark-excel-datasource")
    val spark = SparkSession.builder().config(conf).master("local").getOrCreate()


    val df = spark.sqlContext.read.format("lizunew.source.v1.excel.ExcelProvider")
      .option("pathDir", "/Users/lizu/idea/study/spark-extend-dataSource/data/excel")
      .load()


    //创建Properties存储数据库相关属性
    val prop = new Properties()
    prop.put("user", "root")
    prop.put("password", "root123")

    //PrunedFilteredScan
    df.select("id", "name").write.mode("append").jdbc("jdbc:mysql://114.67.67.44:4306/test?useSSL=false", "excel", prop)

    Thread.sleep(1000000000)

  }

image.png

最终也是有三个task来完成当前的读取,并写入到了mysql中

5-问题

  • 如何更加简单的定义spark.sqlContext.read.format中的格式,而不是指定一个全类名称,比如format("excel")
  • 如何在RDD计算中,真正做到读取固定的列信息,而不是现在这样还是解析获取了excel中所有的列,当前读取excel其实体现不出来PrunedFilteredScan的好处,试想下如果读取的是数据库,如果可以提前只读取部分列数据或者过滤一部分数据其实是可以节省很多资源和时间的。有兴趣的可以先研究一下JDBCRDD的实现。

如果对上面问题有兴趣的话,可以关注后续的文章