Numpy.dot bug?NaN行为不一致

9 投票
1 回答
1569 浏览
提问于 2025-04-18 04:48

我发现了在使用 numpy.dot 时,当里面有 nan(不是一个数字)和零的时候,表现得不太一致。

有没有人能解释一下这是怎么回事?这是个bug吗?还是说只是 dot 函数特有的问题?

我用的是 numpy 版本 1.6.1,64位,运行在 Linux 上(我也在 1.6.2 版本上测试过)。我还在 32位的 Windows 上测试了 1.8.0(所以我不太确定这些差异是因为版本、操作系统还是架构造成的)。

from numpy import *
0*nan, nan*0
=> (nan, nan)  # makes sense

#1
a = array([[0]])
b = array([[nan]])
dot(a, b)
=> array([[ nan]])  # OK

#2 -- adding a value to b. the first value in the result is
#     not expected to be affected.
a = array([[0]])
b = array([[nan, 1]])
dot(a, b)
=> array([[ 0.,  0.]])  # EXPECTED : array([[ nan,  0.]])
# (also happens in 1.6.2 and 1.8.0)
# Also, as @Bill noted, a*b works as expected, but not dot(a,b)

#3 -- changing a from 0 to 1, the first value in the result is
#     not expected to be affected.
a = array([[1]])
b = array([[nan, 1]])
dot(a, b)
=> array([[ nan,   1.]])  # OK

#4 -- changing shape of a, changes nan in result
a = array([[0],[0]])
b = array([[ nan, 1.]])
dot(a, b)
=> array([[ 0.,  0.], [ 0.,  0.]])  # EXPECTED : array([[ nan,  0.], [ nan,  0.]])
# (works as expected in 1.6.2 and 1.8.0)

第4个案例在 1.6.2 和 1.8.0 版本中似乎运行正常,但第2个案例却不行……


补充:@seberg 指出这是一个 BLAS 的问题,所以我通过运行 from numpy.distutils.system_info import get_info; get_info('blas_opt') 找到了关于 BLAS 安装的信息:

1.6.1 linux 64bit
/usr/lib/python2.7/dist-packages/numpy/distutils/system_info.py:1423: UserWarning: 
    Atlas (http://math-atlas.sourceforge.net/) libraries not found.
    Directories to search for the libraries can be specified in the
    numpy/distutils/site.cfg file (section [atlas]) or by setting
    the ATLAS environment variable.
  warnings.warn(AtlasNotFoundError.__doc__)
{'libraries': ['blas'], 'library_dirs': ['/usr/lib'], 'language': 'f77', 'define_macros': [('NO_ATLAS_INFO', 1)]}

1.8.0 windows 32bit (anaconda)
c:\Anaconda\Lib\site-packages\numpy\distutils\system_info.py:1534: UserWarning:
   Blas (http://www.netlib.org/blas/) sources not found.
   Directories to search for the sources can be specified in the
   numpy/distutils/site.cfg file (section [blas_src]) or by setting
   the BLAS_SRC environment variable.
 warnings.warn(BlasSrcNotFoundError.__doc__)
{}

(我个人也不知道该怎么理解这些信息)

1 个回答

3

我觉得,正如seberg所说,这个问题可能出在使用的BLAS库上。如果你查看numpy.dot是怎么实现的,可以在这里这里找到相关代码,你会发现它调用了cblas_dgemm(),这是处理双精度矩阵相乘的函数。

这个C程序复现了你的一些例子,使用“普通”的BLAS时输出是一样的,而使用ATLAS时输出是正确的。

#include <stdio.h>
#include <math.h>

#include "cblas.h"

void onebyone(double a11, double b11, double expectc11)
{
  enum CBLAS_ORDER order=CblasRowMajor;
  enum CBLAS_TRANSPOSE transA=CblasNoTrans;
  enum CBLAS_TRANSPOSE transB=CblasNoTrans;
  int M=1;
  int N=1;
  int K=1;
  double alpha=1.0;
  double A[1]={a11};
  int lda=1;
  double B[1]={b11};
  int ldb=1;
  double beta=0.0;
  double C[1];
  int ldc=1;

  cblas_dgemm(order, transA, transB,
              M, N, K,
              alpha,A,lda,
              B, ldb,
              beta, C, ldc);

  printf("dot([ %.18g],[%.18g]) -> [%.18g]; expected [%.18g]\n",a11,b11,C[0],expectc11);
}

void onebytwo(double a11, double b11, double b12,
              double expectc11, double expectc12)
{
  enum CBLAS_ORDER order=CblasRowMajor;
  enum CBLAS_TRANSPOSE transA=CblasNoTrans;
  enum CBLAS_TRANSPOSE transB=CblasNoTrans;
  int M=1;
  int N=2;
  int K=1;
  double alpha=1.0;
  double A[]={a11};
  int lda=1;
  double B[2]={b11,b12};
  int ldb=2;
  double beta=0.0;
  double C[2];
  int ldc=2;

  cblas_dgemm(order, transA, transB,
              M, N, K,
              alpha,A,lda,
              B, ldb,
              beta, C, ldc);

  printf("dot([ %.18g],[%.18g, %.18g]) -> [%.18g, %.18g]; expected [%.18g, %.18g]\n",
         a11,b11,b12,C[0],C[1],expectc11,expectc12);
}

int
main()
{
  onebyone(0, 0, 0);
  onebyone(2, 3, 6);
  onebyone(NAN, 0, NAN);
  onebyone(0, NAN, NAN);
  onebytwo(0, 0,0, 0,0);
  onebytwo(2, 3,5, 6,10);
  onebytwo(0, NAN,0, NAN,0);
  onebytwo(NAN, 0,0, NAN,NAN);
  return 0;
}

使用BLAS的输出:

dot([ 0],[0]) -> [0]; expected [0]
dot([ 2],[3]) -> [6]; expected [6]
dot([ nan],[0]) -> [nan]; expected [nan]
dot([ 0],[nan]) -> [0]; expected [nan]
dot([ 0],[0, 0]) -> [0, 0]; expected [0, 0]
dot([ 2],[3, 5]) -> [6, 10]; expected [6, 10]
dot([ 0],[nan, 0]) -> [0, 0]; expected [nan, 0]
dot([ nan],[0, 0]) -> [nan, nan]; expected [nan, nan]

使用ATLAS的输出:

dot([ 0],[0]) -> [0]; expected [0]
dot([ 2],[3]) -> [6]; expected [6]
dot([ nan],[0]) -> [nan]; expected [nan]
dot([ 0],[nan]) -> [nan]; expected [nan]
dot([ 0],[0, 0]) -> [0, 0]; expected [0, 0]
dot([ 2],[3, 5]) -> [6, 10]; expected [6, 10]
dot([ 0],[nan, 0]) -> [nan, 0]; expected [nan, 0]
dot([ nan],[0, 0]) -> [nan, nan]; expected [nan, nan]

当第一个操作数是NaN时,BLAS的表现似乎是正常的,但当第一个操作数是零而第二个是NaN时,结果就不对了。

总之,我认为这个bug不在Numpy层面,而是在BLAS里。看起来可以通过使用ATLAS来解决这个问题。

以上是在Ubuntu 14.04上生成的,使用的是Ubuntu提供的gcc、BLAS和ATLAS。

撰写回答