调节pytorch 1.5.0和pytorch 1.9.0之间fft输出的差异

2024-06-16 14:22:15 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试使一些与Pythorch 1.5.0配合使用的python3代码在较新版本上也能正常工作(我目前正在使用Pythorch 1.9.0)。更具体地说,我正在尝试更新进行快速傅立叶变换的代码。我试图用pytorch 1.9.0中的torch.fft.fftn()和torch.view_as_real()替换pytorch 1.5.0中的torch.rfft()。我注意到,当我运行以下命令时,得到的输出略有不同:

使用PyTorch 1.5.0:

import torch
import numpy as np
arr = torch.from_numpy(np.array([[1.,2.,3.,4.,5.],
                                 [6.,7.,8.,9.,10.],
                                 [11.,12.,13.,14.,15.],
                                 [16.,17.,18.,19.,20.]]))
ftt_arr = torch.rfft(arr,2,onesided=False)
print(fft_arr)

使用PyTorch 1.9.0:

import torch
import numpy as np
arr = torch.from_numpy(np.array([[1.,2.,3.,4.,5.],
                                 [6.,7.,8.,9.,10.],
                                 [11.,12.,13.,14.,15.],
                                 [16.,17.,18.,19.,20.]]))
fft_arr = torch.fft.fftn(arr,norm="backward")
fft_arr = torch.view_as_real(fft_arr)
print(fft_arr)

两个快速傅里叶变换的输出如下:

pytorch 1.5.0:

tensor([[[211.0000,   0.0000],
         [-10.8090,  13.1760],
         [ -9.6910,   4.2003],
         [ -9.6910,  -4.2003],
         [-10.8090, -13.1760]],

        [[-50.0000,  51.0000],
         [  0.5878,  -0.8090],
         [ -0.9511,   0.3090],
         [  0.9511,   0.3090],
         [ -0.5878,  -0.8090]],

        [[-51.0000,   0.0000],
         [  0.8090,   0.5878],
         [ -0.3090,  -0.9511],
         [ -0.3090,   0.9511],
         [  0.8090,  -0.5878]],

        [[-50.0000, -51.0000],
         [ -0.5878,   0.8090],
         [  0.9511,  -0.3090],
         [ -0.9511,  -0.3090],
         [  0.5878,   0.8090]]], dtype=torch.float64)

pytorch 1.9.0:

tensor([[[ 2.1000e+02,  0.0000e+00],
         [-1.0000e+01,  1.3764e+01],
         [-1.0000e+01,  3.2492e+00],
         [-1.0000e+01, -3.2492e+00],
         [-1.0000e+01, -1.3764e+01]],

        [[-5.0000e+01,  5.0000e+01],
         [ 2.2204e-15,  0.0000e+00],
         [ 1.7764e-15, -4.4409e-16],
         [ 1.7764e-15, -4.4409e-16],
         [ 2.2204e-15,  0.0000e+00]],

        [[-5.0000e+01,  0.0000e+00],
         [-1.7764e-15,  0.0000e+00],
         [-8.8818e-16,  0.0000e+00],
         [-8.8818e-16,  0.0000e+00],
         [-1.7764e-15,  0.0000e+00]],

        [[-5.0000e+01, -5.0000e+01],
         [ 2.2204e-15,  0.0000e+00],
         [ 1.7764e-15,  4.4409e-16],
         [ 1.7764e-15,  4.4409e-16],
         [ 2.2204e-15,  0.0000e+00]]], dtype=torch.float64)

所有的输出值似乎在+/-1左右变化,我无法解释或协调


Tags: 代码importfftnumpyviewasnptorch