
# 章节 2: 高级张量操作
在前面介绍的基础张量操作之上,本章将讲解更高级的张量操作方法。对张量结构、数据类型和设备放置进行精细控制,对于有效准备数据和实现复杂的深度学习模型来说非常重要。
在本章中,你将学会:
- 使用索引和切片方法,精确选择和修改张量元素。
- 使用 `view()`、`reshape()` 和 `permute()` 重构张量,而不改变其数据。
- 使用 `cat()`、`stack()`、`split()` 和 `chunk()` 组合多个张量或将单个张量分割成不同部分。
- 使用 PyTorch 的广播机制,对形状兼容但不同的张量之间执行操作。
- 处理不同的数值数据类型(例如 `float`、`int`),并相应地转换张量类型。
- 使用 `.to(device)` 在 CPU 内存和 GPU 加速器之间传输张量。
掌握这些操作对于处理深度学习工作流中遇到的多样化数据格式和计算要求是必要的。
# 张量索引与切片
访问和修改张量的特定部分是处理深度学习数据时常有的需求。无论您是需要选择单个数据点、提取一批训练样本、裁剪图像补丁,还是挑选特定特征,PyTorch都提供了强大且灵活的索引和切片机制,类似于NumPy数组中的那些,但PyTorch的机制与GPU加速和自动微分功能相集成。
### 基本索引
访问张量元素最直接的方式是使用标准的Python整数索引。请记住,PyTorch张量与Python列表和NumPy数组一样,使用0作为起始索引。
对于一维张量,您可以使用其索引访问元素:
```scala 3
import torch.*
// 创建一个一维张量
val x_1d = torch.tensor(Seq(10, 11, 12, 13, 14))
println(f"原始一维张量:\n{x_1d}")
// 访问第一个元素
val first_element = x_1d(0)
println(f"\n第一个元素 (x_1d(0)): {first_element}, 类型: {type(first_element)}")
// 访问最后一个元素
val last_element = x_1d(-1)
println(f"最后一个元素 (x_1d(-1)): {last_element}")
// 修改一个元素
x_1d(1) = 110
println(f"\n修改后的张量:\n{x_1d}")
```
```java
Tensor x_1d = tensor(new int[]{10, 11, 12, 13, 14});
System.out.printf("原始一维张量:\n %s\n", x_1d);
// 访问第一个元素
Tensor first_element = x_1d.get(0);
System.out.printf("\n第一个元素 (x_1d(0)): %s, 类型: %s\n", first_element, first_element.dtype());
// 访问最后一个元素
Tensor last_element = x_1d.get(x_1d.size(0) - 1);
System.out.printf("最后一个元素 (x_1d(-1)): %s, 类型: %s\n", last_element, last_element.dtype());
// 修改一个元素 ??? 需要确认
x_1d.set_(new LongPointer(1), 110);
System.out.printf("\n修改后的张量:\n %s\n", x_1d);
```
注意,访问单个元素会返回一个包含单个值的`torch.Tensor`(一个0维张量或标量),而不是一个标准的Python数字,除非您使用`.item()`明确提取它。元素的修改是原地进行的。
对于多维张量,您需要为每个维度提供索引,并用逗号分隔:
```scala 3
// 创建一个二维张量 (例如,一个小矩阵)
val x_2d = torch.tensor(Seq(Seq(1, 2, 3),
Seq(4, 5, 6),
Seq(7, 8, 9)))
println(f"原始二维张量:\n{x_2d}")
// 访问第0行第1列的元素
val element_0_1 = x_2d(0, 1)
println(f"\n在 [0, 1] 的元素: {element_0_1}")
// 访问整个第一行 (索引0)
val first_row = x_2d(0) // or x_2d(0, *)
println(f"\n第一行 (x_2d(0)): {first_row}")
// 访问整个第二列 (索引1)
val second_col = x_2d(*, 1) // or x_2d(*, 1)
println(f"第二列 (x_2d(*, 1)): {second_col}")
// 修改一个元素
x_2d(1, 1) = 55
println(f"\n修改后的二维张量:\n{x_2d}")
```
```java
var data2 = new int[][]{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}};
var flatData2 = TensorToolkit.flatten(data2);
var shape2 = TensorToolkit.getShape(data2);
Tensor x_2d = torch.tensor((int[]) flatData2).reshape(shape2);
System.out.printf("原始二维张量:\n %s\n", x_2d);
// 访问第0行第1列的元素
Tensor element_0_1 = x_2d.get(torch.tensor(new int[]{0, 1}));
System.out.printf("\n在 [0, 1] 的元素: %s, 类型: %s\n", element_0_1, element_0_1.dtype());
// 访问整个第一行 (索引0)
Tensor first_row = x_2d.get(torch.tensor(new int[]{0}));
System.out.printf("\n第一行 (x_2d(0)): %s, 类型: %s\n", first_row, first_row.dtype());
// 访问整个第二列 (索引1)
Tensor second_col = x_2d.get(torch.tensor(new int[]{1}));
System.out.printf("第二列 (x_2d(*, 1)): %s, 类型: %s\n", second_col, second_col.dtype());
// 修改一个元素
x_2d.set_(tensor(55),1,1);//new LongPointer(1), new LongPointer(1), 55);
System.out.printf("\n修改后的二维张量:\n %s\n", x_2d);
```
提供的索引数量少于维数时,会沿剩余维度选择一个完整的子张量。例如,`x_2d[0]` 会选择整个第一行。
### 张量切片
切片允许您沿张量维度选择一系列元素。语法是 `start:stop:step`,其中 `start` 是包含的,`stop` 是不包含的,而 `step` 定义了间隔。省略 `start` 默认为0,省略 `stop` 默认为维度的末尾,省略 `step` 默认为1。
```scala 3
// 创建一个一维张量
val y_1d = torch.arange(10) // Tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
println(f"原始一维张量: {y_1d}")
// 选择从索引2开始到(不包含)索引5的元素
val slice1 = y_1d(2 until 5)
println(f"\n切片 y_1d(2 until 5): {slice1}")
// 选择从开头到索引4的元素
val slice2 = y_1d(0 until 4)
println(f"切片 y_1d(0 until 4): {slice2}")
// 选择从索引6到末尾的元素
val slice3 = y_1d(6 until y_1d.size(0))
println(f"切片 y_1d(6 until y_1d.size(0)): {slice3}")
// 选择每隔一个的元素
val slice4 = y_1d(0 until y_1d.size(0) by 2)
println(f"切片 y_1d(0 until y_1d.size(0) by 2): {slice4}")
// 选择从索引1到7的元素,步长为2
val slice5 = y_1d(1 until 8 by 2)
println(f"切片 y_1d(1 until 8 by 2): {slice5}")
// 反转张量
val slice6 = y_1d(y_1d.size(0) - 1 until 0 by -1)
println(f"切片 y_1d(y_1d.size(0) - 1 until 0 by -1): {slice6}")
```
```java
Tensor y_1d = torch.arange(new Scalar(10));
System.out.printf("原始一维张量: %s\n", y_1d);
// 选择从索引2开始到(不包含)索引5的元素
Tensor slice1 = y_1d.slice(0, new LongOptional(2),new LongOptional(5), 1);
System.out.printf("\n切片 y_1d(2 until 5): %s\n", slice1);
// 选择从开头到索引4的元素
Tensor slice2 = y_1d.slice(0, new LongOptional(0), new LongOptional(4), 1);
System.out.printf("切片 y_1d(0 until 4): %s\n", slice2);
// 选择从索引6到末尾的元素
Tensor slice3 = y_1d.slice(0, new LongOptional(6), new LongOptional(y_1d.size(0)) , 1);
System.out.printf("切片 y_1d(6 until y_1d.size(0)): %s\n", slice3);
// 选择每隔一个的元素
Tensor slice41 = y_1d.slice(0, new LongOptional(0), new LongOptional( y_1d.size(0)), 2);
System.out.printf("切片 y_1d(0 until y_1d.size(0) by 2): %s\n", slice41);
// 选择从索引1到7的元素
Tensor slice52 = y_1d.slice(0, new LongOptional(1), new LongOptional(8), 1);
System.out.printf("切片 y_1d(1 until 8 by 2): %s\n", slice52);
```
切片对多维张量的工作方式类似。您可以将整数索引和切片结合使用:
```scala 3
// 创建一个3x4张量
val x_2d = torch.tensor(Seq(Seq( 0, 1, 2, 3),
Seq( 4, 5, 6, 7),
Seq( 8, 9, 10, 11)))
println(f"原始二维张量:\n{x_2d}")
// 选择前两行以及第1和第2列
val sub_tensor1 = x_2d(0 until 2, 1 until 3)
println(f"\n切片 x_2d(0 until 2, 1 until 3):\n{sub_tensor1}")
// 选择所有行,但只选择最后两列
val sub_tensor2 = x_2d(*, -2 until x_2d.size(1))
println(f"\n切片 x_2d(*, -2 until x_2d.size(1)):\n{sub_tensor2}")
// 选择第一行,从第1列到末尾
val sub_tensor3 = x_2d(0, 1 until x_2d.size(1))
println(f"\n切片 x_2d(0, 1 until x_2d.size(1)):\n{sub_tensor3}")
// 选择第0行和第2行(使用步长),所有列
val sub_tensor4 = x_2d(0 until x_2d.size(0) by 2, *)
println(f"\n切片 x_2d(0 until x_2d.size(0) by 2, *):\n{sub_tensor4}")
```
```java
var data4 = new int[][]{{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}};
var flatData4 = TensorToolkit.flatten(data4);
var shape4 = TensorToolkit.getShape(data4);
Tensor x_2d2 = torch.tensor((int[]) flatData4).reshape(shape4);
System.out.printf("原始二维张量:\n %s\n", x_2d2);
// 选择前两行以及第1和第2列
Tensor sub_tensor1 = x_2d2.slice(0, new LongOptional(0),new LongOptional(2) , 1).slice(1, new LongOptional(1),new LongOptional(3), 1);
System.out.printf("\n切片 x_2d2(0 until 2, 1 until 3):\n %s\n", sub_tensor1);
// 选择所有行,但只选择最后两列
Tensor sub_tensor2 = x_2d2.slice(1, new LongOptional(-2),new LongOptional( x_2d2.size(1)) , 1);
System.out.printf("\n切片 x_2d2(*, -2 until x_2d2.size(1)):\n %s\n", sub_tensor2);
// 选择第一行,从第1列到末尾
Tensor sub_tensor3 = x_2d2.get(new Scalar(0)).slice(0, new LongOptional(1),new LongOptional(x_2d2.size(1)) , 1);
System.out.printf("\n切片 x_2d2(0, 1 until x_2d2.size(1)):\n %s\n", sub_tensor3);
// 选择第0行和第2行(使用步长),所有列
Tensor sub_tensor41 = x_2d2.slice(0, new LongOptional(0),new LongOptional( x_2d2.size(0)), 2);
System.out.printf("\n切片 x_2d2(0 until x_2d2.size(0) by 2, *):\n %s\n", sub_tensor41);
```
原始张量 (x_2d)切片: x_2d[0:2, 1:3]012345678910111256

> 张量 `x_2d` 使用 `x_2d[0:2, 1:3]` 进行切片的视觉表示。它选择第0和第1行,以及第1和第2列。
切片的一个重要特性(与某些其他形式的索引不同)是,返回的张量通常与原始张量共享底层存储。修改切片会修改原始张量。
```scala 3
println(f"修改切片前的原始 x_2d:\n{x_2d}")
// 获取一个切片
val sub_tensor = x_2d(0 until 2, 1 until 3)
# 修改切片
sub_tensor(0, 0) = 101
println(f"\n修改后的切片:\n{sub_tensor}")
println(f"\n修改切片后的原始 x_2d:\n{x_2d}") // 注意变化!
```
```java
System.out.printf("修改切片前的原始 x_2d:\n %s\n", x_2d2);
// 获取一个切片
Tensor sub_tensor = x_2d2.slice(0, new LongOptional(0),new LongOptional(2), 1).slice(1, new LongOptional(1),new LongOptional(3) , 1);
// 修改切片
sub_tensor.set_(torch.tensor(101),0,0); ;
System.out.printf("\n修改后的切片:\n %s\n", sub_tensor);
System.out.printf("\n修改切片后的原始 x_2d2:\n %s\n", x_2d2);
```
如果您需要一个不共享内存的副本,请在切片上使用 `.clone()`: `sub_tensor_copy = x_2d[0:2, 1:3].clone()`。
### 布尔索引 (遮罩)
您可以使用布尔张量来索引另一个张量。布尔张量的形状必须能够广播到被索引张量的形状(通常,它们的形状完全相同)。只有布尔张量中对应 `True` 值的元素(即“遮罩”)才会被选中。这对于根据条件筛选数据非常有用。
```scala 3
// 创建一个张量
val data = torch.tensor(Seq(Seq(1, 2), Seq(3, 4), Seq(5, 6)))
println(f"原始数据张量:\n{data}")
// 创建一个布尔遮罩 (例如,选择大于3的元素)
val mask = data > 3
println(f"\n布尔遮罩 (data > 3):\n{mask}")
// 应用遮罩
val selected_elements = data(mask)
println(f"\n通过遮罩选择的元素:\n{selected_elements}")
println(f"所选元素的形状: {selected_elements.shape}")
// 根据条件修改元素
data(data <= 3) = 0
println(f"\n将小于等于3的元素设置为零后的数据:\n{data}")
```
```java
// 创建一个张量
var data6 = new int[][]{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
var flatData6 = TensorToolkit.flatten(data6);
var shape6 = TensorToolkit.getShape(data6);
Tensor data = torch.tensor((int[]) flatData6).reshape(shape6);
System.out.printf("原始数据张量:\n %s\n", data);
// 创建一个布尔遮罩 (例如,选择大于3的元素)
Tensor mask = torch.gt(data, new Scalar(3));
System.out.printf("\n布尔遮罩 (data > 3):\n %s\n", mask);
// 应用遮罩
Tensor selected_elements = data.masked_select(mask);
System.out.printf("\n通过遮罩选择的元素:\n %s\n", selected_elements);
System.out.printf("所选元素的形状: %s\n", selected_elements.sizes());
// 根据条件修改元素
Tensor index = torch.le(data, new Scalar(3));
System.out.printf("\n将小于等于3的元素为零后的数据索引 mask :\n %s\n", data);
data.masked_fill_(index, new Scalar(0));
System.out.printf("\n将小于等于3的元素设置为零后的数据 更新后 :\n %s\n", data);
```
布尔索引通常返回一个包含所有选定元素的一维张量。与切片不同,它*不*保留原始形状。此外,布尔索引通常会创建一个副本,而不是一个视图。
您可以将布尔索引与其他形式结合使用。例如,根据应用于其中一列的条件选择行:
```scala 3
// 选择第一列大于2的行
val row_mask = data(:, 0) > 2
println(f"\n行遮罩 (data[:, 0] > 2): {row_mask}")
// 使用 ':' 选择所选行中的所有列
// Or simply: data[row_mask] - PyTorch 通常会推断出完整的行选择
val selected_rows = data(row_mask, *)
println(f"\n第一列大于2的行:\n{selected_rows}")
```
```java
Tensor row_mask = torch.gt(data.get(0), new Scalar(2));
System.out.printf("\n行遮罩 (data[:, 0] > 2): %s\n", row_mask);
// 使用掩码选择行
Tensor selected_rows2 = data.masked_select(row_mask);
System.out.printf("\n第一列大于2的行:\n %s\n", selected_rows2);
```
PyTorch 中 mask(掩码)和 masked_select 函数的实际用途,其实它们的核心价值是精准筛选 / 操作张量中满足特定条件的元素,避免手动遍历,同时充分利用 PyTorch 的向量化计算和 GPU 加速
Mask(掩码张量):是一个和目标张量形状相同的布尔型张量(元素为 True/False),True 表示 “选中该位置元素”,False 表示 “忽略该位置元素”;
masked_select:PyTorch 内置函数,作用是 “根据掩码张量,从目标张量中提取所有 mask=True 位置的元素,返回一维张量”。
简单说:mask 是 “筛选规则”,masked_select 是 “执行筛选的工具”—— 替代手动 for 循环筛选元素,效率提升几十倍(尤其是 GPU 上)。
1. 数据清洗 / 异常值过滤(最常用)
实际业务中,数据常包含异常值(如缺失值、超出范围的值、无效值),用 mask+masked_select 可快速过滤:
例 1:过滤张量中的负数(只保留非负数);
例 2:过滤图像像素中的无效值(如像素值 > 255 或 < 0);
例 3:过滤时间序列中的缺失值(NaN/Inf)
```java
// 1. 创建包含异常值的张量(含负数、0、正数)
Tensor data = torch.tensor(new float[]{1.2f, -3.5f, 0.0f, 4.8f, -0.9f, 6.1f},
torch.tensorOptions().dtype(ScalarType.Float));
System.out.println("===== 原始数据 =====");
print(data);
// 2. 创建掩码:只保留 >0 的元素(True表示选中)
Tensor mask = torch.gt(data, new Scalar(0)); // gt = greater than(大于)
System.out.println("\n===== 掩码张量(True=保留>0的元素) =====");
print(mask);
// 3. 使用masked_select过滤元素
Tensor filtered_data = torch.masked_select(data, mask);
System.out.println("\n===== 过滤后的数据(仅保留正数) =====");
print(filtered_data);
// 4. 进阶:过滤图像像素中的无效值(0-255之外)
Tensor image_pixels = torch.randint(new Scalar(0), new Scalar(300), new long[]{2, 3}); // 2x3像素,0-299
System.out.println("\n===== 原始图像像素(含>255的无效值) =====");
print(image_pixels);
// 创建复合掩码:像素值 >=0 且 <=255
Tensor mask_valid = torch.logical_and(
torch.ge(image_pixels, new Scalar(0)), // ge = greater equal(大于等于)
torch.le(image_pixels, new Scalar(255)) // le = less equal(小于等于)
);
Tensor valid_pixels = torch.masked_select(image_pixels, mask_valid);
System.out.println("\n===== 过滤后的有效像素 =====");
print(valid_pixels);
```
2. 深度学习中的损失函数计算(核心场景)
在分类 / 回归任务中,常需要忽略部分样本的损失(如:
例 1:目标检测中,忽略 “背景框” 的损失(只计算前景框);
例 2:序列标注中,忽略 padding 部分的损失(如句子长度不一致时的补 0);
例 3:推荐系统中,忽略用户未交互的物品评分
```java
loss = cross_entropy(pred, target) # 计算所有样本的损失
mask = (target != PAD_TOKEN) # 掩码:只保留非padding的样本
loss = torch.masked_select(loss, mask) # 过滤掉padding样本的损失
mean_loss = loss.mean() # 只计算有效样本的平均损失
```
3. 稀疏数据处理(和之前讲的 Layout 呼应)
对于稀疏张量(如 SparseCsr 布局),mask+masked_select 可快速提取非零元素或满足条件的稀疏元素:
例:从稀疏邻接矩阵中提取 “权重> 0.5” 的边(图神经网络常用);
例:从用户 - 物品评分矩阵中提取 “评分>=4” 的有效评分。
4. 条件赋值 / 替换(扩展用法)
虽然 masked_select 是 “提取”,但结合掩码可实现 “条件替换”(如将不满足条件的元素设为 0):
```java
// 将张量中负数替换为0
Tensor data = torch.tensor(new float[]{1.2f, -3.5f, 4.8f, -0.9f});
Tensor mask = torch.lt(data, new Scalar(0)); // lt = less than(小于)
Tensor new_data = torch.where(mask, torch.zeros_like(data), data); // where(掩码, 替换值, 原值)
System.out.println("替换后数据: " + new_data);
// 输出:1.2 0.0 4.8 0.0
```
二、为什么不用手动循环?(核心优势)
你可能会问:“用 for 循环也能筛选,为什么用 mask?”—— 核心是效率和兼容性:
速度提升:
CPU 上:向量化操作比 for 循环快 5-10 倍;
GPU 上:差距可达 100 倍以上(GPU 擅长并行处理向量化操作,不擅长循环);
代码简洁:一行 masked_select 替代多行 for 循环 + 条件判断;
梯度兼容:PyTorch 的掩码操作支持自动求导(手动循环会破坏计算图),深度学习中必须用这种方式。
总结(关键点回顾)
核心用途:
数据清洗:过滤异常值 / 无效值 / 缺失值;
损失计算:忽略无效样本的损失(padding / 背景等);
稀疏数据:提取满足条件的稀疏元素;
条件赋值:按规则替换张量元素。
核心优势:
速度快:向量化操作,CPU/GPU 加速;
代码简:替代手动循环,可读性高;
兼容梯度:支持自动求导,适配深度学习。
使用逻辑:
先通过布尔运算生成 mask(筛选规则)→ 再用 masked_select 执行筛选 → 得到满足条件的一维张量。
简单来说,mask+masked_select 是 PyTorch 中 “按条件筛选元素” 的标准做法,几乎所有涉及 “条件过滤” 的场景都能用到,是深度学习和数据处理的必备工具。
### 整数数组索引
除了单个整数和切片,您还可以使用列表或一维整数张量沿维度进行索引。这使得您可以按任意顺序选择元素,或多次选择相同的元素。
```scala 3
// 创建一个一维张量
val x = torch.arange(10, 20) // Tensor([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
println(f"原始一维张量: {x}")
// 注意索引2的重复
val indices = torch.tensor(Seq(0, 4, 2, 2))
println(f"\n使用索引 {indices} 选择的元素: {x(indices)}")
// 对于二维张量
val y = torch.arange(12).reshape(3, 4)
// [[ 0, 1, 2, 3],
// [ 4, 5, 6, 7],
// [ 8, 9, 10, 11]]
println(f"\n原始二维张量:\n{y}")
// 选择特定行
val row_indices = torch.tensor(Seq(0, 2))
val selected_rows = y(row_indices, *)
println(f"\n使用索引 {row_indices} 选择的行:\n{selected_rows}")
// 选择特定列
val col_indices = torch.tensor(Seq(1, 3))
val selected_cols = y(*, col_indices) // 从所有行中选择第1列和第3列
println(f"\n使用索引 {col_indices} 选择的列:\n{selected_cols}")
// 使用索引对选择特定元素
val row_idx = torch.tensor(Seq(0, 1, 2))
val col_idx = torch.tensor(Seq(1, 3, 0))
val selected_elements = y(row_idx, col_idx) // 选择 (0,1), (1,3), (2,0) -> [1, 7, 8]
println(f"\n使用 (row_idx, col_idx) 选择的特定元素:\n{selected_elements}")
```
```java
Tensor x = torch.arange(new Scalar(10), new Scalar(20));
System.out.printf("原始一维张量: \n %s\n", x);
// 注意索引2的重复
Tensor indices = torch.tensor(new int[]{0, 4, 2, 2});
System.out.printf("\n使用索引 %s 选择的元素: \n %s\n", indices, x.index_select(0, indices));
// 对于二维张量
Tensor y = torch.arange(new Scalar(12)).reshape(3, 4);
System.out.printf("\n原始二维张量:\n %s\n", y);
// 选择特定行
Tensor row_indices = torch.tensor(new int[]{0, 2});
Tensor selected_rows = y.index_select(0, row_indices);
System.out.printf("\n使用索引 %s 选择的行:\n %s\n", row_indices, selected_rows);
// 选择特定列
Tensor col_indices = torch.tensor(new int[]{1, 3});
Tensor selected_cols = y.index_select(1, col_indices);
System.out.printf("\n使用索引 %s 选择的列:\n %s\n", col_indices, selected_cols);
```
与布尔索引类似,整数数组索引通常返回一个新的张量(一个副本),而不是原始张量存储的视图。输出的形状取决于索引方法。当选择完整的行或列时,其他维度会被保留。当为多个维度提供索引数组(例如 `y[row_idx, col_idx]`)时,结果通常是一个对应于所选元素的一维张量。
掌握这些索引和切片技术能够精准地控制张量数据,为后续步骤中的数据准备、特征提取以及模型输入输出的操作奠定根基。
# 张量的重塑与维度调整
通常,你会发现现有张量的结构不太适合后续计算步骤,尤其是在将数据送入特定的神经网络层时。PyTorch 提供了灵活的工具,可以在不改变底层数据元素本身的情况下,改变张量的形状或调整其维度。用于这些操作的主要方法是:`view()`、`reshape()` 和 `permute()`。
### 使用 `view()` 和 `reshape()` 改变形状
`view()` 和 `reshape()` 都允许你改变张量的维度,前提是总元素数量保持不变。它们在将多维张量展平后传递给线性层,或增加/移除大小为1的维度等任务中非常有用。
#### 使用 `view()`
`view()` 方法返回一个新的张量,该张量与原始张量共享相同的底层数据,但具有不同的形状。它非常高效,因为它避免了数据复制。然而,`view()` 要求张量在内存中是*连续的*。连续张量是指其元素在内存中按维度顺序连续存储,没有间隙的张量。大多数新创建的张量是连续的,但某些操作(如切片或使用 `t()` 进行转置)会产生非连续张量。
我们来看一个例子:
```scala 3
import torch.*
// 创建一个连续张量
val x = torch.arange(12) // tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
println(f"原始张量: {x}")
println(f"原始形状: {x.shape}")
println(f"是否连续? {x.is_contiguous()}")
// 使用 view() 重塑
val y = x.view(3, 4)
println("\nview(3, 4) 后的张量:")
println(y)
println(f"新形状: {y.shape}")
println(f"与 x 共享存储吗? {y.storage().data_ptr() == x.storage().data_ptr()}") // 检查它们是否共享内存
println(f"y 是否连续? {y.is_contiguous()}")
// 尝试另一个视图
val z = y.view(2, 6)
println("\nview(2, 6) 后的张量:")
println(z)
println(f"新形状: {z.shape}")
println(f"与 x 共享存储吗? {z.storage().data_ptr() == x.storage().data_ptr()}")
println(f"z 是否连续? {z.is_contiguous()}")
```
```java
// 创建一个连续张量
Tensor x2 = torch.arange(new Scalar(12));
System.out.printf("原始张量: %s\n", x2);
System.out.printf("原始形状: %s\n", x2.sizes());
System.out.printf("是否连续? %b\n", x2.is_contiguous());
// 使用 view() 重塑
Tensor y2 = x2.view(3, 4);
System.out.println("\nview(3, 4) 后的张量:");
System.out.println(y2);
System.out.printf("新形状: %s\n", y2.sizes());
System.out.printf("与 x2 共享存储吗? %b\n",
y2.storage().data_ptr().address() == x2.storage().data_ptr().address());
System.out.printf("y2 是否连续? %b\n", y2.is_contiguous());
// 尝试另一个视图
Tensor z = y2.view(2, 6);
System.out.println("\nview(2, 6) 后的张量:");
System.out.println(z);
System.out.printf("新形状: %s\n", z.sizes());
System.out.printf("与 x2 共享存储吗? %b\n",
z.storage().data_ptr().address() == x2.storage().data_ptr().address());
System.out.printf("z 是否连续? %b\n", z.is_contiguous());
```
你可以在 `view()` 调用中对一个维度使用 `-1`,PyTorch 将根据总元素数量和其它维度的尺寸自动推断出该维度的正确尺寸。
```scala 3
// 使用 -1 进行推断
val w = x.view(2, 2, -1) // 推断出最后一个维度为 3 (12 / (2*2) = 3)
println("\nview(2, 2, -1) 后的张量:")
println(w)
println(f"新形状: {w.shape}")
```
如果你尝试在非连续张量上调用 `view()`,你会得到一个 `RuntimeError`。
```scala 3
// view() 在非连续张量上失败的例子
val a = torch.arange(12).view(3, 4)
val b = a.t() // 转置操作会创建一个非连续张量
println(f"\nb 是否连续? {b.is_contiguous()}")
try:
val c = b.view(12)
catch (e: RuntimeException) =>
println(f"\n尝试 b.view(12) 时出错: {e}")
```
```java
// view() 在非连续张量上失败的例子
Tensor a = torch.arange(new Scalar(12)).view(3, 4);
Tensor b = a.t(); // 转置操作会创建一个非连续张量
System.out.printf("\nb 是否连续? %b\n", b.is_contiguous());
try {
Tensor c = b.view(12);
} catch (RuntimeException e) {
System.out.printf("\n尝试 b.view(12) 时出错: %s\n", e.getMessage());
}
```
#### 使用 `reshape()`
`reshape()` 方法的行为类似于 `view()`,但提供了更多灵活性。如果张量对于目标形状是连续的,它会尝试返回一个视图。如果无法返回视图(例如,因为原始张量在与新形状兼容的方式上不是连续的),`reshape()` 将把数据*复制*到一个新的、具有所需形状的连续张量中。这使得 `reshape()` 通常更安全、更通用,尽管如果发生复制,性能可能会降低。
我们再次查看使用 `reshape()` 的转置例子:
```scala 3
// 在非连续张量 'b' 上使用 reshape()
println(f"\n原始非连续张量 b:\n{b}")
println(f"b 的形状: {b.shape}")
println(f"b 是否连续? {b.is_contiguous()}")
// 即使 'b' 不连续,reshape 也能工作
val c = b.reshape(12)
println(f"\nb.reshape(12) 后的张量 c:\n{c}")
println(f"c 的形状: {c.shape}")
println(f"c 是否连续? {c.is_contiguous()}")
// 检查 'c' 是否与 'b' 共享存储。由于 reshape 可能进行了复制,所以它们很可能不共享。
println(f"与 b 共享存储吗? {c.storage().data_ptr() == b.storage().data_ptr()}")
// reshape 也可以用 -1 推断维度
val d = b.reshape(2, -1) // 推断出最后一个维度为 6
println(f"\nb.reshape(2, -1) 后的张量 d:\n{d}")
println(f"d 的形状: {d.shape}")
```
```java
// 在非连续张量 'b' 上使用 reshape()
System.out.printf("\n原始非连续张量 b:\n %s\n", b);
System.out.printf("b 的形状: %s\n", b.sizes());
System.out.printf("b 是否连续? %b\n", b.is_contiguous());
// 即使 'b' 不连续,reshape 也能工作
Tensor c = b.reshape(12);
System.out.printf("\nb.reshape(12) 后的张量 c:\n %s\n", c);
System.out.printf("c 的形状: %s\n", c.sizes());
System.out.printf("c 是否连续? %b\n", c.is_contiguous());
// 检查 'c' 是否与 'b' 共享存储
System.out.printf("与 b 共享存储吗? %b\n",
c.storage().data_ptr().address() == b.storage().data_ptr().address());
// reshape 也可以用 -1 推断维度
Tensor d = b.reshape(2, -1);
System.out.printf("\nb.reshape(2, -1) 后的张量 d:\n %s\n", d);
System.out.printf("d 的形状: %s\n", d.sizes());
System.out.printf("d 是否连续? %b\n", d.is_contiguous());
```
**何时使用哪个方法?**
- 如果你确定张量是连续的,并且希望确保不发生数据复制以获得最高性能,请使用 `view()`。如果连续性假设有误,请准备好处理可能的 `RuntimeError`。
- `reshape()` 适用于连续和非连续张量。如果可能,它会返回一个视图,否则会创建一个副本。除非性能绝对关键且你能保证连续性,否则这通常是优选方法。
维度 view() reshape()
内存行为 仅改变 “视图”,不复制内存(浅操作) 优先复用内存(连续时 = view),不连续则复制(深操作)
前置条件 必须是 contiguous 张量,否则报错 无前置条件,自动兼容连续 / 非连续张量
返回值内存 和原张量共享内存(修改一方会影响另一方) 连续时共享内存,非连续时生成新内存
适用场景 确定张量连续,追求极致效率(无内存拷贝) 通用场景,不想关心连续性(兼容所有情况)
| 维度 | view() | reshape() |
| -------------- | ---------------------------------------- | ----------------------------------------------------- |
| **内存行为** | 仅改变 “视图”,不复制内存(浅操作) | 优先复用内存(连续时 = view),不连续则复制(深操作) |
| **前置条件** | 必须是 contiguous 张量,否则报错 | 无前置条件,自动兼容连续 / 非连续张量 |
| **返回值内存** | 和原张量共享内存(修改一方会影响另一方) | 连续时共享内存,非连续时生成新内存 |
| **适用场景** | 确定张量连续,追求极致效率(无内存拷贝) | 通用场景,不想关心连续性(兼容所有情况) |
### 使用 `permute()` 调整维度顺序
`view()` 和 `reshape()` 通过重新安排元素在维度间的解释方式来改变形状,而 `permute()` 则明确地交换维度本身。它不改变总元素数量,也不改变每个轴上元素*数量*方面的形状,但它改变的是*哪个*轴对应哪个原始维度。
假设你有一个图像数据,存储为(通道,高,宽)格式,但为了特定的库或可视化需求,需要其格式为(高,宽,通道)。`permute()` 就是为此而设计的工具。你将所需的维度顺序作为参数提供。
`permute()` 函数的用法,尤其是它和之前聊的 `view()`/`reshape()` 在内存层面的区别 ——`permute()` 核心是**调整张量维度的顺序**(比如 NCHW 转 NHWC),但不会改变元素的内存存储位置,只会修改 “维度访问规则”,最终导致张量变为**非连续的**
`permute()` 的参数是**维度索引**,用来指定新的维度顺序。比如:
- 张量形状为 `(N, C, H, W)`(批量、通道、高、宽),`permute(0,2,3,1)` 会转为 `(N, H, W, C)`(对应之前讲的 `ChannelsLast` 内存格式);
- 张量形状为 `(2,3)`,`permute(1,0)` 等价于转置 `.t()`。
```scala 3
// 创建一个三维张量(例如,表示通道、高、宽)
val image_tensor = torch.randn(3, 32, 32) // 通道,高,宽
println(f"原始形状: {image_tensor.shape}") // torch.Size([3, 32, 32])
// 调整为(高,宽,通道)
val permuted_tensor = image_tensor.permute(1, 2, 0) // 指定新顺序:维度 1,维度 2,维度 0
println(f"调整后的形状: {permuted_tensor.shape}") // torch.Size([32, 32, 3])
// permute 通常返回一个非连续的视图
println(f"permuted_tensor 是否连续? {permuted_tensor.is_contiguous()}")
// 调回原状
val original_again = permuted_tensor.permute(2, 0, 1) // 回到通道,高,宽
println(f"调回后的形状: {original_again.shape}") // torch.Size([3, 32, 32])
println(f"original_again 是否连续? {original_again.is_contiguous()}") // (可能仍然是非连续的)
// 检查存储共享
println(f"与原始张量共享存储吗? {original_again.storage().data_ptr() == image_tensor.storage().data_ptr()}")
```
```java
// 创建一个三维张量(例如,表示通道、高、宽)
Tensor image_tensor = torch.randn(3, 32, 32);
System.out.printf("原始形状: %s\n", image_tensor.sizes());
// 调整为(高,宽,通道)
Tensor permuted_tensor = image_tensor.permute(1, 2, 0);
System.out.printf("调整后的形状: %s\n", permuted_tensor.sizes());
// permute 通常返回一个非连续的视图
System.out.printf("permuted_tensor 是否连续? %b\n", permuted_tensor.is_contiguous());
// 调回原状
Tensor original_again = permuted_tensor.permute(2, 0, 1);
System.out.printf("调回后的形状: %s\n", original_again.sizes());
System.out.printf("original_again 是否连续? %b\n", original_again.is_contiguous());
// 检查存储共享
System.out.printf("与原始张量共享存储吗? %b\n",
original_again.storage().data_ptr().address() == image_tensor.storage().data_ptr().address());
```
和 `view()` 一样,`permute()` 返回一个与原始张量共享底层数据的张量。它不复制数据。然而,生成的张量通常*不是*连续的。如果你在调换维度后需要一个连续张量(例如,为了后续使用 `view()`),你可以链式调用 `.contiguous()` 方法:
```scala 3
// 使调整维度的张量连续
val contiguous_permuted = permuted_tensor.contiguous()
println(f"\ncontiguous_permuted 是否连续? {contiguous_permuted.is_contiguous()}")
// 现在可以安全地使用 view()
val flattened_permuted = contiguous_permuted.view(-1)
println(f"展平后的形状: {flattened_permuted.shape}")
```
```java
// 使调整维度的张量连续
Tensor contiguous_permuted = permuted_tensor.contiguous();
System.out.printf("\ncontiguous_permuted 是否连续? %b\n", contiguous_permuted.is_contiguous());
// 现在可以安全地使用 view()
Tensor flattened_permuted = contiguous_permuted.view(-1);
System.out.printf("展平后的形状: %s\n", flattened_permuted.sizes());
// 10
// view() 在非连续张量上失败的例子
Tensor a = torch.arange(new Scalar(12)).view(3, 4);
Tensor b = a.t(); // 转置操作会创建一个非连续张量
System.out.printf("\nb 是否连续? %b\n", b.is_contiguous());
try {
Tensor c = b.view(12);
} catch (RuntimeException e) {
System.out.printf("\n尝试 b.view(12) 时出错: %s\n", e.getMessage());
}
```
掌握 `view()`、`reshape()` 和 `permute()` 让你能够精确控制张量的结构,这是将数据适配到不同 PyTorch 操作和模型层要求所需的一项必备技能。请记住这些权衡:`view()` 速度快但要求连续性,`reshape()` 灵活但可能会复制,而 `permute()` 交换维度而不复制,但通常会产生非连续张量。
### 二、permute () 的内存行为(核心重点)
#### 1. 不复制内存,仅修改 “维度访问规则”
`permute()` 是纯 “视图操作”—— 元素在内存中的存储位置完全不变,只是 PyTorch 改变了 “如何通过维度索引访问内存” 的规则(通过 `stride` 步幅实现)。
比如原始张量 `x (2x3)` 的内存顺序是:`[1,2,3,4,5,6]`,`permute(1,0)` 后,内存还是这个顺序,但访问 `x_permute[0][0]` 时,PyTorch 会去内存中取第 0 个元素(1),访问 `x_permute[0][1]` 时取第 3 个元素(4)—— 相当于 “换了个角度看内存”,内存本身没动。
#### 2. permute () 后张量必然是非连续的
因为维度顺序调换后,元素的内存访问路径出现了 “空洞”,不再是紧密连续的。这也是为什么 permute 后的张量不能直接用 `view()`(会报错),必须先 `contiguous()` 复制为连续张量。
`contiguous()` 函数的原理、作用和实际使用场景,这是理解张量内存模型的关键 ——`contiguous()` 本质是**将非连续的张量重新整理为内存连续的张量**,核心解决 “非连续张量无法使用 `view()`、计算效率低” 的问题。我会从 “原理→作用→场景→代码示例” 逐步拆解,结合之前聊的 `view()`/`permute()` 等知识点,让你彻底搞懂。
### 先记核心结论
- **原理**:`contiguous()` 会**复制非连续张量的元素**,按 “维度顺序” 重新排列到一块连续的内存空间,生成一个全新的、内存连续的张量;
- **作用**:让非连续张量恢复 “内存连续性”,兼容 `view()` 等要求连续的操作,同时提升计算效率;
- **核心场景**:permute / 转置后用 `view()`、提升非连续张量的计算速度、兼容第三方库 / 算子。
## 一、contiguous () 核心原理(内存层面)
要理解 `contiguous()`,先回顾「内存连续性」的本质:
PyTorch 张量的多维结构最终映射到**一维内存数组**,“连续” 意味着:
> 张量的元素在内存中严格按照「从最后一维到第一维」的顺序紧密排列,无任何 “空洞”。
### 1. 非连续张量的产生原因
之前讲过,`permute()`、`.t()`(转置)、`narrow()`、`index_select()` 等操作会生成**非连续张量**—— 这些操作只修改 “维度访问规则(stride)”,不改变元素的内存存储位置,导致内存访问路径出现 “空洞”。
| 作用 | 具体说明 |
| ---------------------- | ------------------------------------------------------------ |
| 1. 兼容 `view()` 操作 | `view()` 要求张量必须连续,非连续张量调用 `view()` 会报错,`contiguous()` 是前置条件; |
| 2. 提升计算效率 | 连续张量的内存访问是 “顺序的”,CPU/GPU 的缓存命中率更高,计算速度提升 10%-50%; |
| 3. 兼容第三方库 / 算子 | 部分 C++/CUDA 算子、第三方库(如 OpenCV)仅支持连续张量,非连续会导致崩溃 / 结果错误; |
| 4. 统一内存布局 | 确保张量内存布局符合 PyTorch 原生规则,避免因 stride 异常导致的隐性 bug; |
| 维度 | permute() | view() | reshape() |
| ------------ | ----------------------------------- | ---------------------------- | --------------------------------- |
| **核心作用** | 调整维度顺序(如 NCHW→NHWC) | 重塑形状(如 2x3→6) | 智能重塑形状(兼容连续 / 非连续) |
| **维度变化** | 维度数 / 各维度大小不变,仅顺序变 | 维度数 / 大小可改,总数不变 | 同 view () |
| **内存行为** | 不复制内存,仅改访问规则 | 不复制内存(要求连续) | 连续时不复制,非连续时复制 |
| **连续性** | 必然返回非连续张量 | 仅支持连续张量,返回连续张量 | 连续时返回连续,否则返回连续 |
| **实战场景** | 图像维度转换(NCHW↔NHWC)、矩阵转置 | 确定连续的张量重塑 | 通用形状重塑 |
# 张量的合并与分割
许多深度学习情境下,你需要将多个张量组合成一个,或将一个更大的张量拆分为更小的部分。这可能涉及汇总不同处理步骤的结果、准备数据批次或分离特征。PyTorch 提供了几个函数,用于有效合并和分割张量。
### 合并张量
组合张量是一个常见操作,尤其是在处理数据批次或合并特征表示时。PyTorch 提供了两种主要方式来合并张量:拼接 (`torch.cat`) 和堆叠 (`torch.stack`)。主要区别在于它们是沿着现有维度操作,还是引入一个新维度。
#### 使用 `torch.cat` 进行拼接
`torch.cat` 函数沿着现有维度拼接一系列张量。序列中的所有张量必须形状相同(除了拼接维度),或者为空。
```scala 3
import torch.*
// 创建两个张量
val tensor_a = torch.randn(2, 3)
val tensor_b = torch.randn(2, 3)
println(f"Tensor A (Shape: {tensor_a.shape}):\n{tensor_a}")
println(f"Tensor B (Shape: {tensor_b.shape}):\n{tensor_b}\n")
// 沿着维度0(行)进行拼接
// 结果形状: (2+2, 3) = (4, 3)
val cat_dim0 = torch.cat((tensor_a, tensor_b), dim=0)
println(f"沿着维度0拼接 (形状: {cat_dim0.shape}):\n{cat_dim0}\n")
// 沿着维度1(列)进行拼接
// 张量必须在其他维度(维度0)上匹配
// 结果形状: (2, 3+3) = (2, 6)
val cat_dim1 = torch.cat((tensor_a, tensor_b), dim=1)
println(f"沿着维度1拼接 (形状: {cat_dim1.shape}):\n{cat_dim1}")
// 3D张量示例
val tensor_c = torch.randn(1, 2, 3)
val tensor_d = torch.randn(1, 2, 3)
// 沿着维度0(批次维度)进行拼接
// 结果形状: (1+1, 2, 3) = (2, 2, 3)
val cat_3d_dim0 = torch.cat((tensor_c, tensor_d), dim=0)
println(f"\n3D张量沿着维度0拼接 (形状: {cat_3d_dim0.shape})")
```
```java
// 创建两个张量
Tensor tensor_a = torch.randn(2, 3);
Tensor tensor_b = torch.randn(2, 3);
System.out.printf("Tensor A (Shape: %s):\n %s\n", tensor_a.sizes(), tensor_a);
System.out.printf("Tensor B (Shape: %s):\n %s\n", tensor_b.sizes(), tensor_b);
// 沿着维度0(行)进行拼接
TensorVector cat_input0 = new TensorVector(tensor_a, tensor_b);
Tensor cat_dim0 = torch.cat(cat_input0, 0);
System.out.printf("沿着维度0拼接 (形状: %s):\n %s\n", cat_dim0.sizes(), cat_dim0);
// 沿着维度1(列)进行拼接
TensorVector cat_input1 = new TensorVector(tensor_a, tensor_b);
Tensor cat_dim1 = torch.cat(cat_input1, 1);
System.out.printf("沿着维度1拼接 (形状: %s):\n %s\n", cat_dim1.sizes(), cat_dim1);
// 3D张量示例
Tensor tensor_c = torch.randn(1, 2, 3);
Tensor tensor_d = torch.randn(1, 2, 3);
// 沿着维度0(批次维度)进行拼接
TensorVector cat_input3d = new TensorVector(tensor_c, tensor_d);
Tensor cat_3d_dim0 = torch.cat(cat_input3d, 0);
System.out.printf("\n3D张量沿着维度0拼接 (形状: %s)\n", cat_3d_dim0.sizes());
```
请注意,`torch.cat` 增加了指定维度的大小,同时保持其他维度不变。张量在所有维度上都必须大小匹配,*除了*你进行拼接的那个维度。
张量 A (2x3)张量 B (2x3)torch.cat((A, B), dim=0)(4x3)torch.cat((A, B), dim=1)(2x6)a11a12a13a21a22a23b11b12b13b21b22b23a11a12a13a21a22a23b11b12b13b21b22b23a11a12a13b11b12b13a21a22a23b21b22b23cluster_a+ 维度0+ 维度1cluster_bcluster_cat0cluster_cat1
> 
>
> `torch.cat` 沿维度0和维度1对两个2x3张量进行拼接的视觉比较。
#### 使用 `torch.stack` 进行堆叠
与 `cat` 不同,`torch.stack` 沿着一个*新*维度连接一系列张量。当你希望从单个示例创建批次或将相关张量分组时,这会很有用。为了 `stack` 能够工作,输入序列中的所有张量必须具有完全相同的形状。
```scala 3
import torch.*
// 创建两个形状相同的张量
val tensor_e = torch.arange(6).reshape(2, 3)
val tensor_f = torch.arange(6, 12).reshape(2, 3)
println(f"Tensor E (Shape: {tensor_e.shape}):\n{tensor_e}")
println(f"Tensor F (Shape: {tensor_f.shape}):\n{tensor_f}\n")
// 沿着新维度0进行堆叠
// 结果形状: (2, 2, 3)
val stack_dim0 = torch.stack((tensor_e, tensor_f), dim=0)
println(f"沿着新维度0堆叠 (形状: {stack_dim0.shape}):\n{stack_dim0}\n")
// 沿着新维度1进行堆叠
// 结果形状: (2, 2, 3)
val stack_dim1 = torch.stack((tensor_e, tensor_f), dim=1)
println(f"沿着新维度1堆叠 (形状: {stack_dim1.shape}):\n{stack_dim1}\n")
// 沿着新维度2(最后一个维度)进行堆叠
// 结果形状: (2, 3, 2)
val stack_dim2 = torch.stack((tensor_e, tensor_f), dim=2)
println(f"沿着新维度2堆叠 (形状: {stack_dim2.shape}):\n{stack_dim2}")
```
```java
// 创建两个形状相同的张量
Tensor tensor_e = torch.arange(new Scalar(0), new Scalar(6)).reshape(2, 3);
Tensor tensor_f = torch.arange(new Scalar(6), new Scalar(12)).reshape(2, 3);
System.out.printf("Tensor E (Shape: %s):\n %s\n", tensor_e.sizes(), tensor_e);
System.out.printf("Tensor F (Shape: %s):\n %s\n", tensor_f.sizes(), tensor_f);
// 沿着新维度0进行堆叠
TensorVector stack_input0 = new TensorVector(tensor_e.to(ScalarType.Int), tensor_f);
Tensor stack_dim0 = torch.stack(stack_input0, 0);
System.out.printf("沿着新维度0堆叠 (形状: %s):\n %s\n", stack_dim0.sizes(), stack_dim0);
// 沿着新维度1进行堆叠
TensorVector stack_input1 = new TensorVector(tensor_e.to(ScalarType.Int), tensor_f);
Tensor stack_dim1 = torch.stack(stack_input1, 1);
System.out.printf("沿着新维度1堆叠 (形状: %s):\n %s\n", stack_dim1.sizes(), stack_dim1);
// 沿着新维度2(最后一个维度)进行堆叠
TensorVector stack_input2 = new TensorVector(tensor_e, tensor_f);
Tensor stack_dim2 = torch.stack(stack_input2, 2);
System.out.printf("沿着新维度2堆叠 (形状: %s):\n %s\n", stack_dim2.sizes(), stack_dim2);
```
张量 E (2x3)张量 F (2x3)torch.stack((E, F), dim=0)(2x2x3)切片 0切片 1torch.stack((E, F), dim=1)(2x2x3)e11e12e13e21e22e23f11f12f13f21f22f23e11e12e13e21e22e23f11f12f13f21f22f23cluster_stack0_ecluster_stack0_fe11e12e13f11f12f13e21e22e23f21f22f23cluster_e堆叠 维度0堆叠 维度1cluster_fcluster_stack0cluster_stack1


> `torch.stack` 在 `dim=0` 和 `dim=1` 处插入新维度的视觉比较。请注意原始张量如何成为新张量中的切片。
选择 `cat` 还是 `stack` 取决于你是想沿着现有维度合并,还是创建一个新维度。`cat` 通常用于水平/垂直组合批次或特征,`stack` 则常用于从单个样本创建批次。
| 函数 | 核心行为 | 维度变化 | 拼接规则 | 核心场景 |
| --------------- | ------------------------------------- | ----------------- | ------------------------------ | ---------------------------- |
| `cat`(concat) | 在**已有维度**上拼接,不新增维度 | 维度数不变 | 除拼接维度外,其他维度必须一致 | 同结构张量拼接(如批量合并) |
| `stack` | 在**新维度**上堆叠,新增 1 个维度 | 维度数 + 1 | 所有张量形状必须完全一致 | 新增维度分组(如多特征合并) |
| `hstack` | 水平拼接(优先拼最后一维 / 展平低维) | 维度数不变 / 降维 | 低维自动适配(1D→2D,2D 拼列) | 行 / 列水平拼接(直观对齐) |
| `column_stack` | 列拼接(1D 转列后拼列,2D 拼列) | 维度数≥2 | 1D 张量自动转为列向量 | 按列拼接(特征列合并) |
1. **检查张量形状兼容性**(不同函数规则不同);
2. **按指定规则重新排列内存**(连续拼接,无数据丢失);
3. **生成新张量**(多数情况共享内存?不 ——PyTorch 拼接后默认生成新的连续张量,原张量内存独立)。
关键区别在于:
- `cat`:在**已有维度**上 “延长”(比如 2x3 和 2x4 拼列→2x7);
- `stack`:在**新维度**上 “堆叠”(比如两个 2x3 堆叠→2x2x3 或 2x3x2);
- `hstack/column_stack`:是 `cat` 的 “便捷封装”,简化低维张量的拼接规则。
全称 `concatenate`,`concat` 是 `cat` 的别名;
在**指定的已有维度 dim** 上拼接,要求:除 `dim` 外,其他所有维度的大小必须完全一致;
维度数不变,仅拼接维度的大小增加(如 dim=1 时,2x3 + 2x4 → 2x7)
```java
// ========== 3. torch.hstack() 示例 ==========
// 场景1:1D张量水平拼接(等价于 cat(dim=0))
Tensor e = torch.tensor(new float[]{1,2,3}); // 1D (3)
Tensor f = torch.tensor(new float[]{4,5}); // 1D (2)
Tensor hstack_1d = torch.hstack(new TensorVector(e, f));
printTensorInfo(hstack_1d, "torch.hstack(e,f) (1D水平拼接)");
// 场景2:2D张量水平拼接(等价于 cat(dim=1))
Tensor g = torch.tensor(new float[]{1,2,3,4}).view(2,2); // 2x2
Tensor h = torch.tensor(new float[]{5,6,7,8,9,10}).view(2,3); // 2x3
Tensor hstack_2d = torch.hstack(new TensorVector(g, h));
printTensorInfo(hstack_2d, "torch.hstack(g,h) (2D水平拼接)");
```
```console
===== torch.hstack(e,f) (1D水平拼接) =====
形状: 5
值: 1.0 2.0 3.0 4.0 5.0
===== torch.hstack(g,h) (2D水平拼接) =====
形状: 2 5
值:
1.0000 2.0000 5.0000 6.0000 7.0000
3.0000 4.0000 8.0000 9.0000 10.0000
```
### 分割张量
正如你可以合并张量一样,你也经常需要将它们分开。这可能涉及将一个批次拆分回单个样本、将特征与标签分离或为并行处理划分数据。PyTorch 为这些任务提供了 `torch.split` 和 `torch.chunk` 函数。
#### 使用 `torch.split` 按特定大小分割
`torch.split` 函数沿着指定维度将张量分割成块。你可以指定每个块的大小(如果你想要等份),或者提供一个包含每个所需块大小的列表。
```scala 3
import torch
// 创建一个要分割的张量
val tensor_g = torch.arange(12).reshape(6, 2)
println(f"原始张量 (形状: {tensor_g.shape}):\n{tensor_g}\n")
// 沿着维度0(行)按大小2分割成块
// 6行 / 2行/块 = 3块
val split_equal = torch.split(tensor_g, 2, dim=0)
println("分割成大小为2的等份(dim=0):")
for i, chunk <- split_equal:
println(f" 块 {i} (形状: {chunk.shape}):\n{chunk}")
println("-" * 20)
// 沿着维度0按大小 [1, 2, 3] 分割成块
// 总大小必须等于该维度的大小 (1 + 2 + 3 = 6)
val split_unequal = torch.split(tensor_g, List(1, 2, 3), dim=0)
println("\n分割成大小不等的块 [1, 2, 3](dim=0):")
for i, chunk <- split_unequal:
println(f" 块 {i} (形状: {chunk.shape}):\n{chunk}")
println("-" * 20)
// 沿着维度1(列)进行分割
// 形状: (6, 2)。沿着维度1按大小1分割成块
val split_dim1 = torch.split(tensor_g, 1, dim=1)
println("\n分割成大小为1的等份(dim=1):")
for i, chunk <- split_dim1:
// 使用 squeeze 移除大小为1的维度,以便更清晰地显示
println(f" 块 {i} (形状: {chunk.shape}):\n{chunk.squeeze()}")
```
```java
// 创建一个要分割的张量
Tensor tensor_g = torch.arange(new Scalar(12)).reshape(6, 2);
System.out.printf("原始张量 (形状: %s):\n %s\n", tensor_g.sizes(), tensor_g);
// 沿着维度0(行)按大小2分割成块
TensorVector split_equal = torch.split(tensor_g, 2, 0);
System.out.println("分割成大小为2的等份(dim=0):");
for (int i = 0; i < split_equal.size(); i++) {
Tensor chunk = split_equal.get(i);
System.out.printf(" 块 %d (形状: %s):\n %s\n", i, chunk.sizes(), chunk);
}
System.out.println("-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-");
// 沿着维度0按大小 [1, 2, 3] 分割成块
LongPointer sizes = new LongPointer(1, 2, 3);
TensorVector split_unequal = torch.split(tensor_g, new LongArrayRef(sizes), 0);
System.out.println("\n分割成大小不等的块 [1, 2, 3](dim=0):");
for (int i = 0; i < split_unequal.size(); i++) {
Tensor chunk = split_unequal.get(i);
System.out.printf(" 块 %d (形状: %s):\n %s\n", i, chunk.sizes(), chunk);
}
System.out.println("-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-");
// 沿着维度1(列)进行分割
TensorVector split_dim1 = torch.split(tensor_g, 1, 1);
System.out.println("\n分割成大小为1的等份(dim=1):");
for (int i = 0; i < split_dim1.size(); i++) {
Tensor chunk = split_dim1.get(i);
// 使用 squeeze 移除大小为1的维度,以便更清晰地显示
System.out.printf(" 块 %d (形状: %s):\n %s\n", i, chunk.sizes(), chunk.squeeze());
}
```
`torch.split` 返回一个张量元组。如果你为 `split_size_or_sections` 参数提供一个整数,PyTorch 会沿着指定的 `dim` 将张量分割成该大小的块。如果维度大小不能被分割大小完全整除,最后一个块会更小。如果你提供一个大小列表,它们的总和必须等于被分割维度的大小。
#### 使用 `torch.chunk` 按数量分割
另一种方法是,`torch.chunk` 沿着给定维度将张量分割成指定*数量*的块。PyTorch 会尝试使这些块的大小尽可能相等。与需要指定块大小的 `torch.split` 不同,`chunk` 只需指定所需的块数量。
```scala 3
import torch.*
// 创建一个张量
val tensor_h = torch.arange(10).reshape(5, 2) // 沿着维度0的大小为5
println(f"原始张量 (形状: {tensor_h.shape}):\n{tensor_h}\n")
// 沿着维度0分割成3个块
// 5行 / 3块 -> 大小将是 [2, 2, 1] (前几个块取 ceil(5/3)=2)
val chunked_tensor = torch.chunk(tensor_h, 3, dim=0)
println("分割成3个部分(dim=0):")
for i, chunk <- chunked_tensor:
println(f" 块 {i} (形状: {chunk.shape}):\n{chunk}")
println("-" * 20)
// 创建另一个张量
val tensor_i = torch.arange(12).reshape(3, 4) // 沿着维度1的大小为4
println(f"\n原始张量 (形状: {tensor_i.shape}):\n{tensor_i}\n")
// 沿着维度1分割成2个块
// 4列 / 2块 -> 大小将是 [2, 2] (ceil(4/2)=2)
val chunked_tensor_dim1 = torch.chunk(tensor_i, 2, dim=1)
println("分割成2个部分(dim=1):")
for i, chunk <- chunked_tensor_dim1:
println(f" 块 {i} (形状: {chunk.shape}):\n{chunk}")
```
```java
// 创建一个张量
Tensor tensor_h = torch.arange(new Scalar(10)).reshape(5, 2);
System.out.printf("原始张量 (形状: %s):\n %s\n", tensor_h.sizes(), tensor_h);
// 沿着维度0分割成3个块
// 5行 / 3块 -> 大小将是 [2, 2, 1] (前几个块取 ceil(5/3)=2)
TensorVector chunked_tensor = torch.chunk(tensor_h, 3, 0);
System.out.println("分割成3个部分(dim=0):");
for (int i = 0; i < chunked_tensor.size(); i++) {
Tensor chunk = chunked_tensor.get(i);
System.out.printf(" 块 %d (形状: %s):\n %s\n", i, chunk.sizes(), chunk);
}
System.out.println("-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-" + "-");
// 创建另一个张量
Tensor tensor_i = torch.arange(new Scalar(12)).reshape(3, 4);
System.out.printf("\n原始张量 (形状: %s):\n %s\n", tensor_i.sizes(), tensor_i);
// 沿着维度1分割成2个块
TensorVector chunked_tensor_dim1 = torch.chunk(tensor_i, 2, 1);
System.out.println("分割成2个部分(dim=1):");
for (int i = 0; i < chunked_tensor_dim1.size(); i++) {
Tensor chunk = chunked_tensor_dim1.get(i);
System.out.printf(" 块 %d (形状: %s):\n %s\n", i, chunk.sizes(), chunk);
}
```
| 函数 | 核心拆分规则 | 维度变化 | 关键参数 | 核心场景 |
| --------- | -------------------------------------- | ---------- | ------------------------ | -------------------------------- |
| `split()` | 按**指定长度**拆分目标维度(可不等长) | 维度数不变 | `split_size`(拆分长度) | 不等长拆分(如按序列长度拆分) |
| `chunk()` | 按**指定份数**拆分目标维度(尽量等长) | 维度数不变 | `chunks`(拆分份数) | 等长拆分(如批量均分、并行计算) |
## 一、核心原理:拆分的本质
两者的底层逻辑都是:
1. 确定**目标拆分维度 `dim`**(默认 dim=0);
2. 按各自规则计算每个子张量在 `dim` 上的长度;
3. 生成**视图(view)而非拷贝**—— 子张量和原张量共享内存(修改子张量会影响原张量);
4. 返回子张量列表,所有子张量的维度数与原张量一致。
关键差异在 “拆分长度的计算方式”:
- `split()`:你指定 “每段多长”,PyTorch 按这个长度切分(最后一段可能更短);
- `chunk()`:你指定 “切分成几段”,PyTorch 自动均分长度(最后一段可能少 1)
当你知道想要多少个部分,而不关心维度大小是否能被均匀整除时,`torch.chunk` 很方便。当你需要大小精确且可能变化的块时,`torch.split` 提供了更多的控制。
掌握这些合并和分割操作很重要,可以帮助你有效处理数据,因为它会流经你的深度学习管线的不同阶段,从初始加载和预处理,到训练的批处理,以及模型输出的分析。