MindSpore `mindspore.nn`模块学习

2 阅读1分钟

mindspore.nn是 MindSpore 框架中用于构建神经网络的核心模块,提供了丰富的预定义构建块(称为 Cell)和计算单元。

1. 核心概念:神经网络 Cell

  • Cell是 MindSpore 中神经网络的基本构成单元。所有网络层、损失函数、优化器等都可以视为 Cell。
  • 动态 Shape 支持:需要注意不同接口对动态 shape 的支持情况。
  • API 变更:不同 MindSpore 版本间,该模块的 API 可能会有新增、删除或平台支持变化。

2. 基本构成单元

  • Cell:所有神经网络结构的基类。
  • GraphCell:用于运行从 MindIR 格式加载的计算图。
  • LossBase:所有损失函数的基类。
  • Optimizer:所有参数优化器的基类。

3. 容器 (Containers)

用于组织和管理多个 Cell,构建更复杂的网络结构:

  • SequentialCell:顺序容器,用于按顺序组合多个 Cell。
  • CellList:以列表形式管理 Cell。
  • CellDict:以字典形式管理 Cell。

4. 封装层 (Wrapper Cells)

主要用于训练流程的封装、分布式训练和优化策略:

  • 训练流程封装:TrainOneStepCell(单步训练),TrainOneStepWithLossScaleCell(混合精度训练),WithLossCell(组合网络与损失),WithEvalCell(组合网络、损失函数用于评估)。
  • 分布式与并行:DistributedGradReducer(分布式梯度聚合),PipelineCell(流水线并行),GradAccumulationCell(梯度累积)。
  • 其他工具:GetNextSingleOp(数据获取),ParameterUpdate(参数更新),DynamicLossScaleUpdateCell/ FixedLossScaleUpdateCell(动态/固定损失缩放)。

5. 主要网络层类型

文档详细列出了多种类型的网络层,是构建模型的核心组件:

  • 卷积神经网络层:

    • 基础卷积:Conv1dConv2dConv3d
    • 转置卷积:Conv1dTransposeConv2dTransposeConv3dTranspose
  • 循环神经网络层:

    • 经典RNN:RNNRNNCell
    • 门控单元:GRUGRUCellLSTMLSTMCell
  • Transformer 层:

    • 核心组件:MultiheadAttention(多头注意力)。
    • 编码器与解码器:TransformerEncoderLayerTransformerDecoderLayerTransformerEncoderTransformerDecoder
    • 完整模块:Transformer
  • 嵌入层:EmbeddingEmbeddingLookupMultiFieldEmbeddingLookup

  • 线性层:Dense(全连接层), BiDense(双线性全连接层)。

  • 非线性激活函数层:种类繁多,包括:

    • 常用函数:ReLUSigmoidTanhGELUSoftmaxLogSoftmax
    • 其他变体:LeakyReLUPReLUELUSiLU(Swish), Mish等。
  • 归一化层:

    • 批归一化:BatchNorm1d/2d/3dSyncBatchNorm(跨设备同步)。
    • 其他归一化:LayerNormGroupNormInstanceNorm1d/2d/3d
  • 池化层:

    • 自适应池化:AdaptiveAvgPool1d/2d/3dAdaptiveMaxPool1d/2d/3d
    • 常规池化:AvgPool1d/2d/3dMaxPool1d/2d/3d及其逆操作 MaxUnpool1d/2d/3d
    • 其他:LPPool1d/2dFractionalMaxPool3d
  • Dropout 层:DropoutDropout1d/2d/3d

  • 填充层 (Padding):支持常量填充(ConstantPad1d/2d/3d)、反射填充(ReflectionPad1d/2d/3d)、复制填充(ReplicationPad1d/2d/3d)等多种方式。

  • 图像处理层:PixelShuffle(像素重组), PixelUnshuffle(逆操作), Upsample(上采样)。

  • 公共层:Flatten(展平), Unflatten(反展平), ChannelShuffle(通道重排), Identity(恒等映射)。

6. 损失函数

提供了丰富的损失函数,涵盖分类、回归、序列学习等任务:

  • 分类任务:CrossEntropyLossBCELossBCEWithLogitsLossNLLLoss
  • 回归任务:L1LossMSELossSmoothL1LossHuberLoss
  • 其他:CTCLoss(时序分类), CosineEmbeddingLoss(相似度), TripletMarginLoss(三元组损失), KLDivLoss(KL散度) 等。
  • 注意:文档标注 DiceLossFocalLossSampledSoftmaxLoss等从特定版本开始已被弃用。

7. 优化器

实现了主流的梯度下降优化算法及其变种:

  • 自适应算法:AdamAdamW(文档中为AdamWeightDecay), AdaMaxAdagradAdadeltaRMSProp
  • 带动量的算法:SGDMomentum
  • 大规模/分布式优化:LAMBLARS
  • 其他:ASGDFTRLRprop等。
  • 注意:LazyAdam已被标记为弃用。

8. 动态学习率

支持两种方式实现学习率动态调整:

  • LearningRateSchedule 类:如 CosineDecayLRPolynomialDecayLRWarmUpLR。将类的实例传递给优化器。
  • Dynamic LR 函数:如 cosine_decay_lrpiecewise_constant_lrwarmup_lr。调用函数并将返回的列表传递给优化器。

9. 工具

  • no init:一个实用工具,用于在不初始化参数的情况下创建 Cell。