我们使用chainer.functions.linear
来计算y=Wx+b
。你知道吗
在我的例子中,我必须实现一个多维度的线性链接。你知道吗
假设输入示例是(c, x)
,那么所需的输出就是y = W_c x + b
。让我们忽略偏见,让它y = W_c x
。{c}
的基数是预先知道的(通常是样本类)。你知道吗
理论上W
参数可以实现为三维张量(C, y_dims, x_dims)
。但还有什么?我是否必须迭代批处理并提取形状为(y_dims, x_dims)
的W_c
并仅为该(1, x_dims)
形状的示例调用functions.linear
?你知道吗
嗯,我自己找到了一个解决问题的办法。你知道吗
数据的形状如下:
W: (C, y_dims, x_dims)
x: (batch, x_dims)
c: (batch, 1)
首先,我必须为每批x得到一个权重矩阵:
所以这里的关键函数是
get_item
,它接受numpy.ndarray
和cupy.ndarray
,但不接受chainer.Variable
。它的工作原理类似于numpy.take
,但是它是可微的,并且节省了大量的工作。你知道吗相关问题 更多 >
编程相关推荐