我有一个张量nextq
,它是某个问题集的概率分布。我对synthetic_answers
中的每个问题都有可能的答案,即0或1
在nextq
中为批处理中的每个向量找到具有最大值的索引
如果该索引处的synthetic_answers
为1,则将该索引处cur_qinput
的第三个特征设置为1,否则设置第二个特征
这是一些非函数性代码,它在for循环中是非函数性的,因为我不知道如何正确地将张量与其他张量/赋值切分,我只是尝试用python语法编写它,以明确我的意图
#nextq shape = batch_size x q_size
#nextq_index shape = batch_size
nextq_index = tf.argmax(nextq,axis=1)
#synthetic_answers shape = batch_size x q_size
#cur_qinput shape = batch_size x q_size x 3
#"iterate over batch", doesn't actually work and I guess needs to be done entirely differently
for k in tf.range(tf.shape(nextq_index)[0]):
cur_qinput[k,nextq_index[k],1+synthetic_answers[k,nextq_index[k]]]=1
我假设你的数据如下,因为这个问题没有例子
首先,您可以使用
tf.one_hot
构建mask
来描述该索引处的synthetic_answers
是否等于1
然后将
mask
展开为与cur_qinput
相同的形状最后,您可以
tf.where
将1
分配给cur_qinput
相关问题 更多 >
编程相关推荐