连续使用fixed_quad错误

1 投票
1 回答
543 浏览
提问于 2025-04-18 17:30

我刚开始学习Python,使用quad函数做连续积分时没有问题,但当我换成fixed_quad时,就遇到了麻烦,像下面这个测试例子:

def test(z,r):
  return r*z**2 + 5
def inttest(r):
  return fixed_quad(test, 10*r, 100, args=(r,))[0]
def test2(r, t):
  return inttest(r)*(2*t+3)
def inttest2(t):
  return fixed_quad(test2, 3, 5, args = (t,), n=5)[0]
print inttest2(3) 

我得到了以下错误:

Traceback (most recent call last):
  File "C:\Python27\tt11.py", line 132, in <module>
    print inttest2(3)
  File "C:\Python27\tt11.py", line 129, in inttest2
    return fixed_quad(test2, 3, 5, args = (t,), n=5)[0]
  File "C:\Python27\lib\site-packages\scipy\integrate\quadrature.py", line 58, in fixed_quad
    return (b-a)/2.0*sum(w*func(y,*args),0), None
  File "C:\Python27\tt11.py", line 124, in test2
    return inttest(r)*(2*t+3)
  File "C:\Python27\tt11.py", line 119, in inttest
    return fixed_quad(test, 10*r, 100, args=(r,))[0]
  File "C:\Python27\lib\site-packages\scipy\integrate\quadrature.py", line 54, in fixed_quad
    if ainf or binf:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or   a.all()

我哪里出错了呢?任何帮助都非常感谢。

1 个回答

1

简单来说,你的代码不工作是因为 test2 不能接受 np.array 作为输入。

fixed_quad 里的相关代码非常简短:

def fixed_quad(func,a,b,args=(),n=5):
    [x,w] = p_roots(n)
    x = real(x)
    ainf, binf = map(isinf,(a,b))
    if ainf or binf:
         raise ValueError("Gaussian quadrature is only available for "
                 "finite limits.")
    y = (b-a)*(x+1)/2.0 + a
    return (b-a)/2.0*sum(w*func(y,*args),0), None

而文档说明中提到:

Parameters
----------
func : callable
    A Python function or method to integrate (must accept vector inputs).

y 将会是一个大小为 (n, )np.array。所以如果你的 test2 不能接受 np.array 作为输入,func(y,*args) 就会报错。

因此,如果我们设置 n=1,它会正常工作(但当然是没什么用的):

def inttest2(t):
  return si.fixed_quad(test2, 3, 5, args = (t,), n=1)[0]
print inttest2(3) 
#22469400.0

解决办法是让 test2 能够接受 np.array 作为 r 的输入:

def test2(r, t):
    return np.array(map(inttest, np.array(r)))*(2*t+3)
def inttest2(t):
    return fixed_quad(test2, 3, 5, args = (t,), n=5)[0]
print inttest2(3)
#22276200.0

撰写回答