Numpy Array2string正在写。。。在串里?

2024-04-25 05:37:52 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一件简单的事情要做,读取一些向量并将它们写入一个文件。在

向量是1024维的。在

 for emb in src:
            print(len(emb[0].detach().cpu().numpy()))  #--> prints 1024!
            f.write(np.array2string(emb[0].detach().cpu().numpy(), separator=', ') + " \n")

我的文件如下:

^{pr2}$

所以,我不能访问我的向量,1024维被转换成任何6维或7维向量+。。。。:(

如何正确地将向量写入文件?在

干杯:)


Tags: 文件insrcnumpyforlennpcpu
2条回答

将2d数组写入文本文件(以便可以读回)的常规方法是使用np.savetxt

In [309]: src = np.random.rand(6,4)
In [310]: src
Out[310]: 
array([[0.78756364, 0.11385762, 0.16631052, 0.10987765],
       [0.59954504, 0.80417064, 0.22461205, 0.47827772],
       [0.10993457, 0.11650874, 0.55887911, 0.71854456],
       [0.53572426, 0.55055622, 0.25423811, 0.46038837],
       [0.05418115, 0.50696182, 0.31515915, 0.65310375],
       [0.81168653, 0.81063907, 0.95371101, 0.11875685]])

写下:

^{pr2}$

测试负载:

In [314]: np.genfromtxt('test.txt',delimiter=',')
Out[314]: 
array([[0.787564, 0.113858, 0.166311, 0.109878],
       [0.599545, 0.804171, 0.224612, 0.478278],
       [0.109935, 0.116509, 0.558879, 0.718545],
       [0.535724, 0.550556, 0.254238, 0.460388],
       [0.054181, 0.506962, 0.315159, 0.653104],
       [0.811687, 0.810639, 0.953711, 0.118757]])

savetxt逐行格式化写入,大致如下:

In [315]: fmt = ','.join(['%10.6f']*4)
In [316]: fmt
Out[316]: '%10.6f,%10.6f,%10.6f,%10.6f'
In [317]: for row in src:
     ...:     print(fmt%tuple(row))    # f.write(...)
     ...:     
  0.787564,  0.113858,  0.166311,  0.109878
  0.599545,  0.804171,  0.224612,  0.478278
  0.109935,  0.116509,  0.558879,  0.718545
  0.535724,  0.550556,  0.254238,  0.460388
  0.054181,  0.506962,  0.315159,  0.653104
  0.811687,  0.810639,  0.953711,  0.118757

事实上,我可以用file write来包装:

In [318]: with open('test1.txt','w') as f:
     ...:     for row in src:
     ...:         print(fmt%tuple(row), file=f)
     ...:     
In [319]: cat test1.txt
  0.787564,  0.113858,  0.166311,  0.109878
  0.599545,  0.804171,  0.224612,  0.478278
 ...

矢量仍然是1024维的,但是显示器只显示阵列的缩小视图。在

通过设置打印选项,可以查看整个阵列:

import numpy as np
np.set_printoptions(threshold=np.nan)

相关问题 更多 >