Tensorflow概率:从联合分布中检索特定的随机变量

2024-04-29 03:30:14 发布

您现在位置:Python中文网/ 问答频道 /正文

我不熟悉tensorflow概率。 我正在构建一个层次模型,为此我使用JointDistributionSequentialAPI:

jds = tfp.distributions.JointDistributionSequential(
[
    # mu_g ~ uniform on sphere
    tfp.distributions.VonMisesFisher(
        mean_direction= [1] + [0]*(D-1),
        concentration=0,
        validate_args=True,
        name="mu_g"
    ),
    # epsilon ~ Exponential
    tfp.distributions.Exponential(
        rate=1,
        validate_args=True,
        name="epsilon"
    ),
    # mu_s ~ von Mises Fisher centered on mu_g
    lambda epsilon, mu_g: tfp.distributions.VonMisesFisher(
        mean_direction=mu_g,
        concentration=np.array(
            [epsilon]*S
        ),
        validate_args=True,
        name="mu_s"
    ),
    # sigma ~ Exponential
    tfp.distributions.Exponential(
        rate=1,
        validate_args=True,
        name="sigma"
    ),
    # mu_t_s ~ von Mises Fisher centered on mu_s
    lambda sigma, mu_s: tfp.distributions.VonMisesFisher(
        mean_direction=mu_s,
        concentration=np.array(
            [
                [sigma]*S
            ]*T
        ),
        validate_args=True,
        name="mu_t_s"
    ),
    # kappa ~ Exponential
    tfp.distributions.Exponential(
        rate=1,
        validate_args=True,
        name="kappa"
    ),
    # x_t_s ~ mixture of L groups of vMF
    lambda kappa, mu_t_s: tfp.distributions.VonMisesFisher(
        mean_direction=mu_t_s,
        concentration=np.array(
            [
                [
                    [
                        kappa
                    ]*S
                ]*T
            ]*N
        ),
        validate_args=True,
    name="x_t_s
    )            
]
)

然后,我打算使用混合API创建这些模型的混合:

l = tfp.distributions.Categorical(
probs=np.array(
    [
        [
            [
                [1.0/L]*L
            ]*S
        ]*T 
    ]*N               
),
name="l"
)

mixture = tfd.Mixture(
cat=l,
components=[
    jds
] * L,
validate_args=True
)

这不管用。我想要混合的是层次模型“末端”的随机变量,即批形状(N,t,s)的x\u t\u s。我想我需要将它们输入到混合物的组件参数中。问题是我无法从模型对象轻松检索这些变量

有人能找到解决这个问题的办法吗

注意,我尝试使用jds.model[-1]而不是jds,但这指向lambda函数,这不是我在这里需要的


Tags: lambdanametrueargsmeanvalidatedistributionsconcentration
1条回答
网友
1楼 · 发布于 2024-04-29 03:30:14

这里有几点想法

  1. 考虑SphericalUniform作为第一个发行版对于相同类型的{{CD2}}s,考虑使用^ {CD3}}。<李>
  2. 将混合物放入分层模型中。i、 e.不是最后一个发行版是vMF,它可以是MixtureSameFamily(Categorical(...), VonMisesFisher(...))
  3. 如果以后要访问组件,可以调用ds, xs = jds.sample_distributions(),然后查看ds[-1].component_distribution

欢迎发电子邮件tfprobability@tensorflow.orgw/还有问题

相关问题 更多 >