沿给定位置(X,Y)的第三轴(Z)更新rank3 tensorflow张量中的切片

2024-05-28 19:49:36 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试使用Tensorflow 1.9.0重新实现下面的函数(用numpy编写)

def lateral_inhibition2(conv_spikes,SpikesPerNeuronAllowed):
    vbn = np.where(SpikesPerNeuronAllowed==0)
    conv_spikes[vbn[0],vbn[1],:]=0 
    return conv_spikes

conv_spikes是秩3的二元张量,SpikesPerNeuronAllowed是秩2的张量conv_spikes是一个变量,如果某个特定位置的神经元含有10则表示该位置的神经元没有出现峰值SpikesPerNeuronAllowed变量表示沿Z轴位于X-Y位置的所有神经元是否允许出现峰值。在SpikesPerNeuronAllowed中的1表示在conv_spikes中相应的X-Y位置和沿着Z轴的神经元可以出现尖峰。一个0表示在conv_spikes和沿着Z轴的相应X-Y位置的神经元不允许出现尖峰

conv_spikes2 = (np.random.rand(5,5,3)>=0.5).astype(np.int16)
temp2 = np.random.choice([0, 1], size=(25,), p=[3./4, 1./4])
SpikesPerNeuronAllowed2 = temp2.reshape(5,5)
print(conv_spikes2[:,:,0])
print
print(conv_spikes2[:,:,1])
print
print(conv_spikes2[:,:,2])
print
print(SpikesPerNeuronAllowed2)

生成以下输出

##First slice of conv_spikes across Z-axis
[[0 0 1 1 1]
 [1 0 0 1 1]
 [1 0 1 1 0]
 [0 1 0 1 1]
 [0 1 0 0 0]]
##Second slice of conv_spikes across Z-axis
[[0 0 1 0 0]
 [0 0 1 0 1]
 [0 0 1 1 1]
 [0 0 0 1 0]
 [1 1 1 1 1]]
##Third slice of conv_spikes across Z-axis
[[0 1 1 0 0]
 [0 0 1 0 0]
 [0 1 1 0 0]
 [0 0 0 1 0]
 [1 0 1 1 1]]
##SpikesPerNeuronAllowed2
[[0 0 0 0 1]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [1 1 0 0 0]
 [0 0 0 1 0]]

现在,当函数被调用时

conv_spikes2 = lateral_inhibition2(conv_spikes2,SpikesPerNeuronAllowed2)
print(conv_spikes2[:,:,0])
print
print(conv_spikes2[:,:,1])
print
print(conv_spikes2[:,:,2])

生成以下输出

##First slice of conv_spikes across Z-axis
[[0 0 0 0 1]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 1 0 0 0]
 [0 0 0 0 0]]
##Second slice of conv_spikes across Z-axis
[[0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 1 0]]
##Third slice of conv_spikes across Z-axis
[[0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 1 0]]

我试着在Tensorflow中重复下面的内容

conv_spikes_tf = tf.Variable((np.random.rand(5,5,3)>=0.5).astype(np.int16))
a_placeholder = tf.placeholder(tf.float32,shape=(5,5))
b_placeholder = tf.placeholder(tf.float32)
inter2 = tf.where(tf.equal(a_placeholder,b_placeholder))
output= sess.run(inter2,feed_dict{a_placeholder:SpikesPerNeuronAllowed2,b_placeholder:0})
print(output)

产生以下输出

[[0 0]
 [0 1]
 [0 2]
 [0 3]
 [1 0]
 [1 1]
 [1 2]
 [1 3]
 [1 4]
 [2 0]
 [2 1]
 [2 2]
 [2 3]
 [2 4]
 [3 2]
 [3 3]
 [3 4]
 [4 0]
 [4 1]
 [4 2]
 [4 4]]

我试着用下面的代码更新conv_spikes_tf,结果出现了一个错误,我试着阅读scatter_nd_update的手册,但我想我不是很理解

update = tf.scatter_nd_update(conv_spikes_tf, output, np.zeros(output.shape[0]))
sess.run(update)

ValueError: The inner 1 dimensions of input.shape=[5,5,3] must match the inner 1 dimensions of updates.shape=[21,2]: Dimension 0 in both shapes must be equal, but are 3 and 2. Shapes are [3] and [2]. for 'ScatterNdUpdate_8' (op: 'ScatterNdUpdate') with input shapes: [5,5,3], [21,2], [21,2].

我不理解错误消息,特别是inner 1 dimensions是什么意思,如何使用tensorflow实现上述numpy功能


Tags: ofoutputtfnpsliceplaceholderacrossprint
1条回答
网友
1楼 · 发布于 2024-05-28 19:49:36

tf.scatter_nd_updateupdates的最后一个维度应该是3,这等于ref的最后一个维度

update = tf.scatter_nd_update(conv_spikes_tf, output, np.zeros(output.shape[0], 3))

如果我理解正确,您希望对conv\u峰值应用SpikesPerNeuronAllowed2(mask)。一个更简单的方法是将conv_spikes重塑为(3,5,5)并乘以SpikesPerNeuronAllowed2

我用一个常数例子来说明结果。你也可以把它改成tf.Variable

conv = (np.random.rand(3,5,5)>=0.5).astype(np.int32)
tmp = np.random.choice([0, 1], size=(25,), p=[3./4, 1./4])
mask = tmp.reshape(5,5)
# array([[[1, 1, 0, 0, 0],
#         [0, 1, 0, 0, 1],
#         [0, 1, 0, 0, 1],
#         [1, 0, 0, 0, 1],
#         [1, 0, 0, 1, 0]],

#        [[1, 0, 0, 0, 1],
#         [1, 0, 1, 1, 1],
#         [0, 0, 1, 0, 1],
#         [0, 0, 0, 1, 1],
#         [0, 0, 0, 1, 1]],

#        [[0, 0, 0, 1, 0],
#         [0, 1, 1, 0, 1],
#         [0, 1, 1, 0, 1],
#         [1, 1, 1, 1, 0],
#         [1, 1, 1, 0, 1]]], dtype=int32)

# array([[0, 0, 0, 1, 1],
#        [0, 0, 0, 1, 0],
#        [0, 0, 0, 0, 0],
#        [0, 1, 0, 1, 0],
#        [0, 0, 1, 0, 1]])
tf_conv = tf.constant(conv, dtype=tf.int32)
tf_mask = tf.constant(mask, dtype=tf.int32)
res = tf_conv * tf_mask
sess = tf.InteractiveSession()
sess.run(res)
# array([[[0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0]],

#        [[0, 0, 0, 0, 1],
#         [0, 0, 0, 1, 0],
#         [0, 0, 0, 0, 0],
#         [0, 0, 0, 1, 0],
#         [0, 0, 0, 0, 1]],

#        [[0, 0, 0, 1, 0],
#         [0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0],
#         [0, 1, 0, 1, 0],
#         [0, 0, 1, 0, 1]]], dtype=int32)

相关问题 更多 >

    热门问题