如何从100x100 pytorch张量中获得10x10面片,并附加一个约束条件,即如果面片超出阵列的边界,则它将环绕边缘(就像阵列是一个圆环,顶部与底部相连,左侧与右侧相连)
我写了这段代码来完成这项工作,我正在寻找更优雅、高效和清晰的东西:
def shift_matrix(a, distances) -> Tensor:
x, y = distances
a = torch.cat((a[x:], a[0:x]), dim=0)
a = torch.cat((a[:, y:], a[:, :y]), dim=1)
return a
def randomly_shift_matrix(a) -> Tensor:
return shift_matrix(a, np.random.randint(low = 0, high = a.size()))
def random_patch(a, size) -> Tensor:
full_shifted_matrix = randomly_shift_matrix(a)
return full_shifted_matrix[0:size[0], 0:size[1]]
我觉得带负指数片的东西应该可以用。不过我还没找到
您正在寻找^{}
相关问题 更多 >
编程相关推荐