从一个张量中取出一个元素,该元素也同时存在于另一个张量中

2024-04-20 13:47:34 发布

您现在位置:Python中文网/ 问答频道 /正文

我有两个张量,我必须迭代第一个,只取另一个张量中的元素。在t2中只有一个元素也在t1中。这里有一个例子

t1 = tf.where(values > 0) # I get some indices example [6, 0], [3, 0]
t2 = tf.where(values2 > 0) # I get [4, 0], [3, 0]

t3 = .... # [3, 0]

我尝试使用.eval()对它们求值和迭代,并使用操作符in检查t2的元素是否在t1中,但没有起作用。有张量流的函数可以做到吗?你知道吗

编辑

for index in xrange(max_indices):
    indices = tf.where(tf.equal(values, (index + 1))).eval() # indices: [[1 0]\n [4 0]\n [9 0]]
    cent_indices = tf.where(centers > 0).eval() # cent_indices: [[6 0]\n [9 0]]
    indices_list.append(indices)
    for cent in cent_indices:
        if cent in indices:
           centers_list.append(cent)
           break

第一次迭代cent具有值[6 0],但它进入了if条件。你知道吗

回答

for index in xrange(max_indices):
    indices = tf.where(tf.equal(values, (index + 1))).eval()
    cent_indices = tf.where(centers > 0).eval()
    indices_list.append(indices)
    for cent in cent_indices:
        # batch_item is an iterator from an outer loop
        if values[batch_item, cent[0]].eval() == (index + 1):
           centers_list.append(tf.constant(cent))
           break

这个解与我的任务有关,但是如果你在寻找一维张量的解,我建议你看看tf.sets.set_intersection


Tags: in元素forindextfevalwherecent
1条回答
网友
1楼 · 发布于 2024-04-20 13:47:34

这就是你想要的吗?我只用了这两个测试用例。你知道吗

x = tf.constant([[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 1]])
y = tf.constant([[1, 2, 3, 4, 3, 6], [1, 2, 3, 4, 5, 1]])
# x = tf.constant([[1, 2], [4, 5], [7, 7]])
# y = tf.constant([[7, 7], [3, 5]])

def match(xiterations, yiterations, yvalues, xvalues ):
    for i in range(xiterations):
        for j in range(yiterations):
            if (np.array_equal(yvalues[j], xvalues[i])):
                print( yvalues[j])

with tf.Session() as sess:
    xindex = tf.where( x > 4 )
    yindex = tf.where( y > 4 )

    xvalues = xindex.eval()
    yvalues = yindex.eval()

    xiterations =  tf.shape(xvalues)[0].eval()
    yiterations =  tf.shape(yvalues)[0].eval()

    print(tf.shape(xvalues)[0].eval())
    print(tf.shape(yvalues)[0].eval())

    if tf.shape(xvalues)[0].eval() >= tf.shape(yvalues)[0].eval():
        match( xiterations, yiterations, yvalues, xvalues)
    else:
        match( yiterations, xiterations, xvalues, yvalues)

相关问题 更多 >