以十为单位修改张量的函数

2024-06-12 01:16:20 发布

您现在位置:Python中文网/ 问答频道 /正文

我在numpy有职能:

def modify_vec(old_config):
    idx = np.where(old_config != 0)[0]
    remove = np.random.choice(idx)
    old_config[remove] -= 1
    increase = np.random.randint(N)
    old_config[increase] += 1
    return old_config

其中,输入只是N个正整数分量的numpy向量。该函数所做的只是随机获取向量的索引,其中对应的元素与0不同,然后将该元素减少1,并将其添加到另一个随机选择的元素

我想使用Tensorflow(不调用任何会话)执行完全相同的操作,其中old_config是一个形状为(1,N)的常量张量。我在以下方面有困难。我已经在tensorflow中实现了where,如下所示:

idx = tf.where(tf.logical_not(tf.math.equal(old_config, tf.constant(0, dtype=tf.float32))))[:, 1]

但是现在,我在随机选择这个新张量idx的元素时遇到了问题。现在,这个old_config张量不能被修改,因为它是一个常数张量。如何将其内容“复制”到numpy数组,或者如何通过对函数modify_vec的修改创建另一个常量张量?谢谢


Tags: 函数numpyconfig元素tfnprandomwhere