从100x100 pytorch张量获得一个10x10的补丁,边界周围有环面样式的环绕

2024-06-16 09:58:06 发布

您现在位置:Python中文网/ 问答频道 /正文

如何从100x100 pytorch张量中获得10x10面片,并附加一个约束条件,即如果面片超出阵列的边界,则它将环绕边缘(就像阵列是一个圆环,顶部与底部相连,左侧与右侧相连)

我写了这段代码来完成这项工作,我正在寻找更优雅、高效和清晰的东西:

def shift_matrix(a, distances) -> Tensor:
  x, y = distances
  a = torch.cat((a[x:], a[0:x]), dim=0)
  a = torch.cat((a[:, y:], a[:, :y]), dim=1)
  return a

def randomly_shift_matrix(a) -> Tensor:
  return shift_matrix(a, np.random.randint(low = 0, high = a.size()))

def random_patch(a, size) -> Tensor:
  full_shifted_matrix = randomly_shift_matrix(a)
  return full_shifted_matrix[0:size[0], 0:size[1]]

我觉得带负指数片的东西应该可以用。不过我还没找到

你可以see the code in google colab here


Tags: sizereturnshiftdefrandomtorchrandomlymatrix