详解如何自定义Spark外部数据源-补充

713 阅读4分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第1天,点击查看活动详情

前言

还记得在本专栏上篇文章的最后留下了两个问题 

 1-如何使用format("excel")来读取数据 

 2-如何在RDD计算中,真正做到读取固定的列信息。做到列裁剪。或者提前过滤等功能。 

如何简化format中的外部数据源格式

在之前的文章中我们已经实现了读取excel的外部数据源,并且可以基于如下方式进行使用

spark.sqlContext.read.format("lizunew.source.v1.excel.ExcelProvider")
  .option("pathDir", "/Users/lizu/idea/study/spark-extend-dataSource/data/excel")
  .load()

我们都用过spark读取jdbc的api,大致如下:

format("jdbc").save()

可以看到这里并没有指定完整的类名称,而是更加简单的指定了jdbc这个format。那么我们也来改造一下我们的外部数据源吧,让其指定excel为format。

定义format

这一步我们只需在之前定义的Provider中集成DataSourceRegister就可以

class ExcelProvider  extends RelationProvider with DataSourceRegister

然后实现其方法,就可以指定我们想要的format名称了:

override def shortName(): String = "excel"  //这里定义了我们的外部数据源名称为excel

SPI配置文件

因为spark源码中获取Provider会基于SPI机制,因此我们需要在下面的目录中定义如下文件:

/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister

文件内容如下,这里根据自己Provider的类名进行替换:

lizunew.source.v1.excel.ExcelProvider

完成之后我们就可以基于下面的方式来读取数据了:

spark.sqlContext.read.format("excel")
  .option("pathDir", "/Users/lizu/idea/study/spark-extend-dataSource/data/excel")
  .load()

相关源码

入口load()方法,然后我们会看到lookupDataSource方法:

def lookupDataSource(provider: String, conf: SQLConf): Class[_] = {
...
}

方法注释为:Given a provider name, look up the data source class definition:

就是根据传过来的provider名称来获取对应的实现类

继续看源码可以看到如下代码,和我们定义的文件对应上了:

val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader)

返回

case head :: Nil =>  // there is exactly one registered alias  head.getClass

具体的源码细节就不贴出来了都在org.apache.spark.sql.execution.datasources中的lookupDataSource方法中。

如何对我们的外部数据源做到列裁剪和过滤

在简化完使用方式后,就需要看第二个问题:优化数据读取的过程,列裁剪和数据提取过滤过滤。

我们之前已经实现了with PrunedFilteredScan下面的buildScan接口。但是并没有真正的去利用他来实现我们想要的功能,里面的两个参数 requiredColumns, filters 都没有使用上。这里我们来完善一下:

buildScan方法修改

获取需要读取的列信息以及其对应的scheam的index。传入我们的ExcelRDD

  override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
    //requiredColumns 可以对列做过滤
    println("build scan requiredColumns:" + requiredColumns.mkString(", "))

    new ExcelRDD(sparkSession.sparkContext,
      pruneSchema(schema, requiredColumns), getRequiredColumnsIndex(schema, requiredColumns),
      requiredColumns, parts)
  }
  
  
    /**
   * 裁剪schema
   * */
  private def pruneSchema(schema: StructType, columns: Array[String]): StructType = {
    val fieldMap = Map(schema.fields.map(x => x.name -> x): _*)
    new StructType(columns.map(name => fieldMap(name)))
  }
  
  
    /**
   * 返回 column,index
   * */
  private def getRequiredColumnsIndex(schema: StructType, columns: Array[String]): Map[String, Int] = {
    val schemaFields = schema.fields
    var indexMap = Map[String, Int]()

    for (i <- 0 to schemaFields.length - 1) {
      if (columns.contains(schemaFields(i).name)) {
        indexMap += (schemaFields(i).name -> i)
      }
    }
    indexMap
  }


计算逻辑优化compute

这里主要是针对解析excel的时候做了两点修改:

1-只获取excel中需要的字段,添加到需要返回的Row中

2-根据schema的字段类型来获取excel中真实的字段类型数据,比如完整的schema如下:

  /**
   * 全局的schema (TODO 动态获取)
   * 比如excel中完整的列信息
   * */
  def getSchema(): StructType = {
    StructType(
        StructField("id", StringType, false) ::
        StructField("name", StringType, true) ::
        StructField("age", IntegerType, true) :: Nil
    )
  }

3-getNext方法修改返回Row是基于GenericRowWithSchema构造的Row。GenericRowWithSchema继承自Row,它是可以指定StructType来定义Row中的数据类型的:

      override protected def getNext(): Row = {
        if (sheetIter.hasNext) {
          val row = sheetIter.next()
          val cells = new ListBuffer[Any]
          //这里只获取需要读取的列,并且需要对列的value进行类型转换
          for (i <- 0 to columns.length - 1) {
            val index = indexMap.get(columns(i)).get
            cells.append(castTo(row.getCell(index), fieldMap.get(columns(i)).get.dataType))
          } 
          //Row.fromSeq(cells)
          //根据schema返回信息
          new GenericRowWithSchema(cells.toArray, schema)

        } else {
          //结束标识
          finished = true
          null.asInstanceOf[Row]
        }
      }


  /**
   * 数据类型转换 TODO ,补充完整
   * */
  def castTo(
              datum: Cell,
              castType: DataType): Any = {
    castType match {
      case _: IntegerType => datum.getNumericCellValue.toInt
      case _: StringType => datum.getStringCellValue
      case _ => throw new RuntimeException(s"Unsupported type")
    }
  }

到此,已经实现了列的裁剪。

最后还剩下一个提前过滤功能没有实现,也就是filters参数没有使用上:

buildScan(requiredColumns: Array[String], filters: Array[Filter])

这个大家可以自己去实现一下。这里我会跟大家分享一下JDBCRDD中是如何实现的,如果没有思路的可以参考一下JDBCRDD的实现。

JDBCRDD中对PrunedFilteredScan的使用

首先找到JDBCRelation-->buildScan方法-->JDBCRDD.scanTable方法

  def scanTable(
      sc: SparkContext,
      schema: StructType,
      requiredColumns: Array[String],
      filters: Array[Filter],
      parts: Array[Partition],
      options: JDBCOptions): RDD[InternalRow] = {
    val url = options.url
    val dialect = JdbcDialects.get(url)
    val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName))
    new JDBCRDD(
      sc,
      JdbcUtils.createConnectionFactory(options),
      pruneSchema(schema, requiredColumns),
      quotedColumns,
      filters,
      parts,
      url,
      options)
  }

这里可以看到在new JDBCRDD的时候也进行了pruneSchema操作。并且也传入了我们说的两个参数requiredColumns,filters。

我们继续看JDBCRDD中关于filters的处理

   /**
   * `filters`, but as a WHERE clause suitable for injection into a SQL query.
   */
  private val filterWhereClause: String =
    filters
      .flatMap(JDBCRDD.compileFilter(_, JdbcDialects.get(url)))
      .map(p => s"($p)").mkString(" AND ")

通过注释可以知道这里是转换成了SQL中的where 条件,完整代码如下:

  /**
   * Turns a single Filter into a String representing a SQL expression.
   * Returns None for an unhandled filter.
   */
  def compileFilter(f: Filter, dialect: JdbcDialect): Option[String] = {
    def quote(colName: String): String = dialect.quoteIdentifier(colName)

    Option(f match {
      case EqualTo(attr, value) => s"${quote(attr)} = ${dialect.compileValue(value)}"
      case EqualNullSafe(attr, value) =>
        val col = quote(attr)
        s"(NOT ($col != ${dialect.compileValue(value)} OR $col IS NULL OR " +
          s"${dialect.compileValue(value)} IS NULL) OR " +
          s"($col IS NULL AND ${dialect.compileValue(value)} IS NULL))"
      case LessThan(attr, value) => s"${quote(attr)} < ${dialect.compileValue(value)}"
      case GreaterThan(attr, value) => s"${quote(attr)} > ${dialect.compileValue(value)}"
      case LessThanOrEqual(attr, value) => s"${quote(attr)} <= ${dialect.compileValue(value)}"
      case GreaterThanOrEqual(attr, value) => s"${quote(attr)} >= ${dialect.compileValue(value)}"
      case IsNull(attr) => s"${quote(attr)} IS NULL"
      case IsNotNull(attr) => s"${quote(attr)} IS NOT NULL"
      case StringStartsWith(attr, value) => s"${quote(attr)} LIKE '${value}%'"
      case StringEndsWith(attr, value) => s"${quote(attr)} LIKE '%${value}'"
      case StringContains(attr, value) => s"${quote(attr)} LIKE '%${value}%'"
      case In(attr, value) if value.isEmpty =>
        s"CASE WHEN ${quote(attr)} IS NULL THEN NULL ELSE FALSE END"
      case In(attr, value) => s"${quote(attr)} IN (${dialect.compileValue(value)})"
      case Not(f) => compileFilter(f, dialect).map(p => s"(NOT ($p))").getOrElse(null)
      case Or(f1, f2) =>
        // We can't compile Or filter unless both sub-filters are compiled successfully.
        // It applies too for the following And filter.
        // If we can make sure compileFilter supports all filters, we can remove this check.
        val or = Seq(f1, f2).flatMap(compileFilter(_, dialect))
        if (or.size == 2) {
          or.map(p => s"($p)").mkString(" OR ")
        } else {
          null
        }
      case And(f1, f2) =>
        val and = Seq(f1, f2).flatMap(compileFilter(_, dialect))
        if (and.size == 2) {
          and.map(p => s"($p)").mkString(" AND ")
        } else {
          null
        }
      case _ => null
    })
  }

因为JDBC最终会去执行SQL来获取数据,所以不像我们读取Excel,他最终都是在为拼接SQL来做铺垫,包括requiredColumns,只需要拼接到最终的SQL上面就可以做到列裁剪了。

val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause"

总结

到此对于spark读取外部数据源的介绍基本告一段落,我们介绍了如何自定义一个Spark的外部数据源,以及其中的一些小细节。最后更加推荐大家去阅读JDBCRDD的相关源码。