跳转至

Apache Spark 入门

一句话概述:Apache Spark 是分布式大数据计算引擎,能在集群上并行处理TB级数据,比传统MapReduce快10-100倍。

核心知识点表

概念白话解释
RDD弹性分布式数据集,Spark最底层的数据抽象(现在一般不直接用)
DataFrame类似pandas的表格数据结构,但可以分布在多台机器上
SparkSessionSpark程序的入口,所有操作都从它开始
Transformation转换操作(如filter、map),懒执行,不立即计算
Action触发计算的操作(如collect、count),调用后才真正开始算
Partition数据分区,一份大数据被切成多块,分配到不同机器上并行处理
Executor工作节点,真正干活的进程
Driver驱动程序,负责调度和协调
Spark SQL用SQL语法查询分布式数据
PySparkSpark的Python API

版本信息(2026年5月)

  • Spark 4.1.1(当前稳定版)
  • Spark 4.2.0(预览中)
  • 亮点:Spark Declarative Pipelines、实时流处理、VARIANT数据类型

安装配置

方式一:PySpark(最简单,学习推荐)

# 创建虚拟环境
conda create -n spark python=3.12 -y
conda activate spark

# 安装PySpark
pip install pyspark  # 包含了Spark引擎,不需要额外装Java(自带)

# 验证安装
pyspark  # 启动PySpark交互式Shell
# 或者
python -c "import pyspark; print(pyspark.__version__)"

方式二:完整安装

# 1. 安装Java 17+(Spark 4.x要求)
sudo apt install openjdk-17-jdk -y  # Ubuntu
java -version  # 验证

# 2. 下载Spark
wget https://downloads.apache.org/spark/spark-4.1.1/spark-4.1.1-bin-hadoop3.tgz
tar -xzf spark-4.1.1-bin-hadoop3.tgz  # 解压
sudo mv spark-4.1.1-bin-hadoop3 /opt/spark  # 移到/opt

# 3. 配置环境变量
echo 'export SPARK_HOME=/opt/spark' >> ~/.bashrc
echo 'export PATH=$SPARK_HOME/bin:$PATH' >> ~/.bashrc
source ~/.bashrc

# 4. 验证
spark-shell  # Scala交互式Shell
pyspark      # Python交互式Shell
spark-submit --version  # 查看版本

方式三:Docker

docker pull apache/spark:4.1.1  # 拉取镜像
docker run -it apache/spark:4.1.1 /opt/spark/bin/pyspark  # 启动PySpark

基本使用

创建SparkSession

# spark_basic.py
from pyspark.sql import SparkSession  # 导入SparkSession

# 创建Spark会话(程序入口)
spark = SparkSession.builder \
    .appName("MyFirstApp") \        # 应用名称(在Spark UI上显示)
    .master("local[*]") \           # 本地模式,用所有CPU核心
    .getOrCreate()                  # 获取或创建会话

# 查看Spark版本
print(spark.version)

# Spark UI 地址:http://localhost:4040

读取和操作数据

# ===== 从CSV读取 =====
df = spark.read.csv(
    "data/users.csv",       # 文件路径
    header=True,            # 第一行是列名
    inferSchema=True,       # 自动推断数据类型
)

df.show(5)        # 显示前5行
df.printSchema()  # 显示表结构(列名+类型)
df.count()        # 统计行数

# ===== 从JSON读取 =====
df_json = spark.read.json("data/events.json")

# ===== 从Parquet读取(推荐格式,压缩+列式存储) =====
df_parquet = spark.read.parquet("data/sales.parquet")

# ===== 手动创建DataFrame =====
from pyspark.sql.types import StructType, StructField, StringType, IntegerType

data = [("Alice", 25), ("Bob", 30), ("Charlie", 35)]  # 数据
schema = StructType([                                   # 定义Schema
    StructField("name", StringType(), True),            # 名字,字符串,可为空
    StructField("age", IntegerType(), True),            # 年龄,整数,可为空
])
df = spark.createDataFrame(data, schema)  # 创建DataFrame

DataFrame常用操作

from pyspark.sql.functions import col, avg, count, when, lit  # 导入常用函数

# 选择列
df.select("name", "age").show()  # 选择name和age列

# 过滤
df.filter(col("age") > 25).show()  # 筛选年龄大于25的
df.where(col("status") == "active").show()  # where和filter等价

# 添加新列
df = df.withColumn(
    "age_group",  # 新列名
    when(col("age") < 30, "young")  # 条件:小于30→young
    .when(col("age") < 50, "middle")  # 30-49→middle
    .otherwise("senior")  # 其他→senior
)

# 分组聚合
df.groupBy("department") \
    .agg(
        count("*").alias("employee_count"),  # 员工数
        avg("salary").alias("avg_salary"),   # 平均工资
    ) \
    .orderBy(col("avg_salary").desc()) \     # 按平均工资降序
    .show()

# 去重
df.dropDuplicates(["email"]).show()  # 按email去重

# 排序
df.orderBy(col("age").desc()).show()  # 按年龄降序

# 重命名列
df = df.withColumnRenamed("old_name", "new_name")

Spark SQL

# 注册临时视图
df.createOrReplaceTempView("users")  # 注册为SQL临时表

# 用SQL查询
result = spark.sql("""
    SELECT
        department,
        COUNT(*) AS cnt,
        AVG(salary) AS avg_salary
    FROM users
    WHERE status = 'active'
    GROUP BY department
    HAVING COUNT(*) > 5
    ORDER BY avg_salary DESC
""")
result.show()

高级用法

读写Parquet(生产推荐格式)

# 写入Parquet
df.write.parquet(
    "output/users.parquet",  # 输出路径
    mode="overwrite",        # 覆盖写入(append=追加)
    compression="snappy",    # 压缩算法
)

# 分区写入(按日期分区,查询时自动剪枝)
df.write.partitionBy("date") \
    .parquet("output/events/", mode="overwrite")

# 读取分区数据(自动发现分区)
df = spark.read.parquet("output/events/")
df.filter(col("date") == "2026-05-13").show()  # 只读2026-05-13的分区

UDF用户自定义函数

from pyspark.sql.functions import udf
from pyspark.sql.types import StringType

# 定义Python函数
def classify_age(age):
    """根据年龄分类"""
    if age < 18:
        return "未成年"
    elif age < 60:
        return "成年"
    return "老年"

# 注册为UDF
classify_udf = udf(classify_age, StringType())  # 返回类型是字符串

# 使用UDF
df = df.withColumn("age_category", classify_udf(col("age")))

窗口函数

from pyspark.sql.window import Window  # 导入窗口
from pyspark.sql.functions import row_number, rank, lag

# 定义窗口:按部门分组,按工资降序
window_spec = Window.partitionBy("department").orderBy(col("salary").desc())

df = df.withColumn("salary_rank", rank().over(window_spec))  # 工资排名
df = df.withColumn("row_num", row_number().over(window_spec))  # 行号

连接(Join)

# 内连接
result = df_users.join(
    df_orders,
    df_users["user_id"] == df_orders["user_id"],  # 连接条件
    "inner"  # 连接类型:inner/left/right/full/cross
)

# 广播连接(小表广播到所有节点,避免Shuffle)
from pyspark.sql.functions import broadcast
result = df_big.join(
    broadcast(df_small),  # 小表广播
    "user_id"
)

常见报错与解决

报错信息原因解决方案
Java not found没装Java或版本不对安装Java 17+:sudo apt install openjdk-17-jdk
OutOfMemoryError数据太大内存不够增大内存:.config("spark.driver.memory", "4g")
AnalysisException: cannot resolve列名写错df.columns查看所有列名
Py4JJavaErrorJava底层报错看完整错误栈,通常是数据类型问题
FileNotFoundException文件路径错误检查路径,Spark不支持~,用绝对路径
Task not serializableUDF中引用了不可序列化对象用broadcast变量或简化UDF

速查表

# ===== SparkSession =====
spark = SparkSession.builder.appName("X").master("local[*]").getOrCreate()
spark.stop()  # 停止会话

# ===== 读取数据 =====
spark.read.csv("path", header=True, inferSchema=True)
spark.read.json("path")
spark.read.parquet("path")
spark.read.jdbc(url, table, properties)

# ===== 写入数据 =====
df.write.csv("path", header=True, mode="overwrite")
df.write.parquet("path", mode="overwrite")
df.write.json("path")

# ===== 常用操作 =====
df.select("col1", "col2")        # 选列
df.filter(col("x") > 10)        # 过滤
df.groupBy("col").agg(count("*"))  # 分组聚合
df.join(df2, "key", "inner")     # 连接
df.orderBy(col("x").desc())     # 排序
df.dropDuplicates(["col"])      # 去重
df.withColumn("new", expr)       # 新增列
df.drop("col")                   # 删除列

# ===== 常用函数 =====
# col, lit, when, otherwise
# count, sum, avg, min, max
# concat, substring, lower, upper, trim
# year, month, day, date_format
# row_number, rank, dense_rank, lag, lead

同类工具对比

特性SparkPolarsPandasDask
数据规模TB-PB级GB-TB级MB-GB级GB-TB级
运行方式分布式集群单机多线程单机单线程分布式
延迟秒级毫秒级毫秒级秒级
语言Scala/Python/R/SQLRust/PythonPythonPython
学习曲线中高最低
适合场景大规模数据处理中型高性能处理小数据探索pandas扩展

面试建议:Spark是大数据面试必考项。重点理解:1)懒执行(Transformation vs Action);2)DataFrame vs RDD;3)Shuffle的概念(数据跨节点重分布);4)broadcast join优化。能解释"Spark为什么比MapReduce快"(内存计算+DAG优化+懒执行)是加分项。