Pytork的FFT不保持线性

2024-03-28 19:42:46 发布

您现在位置:Python中文网/ 问答频道 /正文

我不确定这是否是一个bug,或者我对索引/移位有些误解。当我试图计算傅里叶域中的差分时,计算中出现了一些错误。对于一个简单的情况,考虑一个恒等核Em>K<EEM>,这样对于任何信号 xEEE>跟随x -k*x=0。此外,x-kox=0。然而,当我使用pytorch计算傅里叶域中的残差时,我没有得到数值为0的解

import torch
from torch.fft import rfftn, irfftn
from PIL import Image
from torchvision.transforms import ToTensor
from torch.nn.functional import pad
from matplotlib import pyplot as plt

# Let X be an input image.
with Image.open("barbara.png") as im:
    X = ToTensor()(im)[0,...]

# Let K be a convolutional identity kernel.
K = torch.zeros((3,3))
K[1,1] = 1.0

# Full convolution shape.
cnvShape = (X.shape[0] + K.shape[0] - 1, X.shape[1] + K.shape[1] - 1)

# Apply DFT. Uncentered zero-padding to full shape.
Xhat = rfftn(X, s=cnvShape, dim=(0,1))
Khat = rfftn(K, s=cnvShape, dim=(0,1))

# Calculate convolution via Fourier domain.
XKhat = Xhat * Khat
XK = irfftn(XKhat, s=cnvShape)[1:-1,1:-1]
# Difference in spatial domain. Is numerically close to 0.
R = X - XK

# Difference in Fourier domain. This should be numerically close to 0.
Shat = Xhat - XKhat
S = irfftn(Shat)

plt.subplot(1,2,1)
plt.imshow(R)
plt.colorbar()
plt.subplot(1,2,2)
plt.imshow(S)
plt.colorbar()
plt.show()

# Center zeropad image FFT, produces numerically 0 solution.
That = rfftn(pad(X, (1,1)*2), s=cnvShape) - XKhat
T = irfftn(That)
plt.imshow(T)
plt.colorbar()
plt.show()

这导致 Difference between the signal and the convolved signal. On the left is taking the difference in spatial domain. On the right is the Fourier domain.Using a centered zeropad around the signal (y) such that y - k * x approx 0.

从上一个图中,我可以看到傅里叶域中的运算发生了某种变化,但这不是直观的。我还想弄清楚如何在傅里叶域中实现一致性运算,因为这只是我需要计算的一个小样本。我知道fftshift,但这在这里似乎不合适,因为这种移位似乎不具有负频率和正频率。那么,索引中发生了什么


Tags: tofromimportdomainplttorchbeshape