Numpyro AR(1)均值切换模型采样不一致性
我正在尝试估计一个AR(1)过程,记作y
,这个过程的均值会根据一个潜在状态S
(取值为0或1)进行切换。这个状态是一个马尔可夫过程,具有固定的转移概率(就像这里所说的那样)。简单来说,它的形式如下:
y_t - mu_{0/1} = phi * (y_{t-1} - mu_{0/1})+ epsilon_t
在这个公式中,如果状态state_t
等于0,就会使用mu_0
,如果state_t
等于1,就会使用mu_1
。我正在使用jax/numpyro和DiscreteHMCGibbs(虽然使用正常的NUTS和潜在状态枚举也能得到相同的结果),但我似乎无法让采样器正常工作。从我运行的所有诊断来看,所有的超参数都卡在了初始化值上,汇总结果也显示所有的标准差都等于0。下面是一个最小可重现示例(MWE),可以复现我的问题。我在实现中是不是犯了什么明显的错误?
最小可重现示例:
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.control_flow import scan
from numpyro.infer import MCMC, NUTS,DiscreteHMCGibbs
from jax import random, pure_callback
import jax
import numpy as np
def generate_synthetic_data(T=100, mu=[0, 5], phi=0.5, sigma=1.0, p=np.array([[0.95, 0.05], [0.1, 0.9]])):
states = np.zeros(T, dtype=np.int32)
y = np.zeros(T)
current_state = np.random.choice([0, 1], p=[0.5, 0.5])
states[0] = current_state
y[0] = np.random.normal(mu[current_state], sigma)
for t in range(1, T):
current_state = np.random.choice([0, 1], p=p[current_state,:])
states[t] = current_state
y[t] = np.random.normal(mu[current_state] + phi * (y[t-1] - mu[current_state]), sigma)
return y, states
def mean_switching_AR1_model(y):
T = len(y)
phi = numpyro.sample('phi', dist.Normal(0, 1))
sigma = numpyro.sample('sigma', dist.Exponential(1))
with numpyro.plate('state_plate', 2):
mu = numpyro.sample('mu', dist.Normal(0, 5))
p = numpyro.sample('p', dist.Dirichlet(jnp.ones(2)))
probs_init = numpyro.sample('probs_init', dist.Dirichlet(jnp.ones(2)))
s_0 = numpyro.sample('s_0', dist.Categorical(probs_init))
def transition_fn(carry, y_t):
prev_state = carry
state_probs = p[prev_state]
state = numpyro.sample('state', dist.Categorical(state_probs))
mu_state = mu[state]
y_mean = mu_state + phi * (y_t - mu_state)
y_next = numpyro.sample('y_next', dist.Normal(y_mean, sigma), obs=y_t)
return state, (state, y_next)
_ , (signal, y)=scan(transition_fn, s_0, y[:-1], length=T-1)
return (signal, y)
# Synthetic data generation
T = 1000
mu_true = [0, 3]
phi_true = 0.5
sigma_true = 0.25
transition_matrix_true = np.array([[0.95, 0.05], [0.1, 0.9]])
y, states_true = generate_synthetic_data(T, mu=mu_true, phi=phi_true, sigma=sigma_true, p=transition_matrix_true)
rng_key = random.PRNGKey(0)
nuts_kernel = NUTS(mean_switching_AR1_model)
gibbs_kernel = DiscreteHMCGibbs(nuts_kernel, modified=True)
# Run MCMC
mcmc = MCMC(gibbs_kernel, num_samples=1000, num_warmup=1000)
mcmc.run(rng_key, y=y)
mcmc.print_summary()
1 个回答
1
结果发现,我确实犯了一个很明显的错误,就是没有正确地把 y_{t-1}
作为状态变量传递下去。下面这个修正后的转换函数可以顺利地得到想要的结果,没有问题。
def transition_fn(carry, y_curr):
prev_state, y_prev = carry
state_probs = p[prev_state]
state = numpyro.sample('state', dist.Categorical(state_probs))
mu_state = mu[state]
y_mean = mu_state + phi * (y_prev - mu_state)
y_curr = numpyro.sample('y_curr', dist.Normal(y_mean, sigma), obs=y_curr)
return (state, y_curr), (state, y_curr)
_, (signal, y) = scan(transition_fn, (s_0, y[0]), y[1:], length=T-1)