利用张量与其他张量的切片来赋值

2024-04-28 17:26:51 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个张量nextq,它是某个问题集的概率分布。我对synthetic_answers中的每个问题都有可能的答案,即0或1

  1. nextq中为批处理中的每个向量找到具有最大值的索引

  2. 如果该索引处的synthetic_answers为1,则将该索引处cur_qinput的第三个特征设置为1,否则设置第二个特征

这是一些非函数性代码,它在for循环中是非函数性的,因为我不知道如何正确地将张量与其他张量/赋值切分,我只是尝试用python语法编写它,以明确我的意图

#nextq shape =  batch_size x q_size
#nextq_index shape =  batch_size
nextq_index = tf.argmax(nextq,axis=1)


#synthetic_answers shape =  batch_size x q_size
#cur_qinput shape = batch_size x q_size x 3

#"iterate over batch", doesn't actually work and I guess needs to be done entirely differently
for k in tf.range(tf.shape(nextq_index)[0]):
    cur_qinput[k,nextq_index[k],1+synthetic_answers[k,nextq_index[k]]]=1

Tags: 函数答案forsizeindextfbatch特征
1条回答
网友
1楼 · 发布于 2024-04-28 17:26:51

我假设你的数据如下,因为这个问题没有例子

import tensorflow as tf

nextq = tf.constant([[1,5,4],[6,8,10]],dtype=tf.float32)
synthetic_answers = tf.constant([[0,1,1],[1,1,0]],dtype=tf.int32)
cur_qinput = tf.random_normal(shape=(tf.shape(nextq)[0],tf.shape(nextq)[1],3))

首先,您可以使用tf.one_hot构建mask来描述该索引处的synthetic_answers是否等于1

nextq_index = tf.argmax(nextq,axis=1)
# [1 2]
nextq_index_hot = tf.one_hot(nextq_index,depth=nextq.shape[1],dtype=tf.int32)
# [[0 1 0]
#  [0 0 1]]
mask = tf.logical_and(tf.equal(nextq_index_hot,synthetic_answers),tf.equal(nextq_index_hot,1))
# [[False  True False]
#  [False False False]]

然后将mask展开为与cur_qinput相同的形状

mask = tf.one_hot(tf.cast(mask,dtype=tf.int32)+1,depth=3)
# [[[0. 1. 0.]
#   [0. 0. 1.]
#   [0. 1. 0.]]
#
#  [[0. 1. 0.]
#   [0. 1. 0.]
#   [0. 1. 0.]]]

最后,您可以tf.where1分配给cur_qinput

scatter = tf.where(tf.equal(mask,1),tf.ones_like(cur_qinput),cur_qinput)

with tf.Session() as sess:
    cur_qinput_val,scatter_val = sess.run([cur_qinput,scatter])
    print(cur_qinput_val)
    print(scatter_val)
[[[ 1.3651905  -0.96688586  0.74061954]
  [-1.1236337  -0.6730857  -0.8439895 ]
  [-0.52024084  1.1968751   0.79242617]]

 [[ 1.4969068  -0.12403865  0.06582119]
  [ 0.79385823 -0.7952771  -0.8562217 ]
  [-0.05428046  1.4613343   0.2726114 ]]]
[[[ 1.3651905   1.          0.74061954]
  [-1.1236337  -0.6730857   1.        ]
  [-0.52024084  1.          0.79242617]]

 [[ 1.4969068   1.          0.06582119]
  [ 0.79385823  1.         -0.8562217 ]
  [-0.05428046  1.          0.2726114 ]]]

相关问题 更多 >