我们将编写一个完整的Spark本地模式运行的Python代码示例。
代码将包括以下步骤:
- 导入必要的库
- 创建SparkSession
- 创建一些示例数据(例如,一个简单的DataFrame)
- 对数据执行一些操作(例如,过滤、聚合)
- 输出结果
- 停止SparkSession
注意:确保本地已经安装了Java和Spark,并且设置了相应的环境变量。
如果没有安装,可以参考以下步骤:
- 安装Java 8或更高版本
- 下载Spark并解压,设置SPARK_HOME和PATH环境变量(此步骤可以没有, 因为安装pyspark库包的时候, 会自动下载安装预编译的 Spark 发行版(包含在包内) )
- 安装pyspark:pip install pyspark
以下代码在本地运行,使用本地模式(local[*])表示使用所有可用的核心。
完整的Spark本地模式Python代码示例
以下是一个完整的、可在本地运行的Spark应用代码,包含数据处理、转换和分析的完整流程:
# spark_local_demo.py
"""
完整的Spark本地模式运行示例
功能:分析电商交易数据,计算用户购买行为统计
"""
import os
import sys
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import logging
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class SparkLocalDemo:
"""Spark本地模式演示类"""
def __init__(self, app_name="SparkLocalDemo", master="local[*]", memory="2g"):
"""
初始化Spark会话
参数:
app_name: 应用名称
master: Spark运行模式,本地模式使用"local[*]"
memory: 执行器内存大小
"""
self.app_name = app_name
self.master = master
self.memory = memory
logger.info(f"正在初始化Spark会话,应用名称: {app_name}")
# 创建SparkSession
self.spark = SparkSession.builder \
.appName(self.app_name) \
.master(self.master) \
.config("spark.executor.memory", self.memory) \
.config("spark.driver.memory", self.memory) \
.config("spark.sql.adaptive.enabled", "true") \
.config("spark.sql.shuffle.partitions", "4") \
.config("spark.default.parallelism", "4") \
.getOrCreate()
# 获取SparkContext
self.sc = self.spark.sparkContext
logger.info(f"Spark初始化完成,版本: {self.sc.version}")
logger.info(f"Spark UI地址: http://localhost:4040")
def generate_sample_data(self):
"""生成模拟的电商交易数据"""
logger.info("正在生成模拟数据...")
# 定义Schema
transaction_schema = StructType([
StructField("transaction_id", StringType(), True),
StructField("user_id", IntegerType(), True),
StructField("product_id", IntegerType(), True),
StructField("product_name", StringType(), True),
StructField("category", StringType(), True),
StructField("price", DoubleType(), True),
StructField("quantity", IntegerType(), True),
StructField("transaction_date", DateType(), True),
StructField("payment_method", StringType(), True),
StructField("city", StringType(), True)
])
user_schema = StructType([
StructField("user_id", IntegerType(), True),
StructField("name", StringType(), True),
StructField("age", IntegerType(), True),
StructField("gender", StringType(), True),
StructField("registration_date", DateType(), True),
StructField("membership_level", StringType(), True)
])
# 创建模拟数据
np.random.seed(42)
# 生成用户数据
user_data = []
for i in range(1, 101):
user_data.append((
i,
f"User_{i}",
np.random.randint(18, 70),
np.random.choice(["M", "F"]),
datetime.now() - timedelta(days=np.random.randint(1, 1000)),
np.random.choice(["Bronze", "Silver", "Gold", "Platinum"], p=[0.4, 0.3, 0.2, 0.1])
))
# 生成交易数据
products = [
(1, "Laptop", "Electronics", 999.99),
(2, "Smartphone", "Electronics", 699.99),
(3, "Headphones", "Electronics", 199.99),
(4, "T-Shirt", "Clothing", 24.99),
(5, "Jeans", "Clothing", 59.99),
(6, "Book", "Books", 14.99),
(7, "Coffee Mug", "Home", 12.99),
(8, "Backpack", "Accessories", 49.99),
(9, "Watch", "Accessories", 199.99),
(10, "Shoes", "Clothing", 89.99)
]
cities = ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix",
"Philadelphia", "San Antonio", "San Diego", "Dallas", "San Jose"]
payment_methods = ["Credit Card", "Debit Card", "PayPal", "Apple Pay", "Google Pay"]
transaction_data = []
for i in range(1, 1001):
product = products[np.random.randint(0, len(products))]
transaction_data.append((
f"TXN{str(i).zfill(6)}",
np.random.randint(1, 101),
product[0],
product[1],
product[2],
product[3],
np.random.randint(1, 4),
datetime.now() - timedelta(days=np.random.randint(1, 365)),
np.random.choice(payment_methods),
np.random.choice(cities)
))
# 创建DataFrame
users_df = self.spark.createDataFrame(user_data, user_schema)
transactions_df = self.spark.createDataFrame(transaction_data, transaction_schema)
logger.info(f"模拟数据生成完成: 用户数={users_df.count()}, 交易数={transactions_df.count()}")
return users_df, transactions_df
def analyze_data(self, users_df, transactions_df):
"""执行数据分析任务"""
logger.info("开始数据分析...")
# 1. 基本数据预览
logger.info("1. 数据预览:")
logger.info(f"用户数据Schema: {users_df.printSchema()}")
logger.info(f"交易数据Schema: {transactions_df.printSchema()}")
logger.info("用户数据示例:")
users_df.show(5, truncate=False)
logger.info("交易数据示例:")
transactions_df.show(5, truncate=False)
# 2. 计算总交易额
logger.info("2. 计算总交易额和平均订单价值:")
transactions_df = transactions_df.withColumn(
"total_amount", col("price") * col("quantity")
)
total_revenue = transactions_df.agg(sum("total_amount").alias("total_revenue")).collect()[0][0]
avg_order_value = transactions_df.agg(avg("total_amount").alias("avg_order_value")).collect()[0][0]
logger.info(f"总交易额: ${total_revenue:,.2f}")
logger.info(f"平均订单价值: ${avg_order_value:,.2f}")
# 3. 按产品类别统计
logger.info("3. 按产品类别统计:")
category_stats = transactions_df.groupBy("category") \
.agg(
count("*").alias("transaction_count"),
sum("quantity").alias("total_quantity"),
sum("total_amount").alias("category_revenue"),
avg("price").alias("avg_price")
) \
.orderBy(desc("category_revenue"))
category_stats.show(truncate=False)
# 4. 按城市统计交易
logger.info("4. 按城市统计交易:")
city_stats = transactions_df.groupBy("city") \
.agg(
count("*").alias("transaction_count"),
sum("total_amount").alias("city_revenue"),
countDistinct("user_id").alias("unique_customers")
) \
.orderBy(desc("city_revenue"))
city_stats.show(truncate=False)
# 5. 用户购买行为分析
logger.info("5. 用户购买行为分析:")
user_transactions = transactions_df.groupBy("user_id") \
.agg(
count("*").alias("purchase_count"),
sum("total_amount").alias("total_spent"),
avg("total_amount").alias("avg_purchase_value")
) \
.orderBy(desc("total_spent"))
top_customers = user_transactions.limit(10)
logger.info("消费最高的10位客户:")
top_customers.show(truncate=False)
# 6. 关联用户和交易数据
logger.info("6. 关联用户和交易数据分析:")
joined_df = transactions_df.join(users_df, "user_id", "inner")
# 按会员等级分析
membership_stats = joined_df.groupBy("membership_level") \
.agg(
countDistinct("user_id").alias("customer_count"),
sum("total_amount").alias("total_revenue"),
avg("total_amount").alias("avg_revenue_per_customer")
) \
.orderBy(desc("total_revenue"))
logger.info("按会员等级统计:")
membership_stats.show(truncate=False)
# 7. 按年龄组分析
logger.info("7. 按年龄组分析:")
# 创建年龄分组
age_binned_df = joined_df.withColumn(
"age_group",
when(col("age") < 25, "18-24")
.when((col("age") >= 25) & (col("age") < 35), "25-34")
.when((col("age") >= 35) & (col("age") < 45), "35-44")
.when((col("age") >= 45) & (col("age") < 55), "45-54")
.otherwise("55+")
)
age_group_stats = age_binned_df.groupBy("age_group") \
.agg(
countDistinct("user_id").alias("customer_count"),
sum("total_amount").alias("total_revenue"),
avg("total_amount").alias("avg_purchase_value")
) \
.orderBy("age_group")
age_group_stats.show(truncate=False)
# 8. 月度销售趋势
logger.info("8. 月度销售趋势分析:")
monthly_sales = transactions_df.withColumn("month", date_format("transaction_date", "yyyy-MM")) \
.groupBy("month") \
.agg(
count("*").alias("transaction_count"),
sum("total_amount").alias("monthly_revenue"),
avg("total_amount").alias("avg_order_value")
) \
.orderBy("month")
monthly_sales.show(truncate=False)
return {
"total_revenue": total_revenue,
"avg_order_value": avg_order_value,
"top_customers": top_customers,
"category_stats": category_stats,
"city_stats": city_stats,
"membership_stats": membership_stats,
"age_group_stats": age_group_stats,
"monthly_sales": monthly_sales
}
def save_results(self, analysis_results):
"""保存分析结果"""
logger.info("保存分析结果...")
# 创建输出目录
output_dir = "spark_analysis_results"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# 保存每个DataFrame为CSV
for name, df in analysis_results.items():
if name != "total_revenue" and name != "avg_order_value":
output_path = os.path.join(output_dir, f"{name}.csv")
df.coalesce(1).write.mode("overwrite").option("header", "true").csv(output_path)
logger.info(f"已保存: {output_path}")
# 保存汇总统计
summary_path = os.path.join(output_dir, "summary.txt")
with open(summary_path, "w") as f:
f.write("=== Spark分析结果汇总 ===\n\n")
f.write(f"总交易额: ${analysis_results['total_revenue']:,.2f}\n")
f.write(f"平均订单价值: ${analysis_results['avg_order_value']:,.2f}\n\n")
f.write("分析完成时间: " + datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n")
logger.info(f"汇总结果已保存: {summary_path}")
return output_dir
def run_complete_demo(self):
"""运行完整的演示流程"""
try:
logger.info("开始运行Spark本地演示...")
# 1. 生成模拟数据
users_df, transactions_df = self.generate_sample_data()
# 2. 执行数据分析
analysis_results = self.analyze_data(users_df, transactions_df)
# 3. 保存结果
output_dir = self.save_results(analysis_results)
# 4. 显示Spark作业信息
logger.info("Spark作业信息:")
logger.info(f"应用ID: {self.sc.applicationId}")
# 5. 显示数据缓存信息
logger.info("数据缓存信息:")
for rdd in self.sc._jsc.getRDDStorageInfo():
logger.info(f"RDD: {rdd.name()}, 缓存级别: {rdd.storageLevel().description()}")
logger.info(f"演示完成! 结果保存在: {output_dir}")
logger.info(f"可以在浏览器中访问 http://localhost:4040 查看Spark UI")
return True
except Exception as e:
logger.error(f"运行过程中发生错误: {str(e)}")
import traceback
traceback.print_exc()
return False
def cleanup(self):
"""清理资源"""
logger.info("正在停止Spark会话...")
self.spark.stop()
logger.info("Spark会话已停止")
def main():
"""主函数"""
print("=" * 60)
print("Spark本地模式完整演示")
print("=" * 60)
# 创建演示实例
demo = SparkLocalDemo(
app_name="ECommerceAnalysis",
master="local[*]", # 使用所有可用的CPU核心
memory="2g"
)
try:
# 运行完整演示
success = demo.run_complete_demo()
if success:
print("\n" + "=" * 60)
print("演示成功完成!")
print("=" * 60)
else:
print("\n演示过程中出现错误!")
finally:
# 清理资源
demo.cleanup()
if __name__ == "__main__":
main()
运行说明
1. 环境要求
- Python 3.7+
- Java 8或更高版本
- Apache Spark(或通过PySpark自动安装)
2. 安装依赖
# 安装PySpark(`会自动下载Spark)`
pip install pyspark pandas numpy
# 或者如果已经安装了Spark,确保SPARK_HOME环境变量已设置
3. 运行代码
# 直接运行Python脚本
python spark_local_demo.py
# 或者设置更多内存(如果需要)
# python spark_local_demo.py --driver-memory 4g
4. 运行结果
- 控制台输出分析结果
- 生成
spark_analysis_results目录,包含CSV格式的分析结果 - 可以通过浏览器访问 http://localhost:4040 查看Spark作业详情
代码功能说明
- 数据生成: 创建模拟的电商交易和用户数据
- 数据分析:
- 计算总交易额和平均订单价值
- 按产品类别、城市、会员等级、年龄组统计
- 用户购买行为分析
- 月度销售趋势分析
- 结果保存: 将分析结果保存为CSV文件
- 资源管理: 正确初始化和清理Spark资源
注意事项
- 第一次运行时会自动下载Spark(如果未安装)
- 确保系统有足够的内存(至少4GB可用内存)
- 可以在代码中调整
master参数:local[*]: 使用所有CPU核心local[2]: 使用2个CPU核心local: 使用1个CPU核心
这个代码示例展示了Spark本地模式的完整工作流程,可以直接运行并查看结果,适合学习和测试使用。
问题:我在本地执行spark任务, 只需要安装pyspark这个python库, 不需要下载安装apache spark是吗
是的,你的理解基本正确。在本地执行 PySpark 任务时,通常只需要安装 pyspark 这个 Python 包,而不需要单独下载和安装 Apache Spark。
为什么只需要安装 PySpark?
PySpark 包已经包含了 Spark 运行时:
pip install pyspark
当你安装 PySpark 时,它会自动下载一个预编译的 Spark 发行版(包含在包内),这个发行版包含了:
- Spark 核心引擎
- 必要的依赖库
- PySpark 的 Python 绑定
验证安装
安装后,你可以直接使用:
from pyspark.sql import SparkSession
# 创建本地 Spark 会话
spark = SparkSession.builder \
.appName("MyApp") \
.master("local[*]") # 使用所有CPU核心
.getOrCreate()
print(spark.version) # 查看 Spark 版本
注意事项
-
Java 依赖:Spark 是基于 Java 的,所以你仍然需要安装 Java 8 或 11
# 检查 Java 是否已安装 java -version -
完整 Spark vs PySpark 包:
- PySpark 包:适合大多数本地开发和测试
- 完整 Spark:如果需要访问 Spark 的所有命令行工具、完整的配置选项,或者需要特定版本的 Spark,可以单独下载
-
Python 版本兼容性:确保你的 Python 版本与 PySpark 兼容(通常是 Python 3.7+)
如果遇到问题
如果你需要特定的 Spark 版本或配置,也可以:
# 安装特定版本的 PySpark
pip install pyspark==3.5.0
# 或者,如果你确实需要完整 Spark
# 1. 下载 Spark 安装包
# 2. 设置 SPARK_HOME 环境变量
# 3. 将 PySpark 指向这个安装路径
对于大多数本地开发和测试场景,pip install pyspark 就足够了,这也是最简单的入门方式。
总结
- 核心结论:仅装 pyspark(Spark 3.0+)+ 装 Java,就能满足基础的本地 Spark 开发,无需手动下载 Apache Spark 安装包;但 pyspark 本身依赖 Spark 核心,并非 “不需要 Apache Spark”,只是 pip 帮你内置了。
- 必装依赖:Java 是 Spark 运行的基础,无论是否手动装 Apache Spark,都必须安装。
- 特殊场景:需自定义 Spark 配置 / 版本时,仍需手动下载 Apache Spark 并配置
SPARK_HOME环境变量。