Tensorflow:通过最大值筛选3D索引中的重复项

2024-06-16 11:53:18 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图创建一个过滤器掩码,通过比较向量的哪个值更大,从向量中删除重复索引

我目前的做法是:

  1. 将三维索引转换为一维索引
  2. 检查一维索引的唯一性
  3. 计算每个唯一索引的最大值
  4. 将最大值与原始值进行比较。如果存在相同的值,则保留该三维索引

我想得到一个过滤器数组,这样我就可以把boolean_mask也应用到其他张量上。对于本例,遮罩的外观应如下所示: [False True True True True]

除非值本身也被复制,否则我当前的代码可以正常工作。然而,当我使用它时,情况似乎是这样,因此我需要找到更好的解决方案

下面是我的代码外观的示例

import tensorflow as tf

# Dummy Input values with same Structure as the real
x_cells   = tf.constant([1,2,3,4,1], dtype=tf.int32)   # Index_1
y_cells   = tf.constant([4,4,4,4,4], dtype=tf.int32)   # Index_2
iou_index = tf.constant([1,2,3,4,1], dtype=tf.int32) # Index_3
iou_max   = tf.constant([1.,2.,3.,4.,5.], dtype=tf.float32) # Values

# my Output should be a mask that is [False True True True True]
# So if i filter this i get e.g. x_cells = [2,3,4,1] or iou_max = [2.,3.,4.,5.]

max_dim_y = tf.constant(10)
max_dim_x = tf.constant(20)
num_anchors = 5
stride = 32

# 1. Transforming the 3D-Index to 1D
tmp = tf.stack([x_cells, y_cells, iou_index], axis=1)
indices = tf.matmul(tmp, [[max_dim_y * num_anchors],     [num_anchors],[1]])

# 2. Looking for unique / duplicate indices
y, idx = tf.unique(tf.squeeze(indices))

# 3. Calculating the maximum values of each unique index.
# An function like unsorted_segment_argmax() would be awesome here
num_segments = tf.shape(y)[0]
ious = tf.unsorted_segment_max(iou_max, idx, num_segments)

iou_max_length = tf.shape(iou_max)[0]
ious_length = tf.shape(ious)[0]

# 4. Compare all max values to original values.
iou_max_tiled = tf.tile(iou_max, [ious_length])
iou_reshaped = tf.reshape(iou_max_tiled, [ious_length, iou_max_length])
iou_max_reshaped = tf.transpose(iou_reshaped)
filter_mask = tf.reduce_any(tf.equal(iou_max_reshaped, ious), -1)
filter_mask = tf.reshape(filter_mask, shape=[-1])

如果我们简单地将iou_max变量开头的值更改为:

x_cells = tf.constant([1,2,3,4,1], dtype=tf.int32)
y_cells = tf.constant([4,4,4,4,4], dtype=tf.int32)
iou_index = tf.constant([1,2,3,4,1], dtype=tf.int32)
iou_max = tf.constant([2.,2.,3.,4.,5.], dtype=tf.float32)


Tags: trueindextfmasklengthnummaxvalues
1条回答
网友
1楼 · 发布于 2024-06-16 11:53:18

我目前的解决办法改变了我问题的第4点:

基本上我改变了比较元组而不是单个值的方式。这使我能够逻辑地检查索引和值是否都在3的剩余值中。

# 4. Compare a Max Value and Indices with original values
rem_index_val_pair = tf.stack([ious, tf.cast(y, dtype=tf.float32)], axis=1)
orig_val_index_pair = tf.stack([iou_max, tf.cast(indices, dtype=tf.float32)], axis=1)

orig_val_index_pair_t = tf.tile(orig_val_index_pair, [1, ious_length])
orig_val_index_pair_s = tf.reshape(orig_val_index_pair_t, [iou_max_length, ious_length, 2])
filter_mask_1 = tf.equal(orig_val_index_pair_s, rem_index_val_pair)
filter_mask_2 = tf.reduce_all(filter_mask_1, -1)
filter_mask_3 = tf.reduce_any(filter_mask_2, -1)

# The orig_val_index_pair_s looks like the following
a =  [[[  2.  71.][  2.  71.][  2.  71.][  2.  71.]
     [[  2. 122.][  2. 122.][  2. 122.][  2. 122.]]
     [[  3. 173.][  3. 173.][  3. 173.][  3. 173.]]
     [[  4. 224.][  4. 224.][  4. 224.][  4. 224.]]
     [[  5.  71.][  5.  71.][  5.  71.][  5.  71.]]]
# I then compare it to the rem_max_val_pair which looks like this.
b =  [[  5.  71.][  2. 122.][  3. 173.][  4. 224.]]

# Using equal(a,b) will now compare each of the values resulting in:
c = [[[False  True][ True False][False False][False False]]
     [[False False][ True  True][False False][False False]]
     [[False False][False False][ True  True][False False]]
     [[False False][False False][False False][ True  True]]
     [[ True  True][False False][False False][False False]]]

# Using tf.reduce_all(c, -1) I can filter the bool pairs with a logical And. 
# (This kicks out my false positives from before).
# Afterwards I can check if the line has any true value by tf.reduce_any().

在我看来,这个解决方案仍然是一个棘手的解决办法。因此,如果您有更好的解决方案建议,请与他人分享。:)

相关问题 更多 >