mindspore.nn是 MindSpore 框架中用于构建神经网络的核心模块,提供了丰富的预定义构建块(称为 Cell)和计算单元。
1. 核心概念:神经网络 Cell
- Cell是 MindSpore 中神经网络的基本构成单元。所有网络层、损失函数、优化器等都可以视为 Cell。
- 动态 Shape 支持:需要注意不同接口对动态 shape 的支持情况。
- API 变更:不同 MindSpore 版本间,该模块的 API 可能会有新增、删除或平台支持变化。
2. 基本构成单元
Cell:所有神经网络结构的基类。GraphCell:用于运行从 MindIR 格式加载的计算图。LossBase:所有损失函数的基类。Optimizer:所有参数优化器的基类。
3. 容器 (Containers)
用于组织和管理多个 Cell,构建更复杂的网络结构:
SequentialCell:顺序容器,用于按顺序组合多个 Cell。CellList:以列表形式管理 Cell。CellDict:以字典形式管理 Cell。
4. 封装层 (Wrapper Cells)
主要用于训练流程的封装、分布式训练和优化策略:
- 训练流程封装:
TrainOneStepCell(单步训练),TrainOneStepWithLossScaleCell(混合精度训练),WithLossCell(组合网络与损失),WithEvalCell(组合网络、损失函数用于评估)。 - 分布式与并行:
DistributedGradReducer(分布式梯度聚合),PipelineCell(流水线并行),GradAccumulationCell(梯度累积)。 - 其他工具:
GetNextSingleOp(数据获取),ParameterUpdate(参数更新),DynamicLossScaleUpdateCell/FixedLossScaleUpdateCell(动态/固定损失缩放)。
5. 主要网络层类型
文档详细列出了多种类型的网络层,是构建模型的核心组件:
-
卷积神经网络层:
- 基础卷积:
Conv1d,Conv2d,Conv3d。 - 转置卷积:
Conv1dTranspose,Conv2dTranspose,Conv3dTranspose。
- 基础卷积:
-
循环神经网络层:
- 经典RNN:
RNN,RNNCell。 - 门控单元:
GRU,GRUCell;LSTM,LSTMCell。
- 经典RNN:
-
Transformer 层:
- 核心组件:
MultiheadAttention(多头注意力)。 - 编码器与解码器:
TransformerEncoderLayer,TransformerDecoderLayer,TransformerEncoder,TransformerDecoder。 - 完整模块:
Transformer。
- 核心组件:
-
嵌入层:
Embedding,EmbeddingLookup,MultiFieldEmbeddingLookup。 -
线性层:
Dense(全连接层),BiDense(双线性全连接层)。 -
非线性激活函数层:种类繁多,包括:
- 常用函数:
ReLU,Sigmoid,Tanh,GELU,Softmax,LogSoftmax。 - 其他变体:
LeakyReLU,PReLU,ELU,SiLU(Swish),Mish等。
- 常用函数:
-
归一化层:
- 批归一化:
BatchNorm1d/2d/3d,SyncBatchNorm(跨设备同步)。 - 其他归一化:
LayerNorm,GroupNorm,InstanceNorm1d/2d/3d。
- 批归一化:
-
池化层:
- 自适应池化:
AdaptiveAvgPool1d/2d/3d,AdaptiveMaxPool1d/2d/3d。 - 常规池化:
AvgPool1d/2d/3d,MaxPool1d/2d/3d及其逆操作MaxUnpool1d/2d/3d。 - 其他:
LPPool1d/2d,FractionalMaxPool3d。
- 自适应池化:
-
Dropout 层:
Dropout,Dropout1d/2d/3d。 -
填充层 (Padding):支持常量填充(
ConstantPad1d/2d/3d)、反射填充(ReflectionPad1d/2d/3d)、复制填充(ReplicationPad1d/2d/3d)等多种方式。 -
图像处理层:
PixelShuffle(像素重组),PixelUnshuffle(逆操作),Upsample(上采样)。 -
公共层:
Flatten(展平),Unflatten(反展平),ChannelShuffle(通道重排),Identity(恒等映射)。
6. 损失函数
提供了丰富的损失函数,涵盖分类、回归、序列学习等任务:
- 分类任务:
CrossEntropyLoss,BCELoss,BCEWithLogitsLoss,NLLLoss。 - 回归任务:
L1Loss,MSELoss,SmoothL1Loss,HuberLoss。 - 其他:
CTCLoss(时序分类),CosineEmbeddingLoss(相似度),TripletMarginLoss(三元组损失),KLDivLoss(KL散度) 等。 - 注意:文档标注
DiceLoss,FocalLoss,SampledSoftmaxLoss等从特定版本开始已被弃用。
7. 优化器
实现了主流的梯度下降优化算法及其变种:
- 自适应算法:
Adam,AdamW(文档中为AdamWeightDecay),AdaMax,Adagrad,Adadelta,RMSProp。 - 带动量的算法:
SGD,Momentum。 - 大规模/分布式优化:
LAMB,LARS。 - 其他:
ASGD,FTRL,Rprop等。 - 注意:
LazyAdam已被标记为弃用。
8. 动态学习率
支持两种方式实现学习率动态调整:
- LearningRateSchedule 类:如
CosineDecayLR,PolynomialDecayLR,WarmUpLR。将类的实例传递给优化器。 - Dynamic LR 函数:如
cosine_decay_lr,piecewise_constant_lr,warmup_lr。调用函数并将返回的列表传递给优化器。
9. 工具
no init:一个实用工具,用于在不初始化参数的情况下创建 Cell。