多维十位数的加权抽样

2024-04-27 03:54:53 发布

您现在位置:Python中文网/ 问答频道 /正文

我需要对多维张量进行加权采样。你知道吗

我有一个形状为[X,Y]的张量A和形状为[X]的概率分布B。 我需要根据分布对B中的N元素进行采样。你知道吗

B表示子传感器的分布。每个子传感器内的采样是均匀的。你知道吗

A中有一些填充,所以我必须考虑到这一点。什么是填充的信息包含在掩码中。你知道吗

例如

A      = [[1,   2,   3,   X,   X,  X],
          [10,  20,  30,  X,   X,  X],
          [100, 200, 300, 400, 500, 600]]
A_mask = [[T,   T,   T,   F,   F,  F],
          [T,   T,   T,   F,   F,  F],
          [T,   T,   T,   T,   T,  T]]
B = [0.5, 0.4, 0.1]

# a possible output, with N = 10
ouput = [1, 1, 2, 2, 3, 10, 20, 30, 30, 200]

我可以从A的每个嵌套张量中检索要采样的元素数,方法是:

tf.multinomial(tf.log(probability_distribution), N)

# a possible output of that function, with N = 10, is:
[1, 1, 1, 1, 1, 2, 2, 2, 2, 3]

对于这些数字中的每一个,我都必须在子传感器中执行均匀采样。你知道吗

我能够计算每个子传感器的最大值。你知道吗

subtensor_sizes = tf.reduce_sum(tf.cast(A_mask, tf.int32), axis=1)

# it would return: [3, 3, 6]

此时,对于多项式函数返回的每个子传感器,我应该在0与其maxvalue之间执行统一采样(或者类似地,计算出现次数并从多项式输出中出现T次的子传感器中采样T个元素)。你知道吗

我不知道怎么处理,怎么办?你知道吗


Tags: 方法log信息元素outputtfwithmask
1条回答
网友
1楼 · 发布于 2024-04-27 03:54:53

所以有一个包含不同长度序列的张量A。您想要从这些序列中提取值,以不同的概率B为每个序列选择一个值。你知道吗

您可以按以下步骤进行:

import tensorflow as tf

A = tf.constant(
    [[1,   2,   3,   -1,  -1,  -1],
     [10,  20,  30,  -1,  -1,  -1],
     [100, 200, 300, 400, 500, 600]])
A_mask = tf.constant(
    [[True,   True,   True,   False,   False,  False],
     [True,   True,   True,   False,   False,  False],
     [True,   True,   True,   True,   True,  True]])
B = tf.constant([0.5, 0.4, 0.1])
subtensor_sizes = tf.reduce_sum(tf.cast(A_mask, tf.int32), axis=1)

# get random sample index
output = tf.to_int32(tf.multinomial(tf.log(B[None]), 10)[0])
# get corresponding sample size
output_sizes = tf.gather(subtensor_sizes, output)
# generate a random index in each range
random_idxs = tf.map_fn(
  lambda x: tf.random_uniform((), maxval=x, dtype=tf.int32), output_sizes)
# construct nd-index for tf.gather
random_ndxs = tf.concat([output[:, None], random_idxs[:, None]], axis=-1)
# get sample values
random_samples = tf.gather_nd(A, random_ndxs)

相关问题 更多 >