我们按照Spark的模块来详细说明,并给出具体的使用例子。Spark的核心模块主要包括:
- Spark Core
- Spark SQL
- Spark Streaming
- Spark MLlib (机器学习)
- Spark GraphX (图计算)
下面我们逐一介绍:
-
Spark Core
Spark Core是Spark的基础,提供了分布式任务调度、内存管理、故障恢复等基础功能。它引入了弹性分布式数据集(RDD)的概念,RDD是一个不可变的分布式对象集合,可以并行操作。
例子:使用Spark Core进行词频统计 -
Spark SQL
Spark SQL是用于处理结构化数据的模块,它提供了DataFrame和DataSet API,可以执行SQL查询。Spark SQL可以从各种数据源中读取数据,如Hive、JSON、Parquet等。
例子:使用Spark SQL查询数据 -
Spark Streaming
Spark Streaming是用于流式处理的模块,它可以将流数据分成小批次进行处理,支持高吞吐和容错。它可以从Kafka、Flume、TCP socket等数据源读取数据。
例子:使用Spark Streaming处理TCP socket数据 -
Spark MLlib
MLlib是Spark的机器学习库,提供了常见的机器学习算法,如分类、回归、聚类、协同过滤等,以及特征提取、转换、降维等工具。
例子:使用MLlib进行线性回归 -
Spark GraphX
GraphX是用于图计算的API,它提供了图抽象和操作,可以用于社交网络分析、路径查找等。
例子:使用GraphX进行图计算Apache Spark 是一个统一的分析引擎,包含多个核心模块,每个模块都有特定的功能。以下是 Spark 的主要核心模块及其详细说明和使用示例:
一、Spark Core
功能:基础模块,提供任务调度、内存管理、错误恢复等核心功能
示例:RDD 操作
from pyspark import SparkContext
# 创建 SparkContext
sc = SparkContext("local", "Core Example")
# 创建 RDD
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
rdd = sc.parallelize(data, 3) # 3个分区
# Transformations 转换操作
squared_rdd = rdd.map(lambda x: x * x) # 平方
filtered_rdd = rdd.filter(lambda x: x > 5) # 过滤
reduced = rdd.reduce(lambda a, b: a + b) # 求和
# Actions 行动操作
print("Count:", rdd.count()) # 计数
print("First:", rdd.first()) # 第一个元素
print("Collect:", rdd.collect()) # 收集所有数据
print("Sum:", reduced) # 求和结果
# 保存到文件系统
rdd.saveAsTextFile("hdfs://path/to/output")
sc.stop()
二、Spark SQL
功能:结构化数据处理,支持 SQL 查询和 DataFrame API
示例:DataFrame 操作
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, avg, desc
# 创建 SparkSession
spark = SparkSession.builder \
.appName("SparkSQL Example") \
.config("spark.some.config.option", "some-value") \
.getOrCreate()
# 1. 从各种数据源创建 DataFrame
# 从 CSV 文件
df = spark.read.csv("data.csv", header=True, inferSchema=True)
# 从 JSON 文件
json_df = spark.read.json("data.json")
# 从 Parquet 文件
parquet_df = spark.read.parquet("data.parquet")
# 从 Hive 表
hive_df = spark.sql("SELECT * FROM my_table")
# 2. DataFrame 操作示例
# 显示 schema
df.printSchema()
# 基本查询
df.select("name", "age").show()
df.filter(col("age") > 25).show()
df.groupBy("department").agg(avg("salary").alias("avg_salary")).show()
# 3. SQL 查询
# 注册为临时视图
df.createOrReplaceTempView("employees")
# 执行 SQL
result = spark.sql("""
SELECT department,
AVG(salary) as avg_salary,
COUNT(*) as emp_count
FROM employees
WHERE age > 25
GROUP BY department
HAVING COUNT(*) > 5
ORDER BY avg_salary DESC
""")
result.show()
# 4. 复杂示例:窗口函数
from pyspark.sql.window import Window
from pyspark.sql.functions import rank, row_number
windowSpec = Window.partitionBy("department").orderBy(col("salary").desc())
df_with_rank = df.withColumn("rank", rank().over(windowSpec)) \
.withColumn("row_number", row_number().over(windowSpec))
df_with_rank.show()
spark.stop()
三、Spark Streaming
功能:实时流数据处理
示例:处理 Kafka 数据流
from pyspark.sql import SparkSession
from pyspark.sql.functions import from_json, col
from pyspark.sql.types import StructType, StringType, IntegerType
# 创建 SparkSession
spark = SparkSession.builder \
.appName("StructuredStreaming Example") \
.config("spark.sql.shuffle.partitions", "2") \
.getOrCreate()
# 定义 schema
schema = StructType() \
.add("user_id", StringType()) \
.add("event_type", StringType()) \
.add("timestamp", StringType()) \
.add("value", IntegerType())
# 从 Kafka 读取流数据
kafka_stream = spark \
.readStream \
.format("kafka") \
.option("kafka.bootstrap.servers", "localhost:9092") \
.option("subscribe", "user_events") \
.option("startingOffsets", "latest") \
.load()
# 解析 JSON 数据
parsed_stream = kafka_stream \
.select(from_json(col("value").cast("string"), schema).alias("data")) \
.select("data.*")
# 实时处理:计算每5分钟的窗口统计
windowed_counts = parsed_stream \
.withWatermark("timestamp", "10 minutes") \
.groupBy(
col("event_type"),
window(col("timestamp"), "5 minutes")
) \
.count()
# 输出到控制台
query = windowed_counts \
.writeStream \
.outputMode("complete") # 或 "update", "append"
.format("console") \
.option("truncate", "false") \
.start()
# 输出到 Kafka
output_query = windowed_counts \
.selectExpr("to_json(struct(*)) AS value") \
.writeStream \
.format("kafka") \
.option("kafka.bootstrap.servers", "localhost:9092") \
.option("topic", "aggregated_events") \
.option("checkpointLocation", "/tmp/checkpoint") \
.start()
query.awaitTermination()
四、MLlib(机器学习)
功能:机器学习算法库
示例:完整的机器学习 Pipeline
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler, StringIndexer, StandardScaler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
# 1. 准备数据
data = spark.read.csv("iris.csv", header=True, inferSchema=True)
# 2. 特征工程
# 字符串标签转索引
label_indexer = StringIndexer(inputCol="species", outputCol="label")
# 特征向量化
feature_cols = ["sepal_length", "sepal_width", "petal_length", "petal_width"]
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
# 标准化
scaler = StandardScaler(inputCol="features", outputCol="scaled_features")
# 3. 划分训练集和测试集
train_data, test_data = data.randomSplit([0.8, 0.2], seed=42)
# 4. 构建模型
rf = RandomForestClassifier(
featuresCol="scaled_features",
labelCol="label",
numTrees=100,
maxDepth=5,
seed=42
)
# 5. 构建 Pipeline
pipeline = Pipeline(stages=[
label_indexer,
assembler,
scaler,
rf
])
# 6. 超参数调优
param_grid = ParamGridBuilder() \
.addGrid(rf.numTrees, [50, 100, 200]) \
.addGrid(rf.maxDepth, [3, 5, 7]) \
.build()
# 7. 交叉验证
evaluator = MulticlassClassificationEvaluator(
labelCol="label",
predictionCol="prediction",
metricName="f1"
)
crossval = CrossValidator(
estimator=pipeline,
estimatorParamMaps=param_grid,
evaluator=evaluator,
numFolds=5,
seed=42
)
# 8. 训练模型
cv_model = crossval.fit(train_data)
# 9. 预测
predictions = cv_model.transform(test_data)
# 10. 评估
accuracy = evaluator.evaluate(predictions)
print(f"Model accuracy: {accuracy:.4f}")
# 11. 查看特征重要性
best_model = cv_model.bestModel
rf_model = best_model.stages[-1]
importances = rf_model.featureImportances
print("Feature importances:", importances)
# 12. 保存模型
cv_model.write().overwrite().save("models/random_forest_model")
五、GraphX(图计算)
功能:图处理和并行计算
示例:社交网络分析
from pyspark.sql import SparkSession
import networkx as nx
import matplotlib.pyplot as plt
# 创建 SparkSession
spark = SparkSession.builder \
.appName("GraphX Example") \
.config("spark.some.config.option", "some-value") \
.getOrCreate()
# 注意:GraphX 主要在 Scala 中使用,Python 通过 GraphFrames 使用
from graphframes import GraphFrame
# 1. 创建顶点和边 DataFrame
# 顶点数据
vertices_data = [
("a", "Alice", 34),
("b", "Bob", 36),
("c", "Charlie", 30),
("d", "David", 29),
("e", "Esther", 32),
("f", "Fanny", 36),
("g", "Gabby", 60)
]
vertices = spark.createDataFrame(vertices_data, ["id", "name", "age"])
# 边数据
edges_data = [
("a", "b", "friend"),
("b", "c", "follow"),
("c", "b", "follow"),
("f", "c", "follow"),
("e", "f", "follow"),
("e", "d", "friend"),
("d", "a", "friend"),
("a", "e", "friend")
]
edges = spark.createDataFrame(edges_data, ["src", "dst", "relationship"])
# 2. 创建图
g = GraphFrame(vertices, edges)
# 3. 图查询
print("顶点数:", g.vertices.count())
print("边数:", g.edges.count())
# 4. 度计算
g.inDegrees.show() # 入度
g.outDegrees.show() # 出度
g.degrees.show() # 总度数
# 5. 运行 PageRank
results = g.pageRank(resetProbability=0.15, maxIter=10)
results.vertices.select("id", "pagerank").show()
# 6. 寻找连通分量
result = g.connectedComponents()
result.select("id", "component").orderBy("component").show()
# 7. 寻找三角形
triangle_count = g.triangleCount()
triangle_count.select("id", "count").show()
# 8. 标签传播算法(社区检测)
result = g.labelPropagation(maxIter=10)
result.select("id", "label").show()
六、Spark Structured Streaming
功能:基于 Spark SQL 的流处理引擎
示例:实时数据管道
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
# 创建 SparkSession
spark = SparkSession.builder \
.appName("StructuredStreaming Advanced") \
.getOrCreate()
# 模拟流数据源(实际中可能是 Kafka、Kinesis 等)
streaming_df = spark \
.readStream \
.format("rate") \
.option("rowsPerSecond", 100) \
.load() \
.withColumn("value", rand() * 100) # 添加随机值列
# 滑动窗口聚合
windowed_agg = streaming_df \
.withWatermark("timestamp", "10 seconds") \
.groupBy(
window(col("timestamp"), "5 minutes", "1 minute")
) \
.agg(
count("*").alias("count"),
avg("value").alias("avg_value"),
min("value").alias("min_value"),
max("value").alias("max_value")
)
# 输出到多个 sink
# 1. 输出到控制台
console_query = windowed_agg \
.writeStream \
.outputMode("complete") \
.format("console") \
.option("truncate", "false") \
.start()
# 2. 输出到内存表(用于交互式查询)
memory_query = windowed_agg \
.writeStream \
.queryName("aggregated_data") \
.outputMode("complete") \
.format("memory") \
.start()
# 3. 批处理交互查询
while True:
spark.sql("SELECT * FROM aggregated_data").show()
time.sleep(10)
七、SparkR
功能:R 语言接口
示例:使用 R 进行数据分析
# 初始化 SparkR
library(SparkR)
sparkR.session(appName = "SparkR Example")
# 1. 创建 DataFrame
df <- createDataFrame(iris)
# 2. 基本操作
showDF(df)
head(df)
# 3. 使用 dplyr 风格操作
library(sparklyr)
sc <- spark_connect(master = "local")
iris_tbl <- copy_to(sc, iris)
# dplyr 操作
iris_tbl %>%
filter(Sepal_Length > 5.0) %>%
group_by(Species) %>%
summarise(
count = n(),
avg_sepal_length = mean(Sepal_Length),
max_petal_length = max(Petal_Length)
) %>%
arrange(desc(avg_sepal_length)) %>%
show()
# 4. 机器学习
model <- ml_random_forest(
iris_tbl,
Species ~ .,
type = "classification",
max_depth = 5,
num_trees = 100
)
# 预测
predictions <- ml_predict(model, iris_tbl)
# 停止 Spark
sparkR.session.stop()
八、实用工具和特性
示例:Spark 配置和优化
from pyspark.sql import SparkSession
# 创建 SparkSession 并配置参数
spark = SparkSession.builder \
.appName("Optimized Job") \
.master("yarn") \
.config("spark.executor.memory", "4g") \
.config("spark.executor.cores", "2") \
.config("spark.executor.instances", "10") \
.config("spark.dynamicAllocation.enabled", "true") \
.config("spark.shuffle.service.enabled", "true") \
.config("spark.sql.adaptive.enabled", "true") \
.config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
.config("spark.sql.autoBroadcastJoinThreshold", "100mb") \
.config("spark.sql.broadcastTimeout", "300") \
.config("spark.default.parallelism", "200") \
.config("spark.sql.shuffle.partitions", "200") \
.enableHiveSupport() \
.getOrCreate()
# 缓存优化
df = spark.read.parquet("large_dataset.parquet")
# 缓存策略
df.cache() # 缓存到内存
# 或
df.persist(storageLevel="MEMORY_AND_DISK") # 内存和磁盘
# 检查点(用于迭代算法或复杂 DAG)
spark.sparkContext.setCheckpointDir("hdfs://checkpoint_dir")
df.checkpoint()
# 广播变量(用于小数据集)
small_lookup = {"A": 1, "B": 2, "C": 3}
broadcast_var = spark.sparkContext.broadcast(small_lookup)
# 累加器
counter = spark.sparkContext.accumulator(0)
def count_func(x):
global counter
if x > 10:
counter += 1
return x * 2
rdd.map(count_func).collect()
print("Count of values > 10:", counter.value)
总结对比
| 模块 | 主要用途 | 核心概念 | 适用场景 |
|---|---|---|---|
| Spark Core | 基础引擎 | RDD, 转换/行动操作 | 底层数据处理,自定义算法 |
| Spark SQL | 结构化查询 | DataFrame, Dataset, SQL | ETL, 数据分析,数据仓库 |
| Spark Streaming | 实时处理 | DStream, 微批次 | 实时监控,流式 ETL |
| MLlib | 机器学习 | Pipeline, Transformer, Estimator | 预测分析,推荐系统 |
| GraphX | 图计算 | 顶点,边,属性图 | 社交网络,推荐系统,路径规划 |
| Structured Streaming | 结构化流 | DataFrame, 事件时间,窗口 | 复杂事件处理,实时分析 |
这些模块共同构成了 Spark 强大的统一分析平台,可以处理批处理、流处理、机器学习和图计算等多种工作负载。