Pytork将预测保存为每个历元的CSV

2024-04-26 00:11:24 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试将从PyTorch中的模型得到的预测保存为csv。下面的代码。但是,代码似乎覆盖了每个历元的预测,最终的.csv文件只包含3个值。你知道我做错了什么吗

output = model(images)
preds = output.sum(dim=[1,2,3])
np.savetxt('preds.csv', preds.detach())


Tags: 文件csv代码模型outputmodelnppytorch
1条回答
网友
1楼 · 发布于 2024-04-26 00:11:24

假设您的输出是float,您可以将其映射到str,并将join作为逗号分隔的字符串应用,最后将行附加到csv

import pandas as pd
import numpy as np
import torch

output = torch.rand(5,5,5,5)
preds = output.sum(dim=[1,2,3])

print(preds)

with open('preds.csv','a') as fd:
    fd.write( ','.join(map(str, preds.detach().tolist())) + '\n')

相关问题 更多 >