mxnet自定义激活函数/op in numpy

2024-06-06 07:58:58 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个关于在mxnet中创建自定义激活函数/op时使用的语法的问题。我在看这个例子: https://github.com/dmlc/mxnet/blob/master/example/numpy-ops/custom_softmax.py

具体而言,本部分:

class Softmax(mx.operator.CustomOp):
    def forward(self, is_train, req, in_data, out_data, aux):
        x = in_data[0].asnumpy()
        y = np.exp(x - x.max(axis=1).reshape((x.shape[0], 1)))
        y /= y.sum(axis=1).reshape((x.shape[0], 1))
        self.assign(out_data[0], req[0], mx.nd.array(y))

    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
        l = in_data[1].asnumpy().ravel().astype(np.int)
        y = out_data[0].asnumpy()
        y[np.arange(l.shape[0]), l] -= 1.0
        self.assign(in_grad[0], req[0], mx.nd.array(y))

输入数据[0]与输入数据[1]以及输出数据[0]与输出数据[1]之间的关系如何?这些指数对应什么?在

谢谢!在


Tags: 数据inselfdatadefnpoutreq