如何在Tensorflow中复制PyTorch的nn.functional.unfold函数?

2024-05-16 08:11:37 发布

您现在位置:Python中文网/ 问答频道 /正文

我想使用tensorflow重写pytorch的torch.nn.functional.unfold函数:

#input x:[16, 1, 50, 36]
x = torch.nn.functional.unfold(x, kernel_size=(5, 36), stride=3)
#output x:[16, 180, 16]

我尝试使用函数tf.extract_image_patches()

x = tf.extract_image_patches(x,ksizes=[1, 1,5, 98],strides=[1, 1, 3, 1], rates=[1, 1, 1, 1],padding='VALID')

输入x.shape[16,1,64,98]

我得到输出x.shape[16,1,20,490]

然后我将X重塑为[16,490,20],这是我所期望的

但当我输入数据时,我得到了错误:

UnimplementedError (see above for traceback): Only support ksizes across space.
[[Node:hcn/ExtractImagePatches = ExtractImagePatches[T=DT_FLOAT, ksizes=[1, 1, 5, 98], padding="VALID", rates=[1, 1, 1, 1], strides=[1, 1, 3, 1], _device="/job:localhost/replica:0/task:0/device:GPU:0"](hcn/Reshape)]]

如何使用tensorflow重写pytorchtorch.nn.functional.unfold函数来更改X


Tags: 函数imagetftensorflowextractnntorchfunctional