MindSpore:如何在静态图模式下,在construct函数里更新网络权重?

154 阅读1分钟

import numpy as np

import mindspore

import mindspore as ms

from mindspore import ops

from mindspore import nn, Tensor, Parameter, context

# context.set_context(mode=context.PYNATIVE_MODE)

\

class MyConv2d(nn.Cell):

    def init(self):

        super().init()

        self.conv = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')

        self.tmp = ms.ParameterTuple(self.get_parameters()) 

    def construct(self, x, w):

        # w是权重

        for weight in self.tmp:

            # 更新权重

            ops.Assign()(weight, w)

        return self.conv(x)

\

x = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)

w = Tensor(np.ones([240, 120, 4, 4]), mindspore.float32)

output = MyConv2d()(x, w)

print(output)