张量物体上的一种特殊切片

2024-04-25 12:29:49 发布

您现在位置:Python中文网/ 问答频道 /正文

问题的总结,在tensorflow中是否支持这种切片然后赋值?你知道吗

out[tf_a2[y],x[:,None]] = tf_a1[tf_a2[y],x[:,None]]
final = out[:-1]

举个例子,我有这样一个张量:

tf_a1 = tf.Variable([    [9.968594,  8.655439,  0.,        0.       ],
                         [0.,        8.3356,    0.,        8.8974   ],
                         [0.,        0.,        6.103182,  7.330564 ],
                         [6.609862,  0.,        3.0614321, 0.       ],
                         [9.497023,  0.,        3.8914037, 0.       ],
                         [0.,        8.457685,  8.602337,  0.       ],
                         [0.,        0.,        5.826657,  8.283971 ],
                         [0.,        0.,        0.,        0.       ]])

我有一个:

tf_a2 = tf.constant([[1, 2, 5],
                    [1, 4, 6],
                    [0, 7, 7],
                    [2, 3, 6],
                    [2, 4, 7]])

现在我想保留tf_a1中的元素,其中n(这里n是2)的组合(它们的索引)的值为tf_a2。这是什么意思?你知道吗

例如,在tf_a1的第一列中,有值的索引是:(0,3,4)。在tf_a2中是否有任何行包含这两个索引的任意组合:(0,3)、(0,4)或(3,4)。其实,没有这样的争吵。所以那列中的所有元素都变成了零。你知道吗

tf_a1中第二列的索引是(0,1)(0,5)(1,5)。如您所见,记录(1,5)在第一行的tf_a2中可用。这就是为什么我们把它们保存在tf_a1。你知道吗

这是正确的numpy代码:

y,x = np.where(np.count_nonzero(a1p[a2], axis=1) >= n)
out = np.zeros_like(tf_a1)
out[tf_a2[y],x[:,None]] = tf_a1[tf_a2[y],x[:,None]]
final = out[:-1]

这是这个numpy代码的预期输出(但我在tensorflow中需要它):

[[0.        0.        0.        0.       ]
 [0.        8.3356    0.        8.8974   ]
 [0.        0.        6.103182  7.330564 ]
 [0.        0.        3.0614321 0.       ]
 [0.        0.        3.8914037 0.       ]
 [0.        8.457685  8.602337  0.       ]
 [0.        0.        5.826657  8.283971 ]]

tensorflow代码应该是这样的:

y, x = tf.where(tf.count_nonzero(tf.gather(tf_a1, tf_a2, axis=0), axis=1) >= n)
out = tf.zeros_like(tf_a1)
out[tf_a2[y],x[:,None]] = tf_a1[tf_a2[y],x[:,None]]
final = out[:-1]

这部分代码tf.gather(tf_a1, tf_a2, axis=0), axis=1)正在执行numpy一样的切片tf_a1[tf_a2]

更新1

唯一不起作用的线路:

out[tf_a2[y],x[:,None]] = tf_a1[tf_a2[y],x[:,None]]
final = out[:-1]

你知道如何在tensorflow中实现这一点吗?这种切片在tensor对象中是受支持的吗?你知道吗

感谢您的帮助:)


Tags: 代码numpynonea2元素tftensorflowa1