张量图像中点的插值采样

2024-06-07 17:03:45 发布

您现在位置:Python中文网/ 问答频道 /正文

给定的是一个灰度图像I作为2D张量(维W,H)和一个坐标张量C(Dim)。无,2)。我想将C的行解释为I中的坐标,在这些坐标处使用某种插值(双线性可能对我的用例很好)采样I,并将得到的值存储在一个新的张量P中(无量纲,即1维,有尽可能多的条目C有行)。

使用TensorFlow这可能(有效)吗?我能找到的只是调整图像大小的函数(如果你愿意,可以等距重采样)。但在坐标列表中我找不到任何现成的样本。

也就是说,我本想找到一个类似tf.interpolate()的函数:

I = tf.placeholder("float", shape=[128, 128])
C = tf.placeholder("float", shape=[None, 2])
P = tf.interpolate(I, C, axis=[0, 1], method="linear")

理想情况下,我会寻找一种解决方案,使我能够使用带形状(None,M)的C沿M维插值N维张量I,并产生N-M+1维输出,如上面代码中的“axis”参数所示。

(我的应用程序中的“图像”不是图片,而是来自物理模型(用作占位符)或替代学习模型(用作变量)的采样数据。现在这个物理模型有2个自由度,因此在“图像”中进行插值就足够了,但我可能会在将来研究更高维度的模型。)

如果现有的TensorFlow特性不可能实现这样的功能:当我想实现类似于tf.interpolate()的运算符时,应该从哪里开始?(文档和/或简单示例代码)


Tags: 函数代码模型图像nonetftensorflow物理
2条回答

没有执行这种插值的内置操作,但是您应该能够使用现有TensorFlow操作的组合来执行。对于双线性情况,我建议采取以下策略:

  1. 从索引的张量C,计算对应于四个角点的整数张量。例如(名称假定原点在左上角):

    top_left = tf.cast(tf.floor(C), tf.int32)
    
    top_right = tf.cast(
        tf.concat(1, [tf.floor(C[:, 0:1]), tf.ceil(C[:, 1:2])]), tf.int32)
    
    bottom_left = tf.cast(
        tf.concat(1, [tf.ceil(C[:, 0:1]), tf.floor(C[:, 1:2])]), tf.int32)
    
    bottom_right = tf.cast(tf.ceil(C), tf.int32)
    
  2. 从表示特定角点的每个张量中,从这些点的I中提取值向量。例如,对于以下函数,对二维情况执行此操作:

    def get_values_at_coordinates(input, coordinates):
      input_as_vector = tf.reshape(input, [-1])
      coordinates_as_indices = (coordinates[:, 0] * tf.shape(input)[1]) + coordinates[:, 1]
      return tf.gather(input_as_vector, coordinates_as_indices)
    
    values_at_top_left = get_values_at_coordinates(I, top_left)
    values_at_top_right = get_values_at_coordinates(I, top_right)
    values_at_bottom_left = get_values_at_coordinates(I, bottom_left)
    values_at_bottom_right = get_values_at_coordinates(I, bottom_right)
    
  3. 首先计算水平方向上的插值:

    # Varies between 0.0 and 1.0.
    horizontal_offset = C[:, 0] - tf.cast(top_left[:, 0], tf.float32)
    
    horizontal_interpolated_top = (
        ((1.0 - horizontal_offset) * values_at_top_left)
        + (horizontal_offset * values_at_top_right))
    
    horizontal_interpolated_bottom = (
        ((1.0 - horizontal_offset) * values_at_bottom_left)
        + (horizontal_offset * values_at_bottom_right))
    
  4. 现在计算垂直方向上的插值:

    vertical_offset = C[:, 1] - tf.cast(top_left[:, 1], tf.float32)
    
    interpolated_result = (
        ((1.0 - vertical_offset) * horizontal_interpolated_top)
        + (vertical_offset * horizontal_interpolated_bottom))
    

考虑到TF还没有Numpy切片的通用性(github issue #206),而且gather只在第一维上工作,这对最近的邻居来说是很棘手的。但有一种方法可以通过使用gather->;transpose->;gather->;extract diagonal来解决这个问题

def identity_matrix(n):
  """Returns nxn identity matrix."""
  # note, if n is a constant node, this assert node won't be executed,
  # this error will be caught during shape analysis 
  assert_op = tf.Assert(tf.greater(n, 0), ["Matrix size must be positive"])
  with tf.control_dependencies([assert_op]):
    ones = tf.fill(n, 1)
    diag = tf.diag(ones)
  return diag

def extract_diagonal(tensor):
  """Extract diagonal of a square matrix."""

  shape = tf.shape(tensor)
  n = shape[0]
  assert_op = tf.Assert(tf.equal(shape[0], shape[1]), ["Can't get diagonal of "
                                                       "a non-square matrix"])

  with tf.control_dependencies([assert_op]):
    return tf.reduce_sum(tf.mul(tensor, identity_matrix(n)), [0])


# create sample matrix
size=4
I0=np.zeros((size,size), dtype=np.int32)
for i in range(size):
  for j in range(size):
    I0[i, j] = 10*i+j

I = tf.placeholder(dtype=np.int32, shape=(size,size))
C = tf.placeholder(np.int32, shape=[None, 2])
C0 = np.array([[0, 1], [1, 2], [2, 3]])
row_indices = C[:, 0]
col_indices = C[:, 1]

# since gather only supports dim0, have to transpose
I1 = tf.gather(I, row_indices)
I2 = tf.gather(tf.transpose(I1), col_indices)
I3 = extract_diagonal(tf.transpose(I2))

sess = create_session()
print sess.run([I3], feed_dict={I:I0, C:C0})

从这样的矩阵开始:

array([[ 0,  1,  2,  3],
       [10, 11, 12, 13],
       [20, 21, 22, 23],
       [30, 31, 32, 33]], dtype=int32)

此代码提取主目录上方的对角线

[array([ 1, 12, 23], dtype=int32)]

[]运算符变成SqueezeSlice时会发生一些神奇的变化

enter image description here

相关问题 更多 >

    热门问题