下面是我写的函数:
def channel_var(image_dataset):
res = image_dataset[0]
for image in image_dataset[1:]:
res += image
return tuple(map(lambda x: x/len(image_dataset),
(torch.var(res[0]),
torch.var(res[1]),
torch.var(res[2]))))
然后我用正态分布测试它:
m = normal.Normal(0, 3)
m.sample((1, 3, 32, 32))
我得到了一个错误的结果:
channel_var(list_test)
>>(tensor(0.0338), tensor(0.0352), tensor(0.0365))
谢谢
我想计算每个通道的所有数据集的方差。每个元素都是火炬张量。你的方法似乎适合一张图片。你能给我一个实现吗?你知道吗
谢谢
你的功能是错误的。这是因为你要计算平均图像,然后计算平均图像中的信道方差。我想你不想那样。你可以通过使用
torch.var(img, dim=[0,2,3])
假设
dim=1
是通道维数,img是火炬张量。如果img不是torch张量,那么可以将img的列表连接起来,形成一个张量。你知道吗您可以这样做,因为
torch.var(torch.cat(img, dim=0), dim=[0,2,3])
cat
操作将列表连接到张量。你知道吗相关问题 更多 >
编程相关推荐