2024-04-26 00:11:24 发布
网友
我正在尝试将从PyTorch中的模型得到的预测保存为csv。下面的代码。但是,代码似乎覆盖了每个历元的预测,最终的.csv文件只包含3个值。你知道我做错了什么吗
output = model(images) preds = output.sum(dim=[1,2,3]) np.savetxt('preds.csv', preds.detach())
假设您的输出是float,您可以将其映射到str,并将join作为逗号分隔的字符串应用,最后将行附加到csv
import pandas as pd import numpy as np import torch output = torch.rand(5,5,5,5) preds = output.sum(dim=[1,2,3]) print(preds) with open('preds.csv','a') as fd: fd.write( ','.join(map(str, preds.detach().tolist())) + '\n')
假设您的输出是float,您可以将其映射到str,并将join作为逗号分隔的字符串应用,最后将行附加到csv
相关问题 更多 >
编程相关推荐