PyTorch张量运算详解
目录
引言
张量(Tensor)是PyTorch中最基本的数据结构,类似于NumPy中的数组,但具有更强的功能,特别是支持GPU加速计算。在深度学习中,张量用于表示各种数据,如输入特征、模型参数、梯度等。掌握张量的各种操作对于深度学习模型的构建和训练至关重要。
PyTorch提供了丰富的张量操作函数,包括基础算术运算、统计计算、索引切片、形状变换等。这些操作不仅能高效地利用GPU进行并行计算,还能满足深度学习中各种复杂的数据处理需求。
本文将详细介绍PyTorch中张量的各种操作,帮助读者全面掌握张量的使用方法。
张量的基础运算
算术运算
PyTorch支持常见的算术运算,包括加法、减法、乘法和除法。这些运算既可以是逐元素的,也可以是矩阵运算。
import torch
# 创建示例张量
torch.manual_seed(666)
v = torch.randint(low=1, high=9, size=(2, 3))
print("原始张量:")
print(v)
# 加法运算
print("\n加法运算:")
print("v + 1 =", v + 1) # 逐元素加法
print("v.add(1) =", v.add(1)) # 使用add函数
# 减法运算
print("\n减法运算:")
print("v - 1 =", v - 1) # 逐元素减法
print("v.sub(1) =", v.sub(1)) # 使用sub函数
# 乘法运算
print("\n乘法运算:")
print("v * 2 =", v * 2) # 逐元素乘法
print("v.mul(2) =", v.mul(2)) # 使用mul函数
# 除法运算
print("\n除法运算:")
print("v / 2 =", v / 2) # 逐元素除法
print("v.div(2) =", v.div(2)) # 使用div函数
原地运算与非原地运算
PyTorch中的张量运算分为原地运算和非原地运算。原地运算会直接修改原张量,而非原地运算会返回一个新的张量。
# 非原地运算(不修改原张量)
print("非原地运算:")
print("原始张量 v:")
print(v)
result = v + 1
print("v + 1 后的 v:")
print(v) # v 未被修改
print("运算结果:")
print(result)
# 原地运算(直接修改原张量)
print("\n原地运算:")
print("原始张量 v:")
print(v)
v.add_(1) # 使用下划线后缀表示原地运算
print("v.add_(1) 后的 v:")
print(v) # v 被修改
张量的统计运算
最值运算
PyTorch提供了多种获取张量最值的方法,包括全局最值和按维度的最值。
torch.manual_seed(22)
v = torch.randint(low=1, high=5, size=(2, 3))
print("示例张量:")
print(v)
# 全局最值
print("\n全局最值:")
print("最大值:", v.max())
print("最小值:", v.min())
# 按维度最值
print("\n按维度最值:")
print("按列(dim=0)的最大值:", v.max(dim=0))
print("按行(dim=1)的最大值:", v.max(dim=1))
print("按列(dim=0)的最小值:", v.min(dim=0))
print("按行(dim=1)的最小值:", v.min(dim=1))
# 获取最值和索引
print("\n最值和索引:")
max_values, max_indices = v.max(dim=0)
print("按列最大值:", max_values)
print("按列最大值索引:", max_indices)
均值与求和
均值和求和是常用的统计运算,需要注意数据类型的要求。
# 注意:mean()要求张量是浮点类型
v_float = v.float()
print("\n均值运算:")
print("全局均值:", v_float.mean())
print("按列均值:", v_float.mean(dim=0))
print("按行均值:", v_float.mean(dim=1))
print("\n求和运算:")
print("全局求和:", v.sum())
print("按列求和:", v.sum(dim=0))
print("按行求和:", v.sum(dim=1))
其他数学运算
PyTorch还提供了丰富的数学运算函数。
print("\n其他数学运算:")
print("平方根:", v_float.sqrt())
print("指数:", v_float.exp())
print("对数:", v_float.log())
print("幂运算:", v.pow(2))
print("向下取整:", v_float.floor())
print("向上取整:", v_float.ceil())
print("四舍五入:", v_float.round())
print("绝对值:", v_float.abs())
张量的索引与切片
基础索引
张量支持类似NumPy的索引方式,可以访问特定元素、行或列。
torch.manual_seed(666)
v = torch.randint(low=1, high=9, size=(5, 5))
print("示例张量:")
print(v)
# 访问特定元素
print("\n基础索引:")
print("第2行第3列元素:", v[1, 2]) # 注意索引从0开始
print("第3行:", v[2])
print("第4列:", v[:, 3])
布尔索引
布尔索引允许根据条件筛选元素。
print("\n布尔索引:")
print("大于4的元素:", v[v > 4])
print("第1行大于4的元素:", v[0][v[0] > 4])
高级索引
高级索引允许使用索引数组来访问元素。
print("\n高级索引:")
rows = torch.tensor([1, 3]) # 第2行和第4行
cols = torch.tensor([2, 4]) # 第3列和第5列
print("第2行和第4行,第3列和第5列的元素:", v[rows[:, None], cols])
切片操作
切片操作允许访问张量的子集。
print("\n切片操作:")
print("前3行,后2列:", v[:3, -2:])
print("第2-4行,第2-3列:", v[1:4, 1:3])
张量的形状修改与维度操作
形状重塑
形状重塑是深度学习中常用的操作,用于调整张量的维度以满足模型要求。
v = torch.randint(low=1, high=9, size=(1, 1, 1, 1, 3, 4, 1, 1))
print("原始张量形状:", v.shape)
# reshape和view都可以改变张量形状
reshape_1 = v.reshape(6, -1) # -1表示自动计算该维度大小
print("reshape后形状:", reshape_1.shape)
reshape_2 = v.view(6, -1)
print("view后形状:", reshape_2.shape)
# reshape和view的区别
print("reshape和view是否共享内存:",
torch.equal(reshape_1, reshape_2) and
(reshape_1.data_ptr() == reshape_2.data_ptr()))
维度增减
squeeze和unsqueeze用于删除或增加大小为1的维度。
print("\n维度增减:")
# squeeze删除大小为1的维度
squeeze_1 = v.squeeze()
print("squeeze后形状:", squeeze_1.shape)
# unsqueeze增加大小为1的维度
unsqueeze_1 = squeeze_1.unsqueeze(0)
print("unsqueeze后形状:", unsqueeze_1.shape)
维度变换
permute和transpose用于改变维度顺序。
print("\n维度变换:")
# transpose交换两个维度
transpose_1 = v.transpose(4, 5)
print("transpose后形状:", transpose_1.shape)
# permute重新排列所有维度
permute_1 = v.permute(5, 4, 6, 7, 0, 1, 2, 3)
print("permute后形状:", permute_1.shape)
print("permute后是否连续:", permute_1.is_contiguous())
矩阵运算
矩阵运算是深度学习中的核心操作,特别是矩阵乘法。
print("\n矩阵运算:")
torch.manual_seed(22)
a = torch.randint(low=1, high=5, size=(2, 3))
b = torch.randint(low=1, high=9, size=(3, 2))
print("矩阵A (2x3):")
print(a)
print("矩阵B (3x2):")
print(b)
# 矩阵乘法
print("矩阵乘法 A @ B:")
print(a @ b)
# 使用torch.mm进行矩阵乘法(仅适用于2D张量)
print("使用torch.mm:")
print(torch.mm(a, b))
# 使用torch.matmul进行矩阵乘法(适用于任意维度)
print("使用torch.matmul:")
print(torch.matmul(a, b))
总结
PyTorch提供了丰富的张量操作函数,涵盖了从基础算术运算到复杂形状变换的各种需求。掌握这些操作对于深度学习模型的构建和训练至关重要。
关键要点总结:
- 基础运算:支持加减乘除等算术运算,区分原地和非原地操作
- 统计运算:提供最值、均值、求和等统计函数
- 索引切片:支持基础索引、布尔索引、高级索引和切片操作
- 形状操作:reshape/view用于形状重塑,squeeze/unsqueeze用于维度增减,permute/transpose用于维度变换
- 矩阵运算:使用@、torch.mm或torch.matmul进行矩阵乘法
在实际应用中,合理使用这些张量操作可以大大提高数据处理效率,为深度学习模型的训练和推理提供有力支持。随着对PyTorch理解的深入,你将能够更加灵活地运用这些操作来解决复杂的机器学习问题。