如何创建线性分数分布作为自定义离散概率分布?

2024-06-01 01:45:29 发布

您现在位置:Python中文网/ 问答频道 /正文

我定义了以下自定义概率分布:

import scipy.stats as st

# parameters
a = 3 / 16
b = 1

class linear_fractional(st.rv_discrete):
    def _pdf(self, n):
        if (n == 0):
            return (a + b - 1) / (a + b)
        else:
            return (a * b ** (n - 1)) / (a + b) ** (n + 1)

LF = linear_fractional()
LF.rvs()

当我让脚本运行时,会收到一条很长的错误消息:

Traceback (most recent call last):
File "C:/Users/thoma/PycharmProjects/Host_Parasite_Coevolution/Asymptotics.py", line 17, in <module> LF.rvs()
File "C:\Users\thoma\AppData\Local\Programs\Python\Python37-32\lib\site-packages\scipy\stats\_distn_infrastructure.py", line 2969, in rvs
    return super(rv_discrete, self).rvs(*args, **kwargs)

...

RecursionError: maximum recursion depth exceeded while calling a Python object

如果我改为LF.mean(),我得到

Fatal Python error: Cannot recover from stack overflow.

有人知道这是为什么吗?我如何解决这个问题?我必须定义概率分布的上界吗


Tags: selfreturn定义statsscipyusersfilelinear
1条回答
网友
1楼 · 发布于 2024-06-01 01:45:29

根据the docsthis post给出的示例,该方法需要一些修改。重要的是,由于它是一个离散分布,因此应该使用_pmf而不是_pdf。另外,_pmf将被n的numpy样式数组调用,而n == 0的工作方式完全不同

因为(a * b ** (n - 1)) / (a + b) ** (n + 1)等于(a + b - 1) / (a + b)n == 0时,我们可以对所有n使用第一个表达式。但是,当b是整数且n = -1时,numpy会生成错误。将b1.0相乘会将其更改为浮点,numpy不会给出此类错误。如果多次使用相同的参数ab,可能会生成冻结分布

下面是一个示例,它创建生成样本的直方图,并将其与pmf进行比较

import scipy.stats as st
import numpy as np
from matplotlib import pyplot as plt

class linear_fractional(st.rv_discrete):
    def _pmf(self, n, a, b):
        return (a * (1.0 * b) ** (n - 1)) / (a + b) ** (n + 1)

# parameters
a = 3 / 16
b = 1

LF = linear_fractional()

N = 10000
plt.hist(LF.rvs(a, b, size=N), bins=np.arange(-0.5, 50), ec='w', label='histogram of samples')
plt.plot(LF.pmf(np.arange(50), a, b) * N, 'ro', label='probability mass function (scaled)')
plt.legend(title=f'$a={a}; b={b}$')
plt.autoscale(enable=True, axis='x', tight=True)
plt.show()

resulting histogram

LF.mean(a, b)输出5.33333333333286

散点图是说明分布样本的另一种方法:

plt.scatter(np.random.uniform(0, 1, N), LF.rvs(a, b, size=N), marker=',', alpha=0.2, lw=0, s=1, color='crimson')

scatter plot

PS:当b=1时,此分布的公式等于geometric distribution加上p = a/(a+1)并减去1。这要快得多,因为它完全是在numpy内部计算的

samples = np.random.geometric(a/(a+1), size=1000) - 1

相关问题 更多 >