Pymc3:非常慢且卡顿
有没有什么原因导致NUTS采样器运行缓慢或停滞不前呢?我正在使用这个链接作为我进行一些层次线性回归工作的基础。
我尝试从find_MAP()开始,但在2000次迭代中,经过100次后仍然停滞不前。
我的代码是
with pm.Model() as hierarchical_model:
# Hyperpriors for group nodes
mu_a = pm.Normal('mu_alpha', mu=0., sd=100**2)
sigma_a = pm.Uniform('sigma_alpha', lower=0, upper=100)
mu_b = pm.Normal('mu_beta', mu=0., sd=100**2)
sigma_b = pm.Uniform('sigma_beta', lower=0, upper=100)
a = pm.Normal('alpha', mu=mu_a, sd=sigma_a, shape=n_dis)
b = pm.Normal('beta', mu=mu_b, sd=sigma_b, shape=n_dis)
# Model error
eps = pm.Uniform('eps', lower=0, upper=100)
actual_est = a[disRefV] + b[disRefV] * data.baseline.values
actual_like = pm.Normal('actual_like', mu=actual_est, sd=eps, observed=data.prepanel)
with hierarchical_model:
start = pm.find_MAP()
step = pm.NUTS()
hierarchical_trace = pm.sample(2000, step, progressbar=True)
非常感谢!
1 个回答
3
NUTS 有时候确实会卡住。你有没有对你的数据进行 z-score 标准化?在我的实验中,这通常会有所帮助。如果你在使用层次模型,那么你可能需要根据组的平均值来进行 z-score 标准化,而不是单个的平均值。
然后你可以把你的后验结果重新缩放回原来的范围(可以参考 Kruscke 的书《Doing Bayesian Data Analysis》中的公式 17.1)。