[spark]spark本地模式 (实战)

9 阅读9分钟

我们将编写一个完整的Spark本地模式运行的Python代码示例。
代码将包括以下步骤:

  1. 导入必要的库
  2. 创建SparkSession
  3. 创建一些示例数据(例如,一个简单的DataFrame)
  4. 对数据执行一些操作(例如,过滤、聚合)
  5. 输出结果
  6. 停止SparkSession

注意:确保本地已经安装了Java和Spark,并且设置了相应的环境变量。
如果没有安装,可以参考以下步骤:

  • 安装Java 8或更高版本
  • 下载Spark并解压,设置SPARK_HOME和PATH环境变量(此步骤可以没有, 因为安装pyspark库包的时候, 会自动下载安装预编译的 Spark 发行版(包含在包内) )
  • 安装pyspark:pip install pyspark

以下代码在本地运行,使用本地模式(local[*])表示使用所有可用的核心。

完整的Spark本地模式Python代码示例

以下是一个完整的、可在本地运行的Spark应用代码,包含数据处理、转换和分析的完整流程:

# spark_local_demo.py
"""
完整的Spark本地模式运行示例
功能:分析电商交易数据,计算用户购买行为统计
"""

import os
import sys
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import logging

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

class SparkLocalDemo:
    """Spark本地模式演示类"""
    
    def __init__(self, app_name="SparkLocalDemo", master="local[*]", memory="2g"):
        """
        初始化Spark会话
        
        参数:
            app_name: 应用名称
            master: Spark运行模式,本地模式使用"local[*]"
            memory: 执行器内存大小
        """
        self.app_name = app_name
        self.master = master
        self.memory = memory
        
        logger.info(f"正在初始化Spark会话,应用名称: {app_name}")
        
        # 创建SparkSession
        self.spark = SparkSession.builder \
            .appName(self.app_name) \
            .master(self.master) \
            .config("spark.executor.memory", self.memory) \
            .config("spark.driver.memory", self.memory) \
            .config("spark.sql.adaptive.enabled", "true") \
            .config("spark.sql.shuffle.partitions", "4") \
            .config("spark.default.parallelism", "4") \
            .getOrCreate()
        
        # 获取SparkContext
        self.sc = self.spark.sparkContext
        
        logger.info(f"Spark初始化完成,版本: {self.sc.version}")
        logger.info(f"Spark UI地址: http://localhost:4040")
        
    def generate_sample_data(self):
        """生成模拟的电商交易数据"""
        logger.info("正在生成模拟数据...")
        
        # 定义Schema
        transaction_schema = StructType([
            StructField("transaction_id", StringType(), True),
            StructField("user_id", IntegerType(), True),
            StructField("product_id", IntegerType(), True),
            StructField("product_name", StringType(), True),
            StructField("category", StringType(), True),
            StructField("price", DoubleType(), True),
            StructField("quantity", IntegerType(), True),
            StructField("transaction_date", DateType(), True),
            StructField("payment_method", StringType(), True),
            StructField("city", StringType(), True)
        ])
        
        user_schema = StructType([
            StructField("user_id", IntegerType(), True),
            StructField("name", StringType(), True),
            StructField("age", IntegerType(), True),
            StructField("gender", StringType(), True),
            StructField("registration_date", DateType(), True),
            StructField("membership_level", StringType(), True)
        ])
        
        # 创建模拟数据
        np.random.seed(42)
        
        # 生成用户数据
        user_data = []
        for i in range(1, 101):
            user_data.append((
                i,
                f"User_{i}",
                np.random.randint(18, 70),
                np.random.choice(["M", "F"]),
                datetime.now() - timedelta(days=np.random.randint(1, 1000)),
                np.random.choice(["Bronze", "Silver", "Gold", "Platinum"], p=[0.4, 0.3, 0.2, 0.1])
            ))
        
        # 生成交易数据
        products = [
            (1, "Laptop", "Electronics", 999.99),
            (2, "Smartphone", "Electronics", 699.99),
            (3, "Headphones", "Electronics", 199.99),
            (4, "T-Shirt", "Clothing", 24.99),
            (5, "Jeans", "Clothing", 59.99),
            (6, "Book", "Books", 14.99),
            (7, "Coffee Mug", "Home", 12.99),
            (8, "Backpack", "Accessories", 49.99),
            (9, "Watch", "Accessories", 199.99),
            (10, "Shoes", "Clothing", 89.99)
        ]
        
        cities = ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix", 
                 "Philadelphia", "San Antonio", "San Diego", "Dallas", "San Jose"]
        
        payment_methods = ["Credit Card", "Debit Card", "PayPal", "Apple Pay", "Google Pay"]
        
        transaction_data = []
        for i in range(1, 1001):
            product = products[np.random.randint(0, len(products))]
            transaction_data.append((
                f"TXN{str(i).zfill(6)}",
                np.random.randint(1, 101),
                product[0],
                product[1],
                product[2],
                product[3],
                np.random.randint(1, 4),
                datetime.now() - timedelta(days=np.random.randint(1, 365)),
                np.random.choice(payment_methods),
                np.random.choice(cities)
            ))
        
        # 创建DataFrame
        users_df = self.spark.createDataFrame(user_data, user_schema)
        transactions_df = self.spark.createDataFrame(transaction_data, transaction_schema)
        
        logger.info(f"模拟数据生成完成: 用户数={users_df.count()}, 交易数={transactions_df.count()}")
        
        return users_df, transactions_df
    
    def analyze_data(self, users_df, transactions_df):
        """执行数据分析任务"""
        logger.info("开始数据分析...")
        
        # 1. 基本数据预览
        logger.info("1. 数据预览:")
        logger.info(f"用户数据Schema: {users_df.printSchema()}")
        logger.info(f"交易数据Schema: {transactions_df.printSchema()}")
        
        logger.info("用户数据示例:")
        users_df.show(5, truncate=False)
        
        logger.info("交易数据示例:")
        transactions_df.show(5, truncate=False)
        
        # 2. 计算总交易额
        logger.info("2. 计算总交易额和平均订单价值:")
        transactions_df = transactions_df.withColumn(
            "total_amount", col("price") * col("quantity")
        )
        
        total_revenue = transactions_df.agg(sum("total_amount").alias("total_revenue")).collect()[0][0]
        avg_order_value = transactions_df.agg(avg("total_amount").alias("avg_order_value")).collect()[0][0]
        
        logger.info(f"总交易额: ${total_revenue:,.2f}")
        logger.info(f"平均订单价值: ${avg_order_value:,.2f}")
        
        # 3. 按产品类别统计
        logger.info("3. 按产品类别统计:")
        category_stats = transactions_df.groupBy("category") \
            .agg(
                count("*").alias("transaction_count"),
                sum("quantity").alias("total_quantity"),
                sum("total_amount").alias("category_revenue"),
                avg("price").alias("avg_price")
            ) \
            .orderBy(desc("category_revenue"))
        
        category_stats.show(truncate=False)
        
        # 4. 按城市统计交易
        logger.info("4. 按城市统计交易:")
        city_stats = transactions_df.groupBy("city") \
            .agg(
                count("*").alias("transaction_count"),
                sum("total_amount").alias("city_revenue"),
                countDistinct("user_id").alias("unique_customers")
            ) \
            .orderBy(desc("city_revenue"))
        
        city_stats.show(truncate=False)
        
        # 5. 用户购买行为分析
        logger.info("5. 用户购买行为分析:")
        user_transactions = transactions_df.groupBy("user_id") \
            .agg(
                count("*").alias("purchase_count"),
                sum("total_amount").alias("total_spent"),
                avg("total_amount").alias("avg_purchase_value")
            ) \
            .orderBy(desc("total_spent"))
        
        top_customers = user_transactions.limit(10)
        logger.info("消费最高的10位客户:")
        top_customers.show(truncate=False)
        
        # 6. 关联用户和交易数据
        logger.info("6. 关联用户和交易数据分析:")
        joined_df = transactions_df.join(users_df, "user_id", "inner")
        
        # 按会员等级分析
        membership_stats = joined_df.groupBy("membership_level") \
            .agg(
                countDistinct("user_id").alias("customer_count"),
                sum("total_amount").alias("total_revenue"),
                avg("total_amount").alias("avg_revenue_per_customer")
            ) \
            .orderBy(desc("total_revenue"))
        
        logger.info("按会员等级统计:")
        membership_stats.show(truncate=False)
        
        # 7. 按年龄组分析
        logger.info("7. 按年龄组分析:")
        
        # 创建年龄分组
        age_binned_df = joined_df.withColumn(
            "age_group",
            when(col("age") < 25, "18-24")
            .when((col("age") >= 25) & (col("age") < 35), "25-34")
            .when((col("age") >= 35) & (col("age") < 45), "35-44")
            .when((col("age") >= 45) & (col("age") < 55), "45-54")
            .otherwise("55+")
        )
        
        age_group_stats = age_binned_df.groupBy("age_group") \
            .agg(
                countDistinct("user_id").alias("customer_count"),
                sum("total_amount").alias("total_revenue"),
                avg("total_amount").alias("avg_purchase_value")
            ) \
            .orderBy("age_group")
        
        age_group_stats.show(truncate=False)
        
        # 8. 月度销售趋势
        logger.info("8. 月度销售趋势分析:")
        
        monthly_sales = transactions_df.withColumn("month", date_format("transaction_date", "yyyy-MM")) \
            .groupBy("month") \
            .agg(
                count("*").alias("transaction_count"),
                sum("total_amount").alias("monthly_revenue"),
                avg("total_amount").alias("avg_order_value")
            ) \
            .orderBy("month")
        
        monthly_sales.show(truncate=False)
        
        return {
            "total_revenue": total_revenue,
            "avg_order_value": avg_order_value,
            "top_customers": top_customers,
            "category_stats": category_stats,
            "city_stats": city_stats,
            "membership_stats": membership_stats,
            "age_group_stats": age_group_stats,
            "monthly_sales": monthly_sales
        }
    
    def save_results(self, analysis_results):
        """保存分析结果"""
        logger.info("保存分析结果...")
        
        # 创建输出目录
        output_dir = "spark_analysis_results"
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        
        # 保存每个DataFrame为CSV
        for name, df in analysis_results.items():
            if name != "total_revenue" and name != "avg_order_value":
                output_path = os.path.join(output_dir, f"{name}.csv")
                df.coalesce(1).write.mode("overwrite").option("header", "true").csv(output_path)
                logger.info(f"已保存: {output_path}")
        
        # 保存汇总统计
        summary_path = os.path.join(output_dir, "summary.txt")
        with open(summary_path, "w") as f:
            f.write("=== Spark分析结果汇总 ===\n\n")
            f.write(f"总交易额: ${analysis_results['total_revenue']:,.2f}\n")
            f.write(f"平均订单价值: ${analysis_results['avg_order_value']:,.2f}\n\n")
            f.write("分析完成时间: " + datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n")
        
        logger.info(f"汇总结果已保存: {summary_path}")
        return output_dir
    
    def run_complete_demo(self):
        """运行完整的演示流程"""
        try:
            logger.info("开始运行Spark本地演示...")
            
            # 1. 生成模拟数据
            users_df, transactions_df = self.generate_sample_data()
            
            # 2. 执行数据分析
            analysis_results = self.analyze_data(users_df, transactions_df)
            
            # 3. 保存结果
            output_dir = self.save_results(analysis_results)
            
            # 4. 显示Spark作业信息
            logger.info("Spark作业信息:")
            logger.info(f"应用ID: {self.sc.applicationId}")
            
            # 5. 显示数据缓存信息
            logger.info("数据缓存信息:")
            for rdd in self.sc._jsc.getRDDStorageInfo():
                logger.info(f"RDD: {rdd.name()}, 缓存级别: {rdd.storageLevel().description()}")
            
            logger.info(f"演示完成! 结果保存在: {output_dir}")
            logger.info(f"可以在浏览器中访问 http://localhost:4040 查看Spark UI")
            
            return True
            
        except Exception as e:
            logger.error(f"运行过程中发生错误: {str(e)}")
            import traceback
            traceback.print_exc()
            return False
        
    def cleanup(self):
        """清理资源"""
        logger.info("正在停止Spark会话...")
        self.spark.stop()
        logger.info("Spark会话已停止")

def main():
    """主函数"""
    print("=" * 60)
    print("Spark本地模式完整演示")
    print("=" * 60)
    
    # 创建演示实例
    demo = SparkLocalDemo(
        app_name="ECommerceAnalysis",
        master="local[*]",  # 使用所有可用的CPU核心
        memory="2g"
    )
    
    try:
        # 运行完整演示
        success = demo.run_complete_demo()
        
        if success:
            print("\n" + "=" * 60)
            print("演示成功完成!")
            print("=" * 60)
        else:
            print("\n演示过程中出现错误!")
            
    finally:
        # 清理资源
        demo.cleanup()

if __name__ == "__main__":
    main()

运行说明

1. 环境要求

  • Python 3.7+
  • Java 8或更高版本
  • Apache Spark(或通过PySpark自动安装)

2. 安装依赖

# 安装PySpark(`会自动下载Spark)`
pip install pyspark pandas numpy

# 或者如果已经安装了Spark,确保SPARK_HOME环境变量已设置

3. 运行代码

# 直接运行Python脚本
python spark_local_demo.py

# 或者设置更多内存(如果需要)
# python spark_local_demo.py --driver-memory 4g

4. 运行结果

  • 控制台输出分析结果
  • 生成spark_analysis_results目录,包含CSV格式的分析结果
  • 可以通过浏览器访问 http://localhost:4040 查看Spark作业详情

代码功能说明

  1. 数据生成: 创建模拟的电商交易和用户数据
  2. 数据分析:
    • 计算总交易额和平均订单价值
    • 按产品类别、城市、会员等级、年龄组统计
    • 用户购买行为分析
    • 月度销售趋势分析
  3. 结果保存: 将分析结果保存为CSV文件
  4. 资源管理: 正确初始化和清理Spark资源

注意事项

  1. 第一次运行时会自动下载Spark(如果未安装)
  2. 确保系统有足够的内存(至少4GB可用内存)
  3. 可以在代码中调整master参数:
    • local[*]: 使用所有CPU核心
    • local[2]: 使用2个CPU核心
    • local: 使用1个CPU核心

这个代码示例展示了Spark本地模式的完整工作流程,可以直接运行并查看结果,适合学习和测试使用。




问题:我在本地执行spark任务, 只需要安装pyspark这个python库, 不需要下载安装apache spark是吗

是的,你的理解基本正确。在本地执行 PySpark 任务时,通常只需要安装 pyspark 这个 Python 包,而不需要单独下载和安装 Apache Spark。

为什么只需要安装 PySpark?

PySpark 包已经包含了 Spark 运行时

pip install pyspark

当你安装 PySpark 时,它会自动下载一个预编译的 Spark 发行版(包含在包内),这个发行版包含了:

  • Spark 核心引擎
  • 必要的依赖库
  • PySpark 的 Python 绑定

验证安装

安装后,你可以直接使用:

from pyspark.sql import SparkSession

# 创建本地 Spark 会话
spark = SparkSession.builder \
    .appName("MyApp") \
    .master("local[*]")  # 使用所有CPU核心
    .getOrCreate()

print(spark.version)  # 查看 Spark 版本

注意事项

  1. Java 依赖:Spark 是基于 Java 的,所以你仍然需要安装 Java 8 或 11

    # 检查 Java 是否已安装
    java -version
    
  2. 完整 Spark vs PySpark 包

    • PySpark 包:适合大多数本地开发和测试
    • 完整 Spark:如果需要访问 Spark 的所有命令行工具、完整的配置选项,或者需要特定版本的 Spark,可以单独下载
  3. Python 版本兼容性:确保你的 Python 版本与 PySpark 兼容(通常是 Python 3.7+)

如果遇到问题

如果你需要特定的 Spark 版本或配置,也可以:

# 安装特定版本的 PySpark
pip install pyspark==3.5.0

# 或者,如果你确实需要完整 Spark
# 1. 下载 Spark 安装包
# 2. 设置 SPARK_HOME 环境变量
# 3. 将 PySpark 指向这个安装路径

对于大多数本地开发和测试场景,pip install pyspark 就足够了,这也是最简单的入门方式。

总结

  1. 核心结论:仅装 pyspark(Spark 3.0+)+ 装 Java,就能满足基础的本地 Spark 开发,无需手动下载 Apache Spark 安装包;但 pyspark 本身依赖 Spark 核心,并非 “不需要 Apache Spark”,只是 pip 帮你内置了。
  2. 必装依赖:Java 是 Spark 运行的基础,无论是否手动装 Apache Spark,都必须安装。
  3. 特殊场景:需自定义 Spark 配置 / 版本时,仍需手动下载 Apache Spark 并配置 SPARK_HOME 环境变量。