如何从张量中返回最大值及其邻域

2023-02-06 13:57:39 发布

您现在位置:Python中文网/ 问答频道 /正文

我创建了一个神经网络模型,并希望自定义损失函数

我想知道如何从张量中返回最大值及其邻域

我知道tf.argmax可以从张量返回max value的索引。但是,是否有可能得到一个新的张量,其中包括[max之前的3个值,max之后的3个值]


Tags: 函数模型valuetf神经网络max邻域损失argmax个值
1条回答
网友
1楼 · 发布于 2023-02-06 13:57:39

刚才看到您在评论中说过,您希望在非急切模式下运行它(即内部损失函数),因此tensor.numpy()不可用

让我们创建一个随机张量,并找到最大值及其邻域,而不使用numpy和任何其他在非急切模式下不可用的函数:

a = tf.random.uniform((20,), minval=0, maxval=100,dtype=tf.int32)
tf.print(a,summarize=20)
# [5 42 25 18 15 95 1 51 47 42 36 72 92 11 21 32 1 68 84 24]

start_index = tf.maximum(0,tf.subtract(tf.argmax(a,output_type=tf.int32),3))
end_index   = tf.minimum(a.shape[0],tf.add(start_index,7))

max_neighbors = a[start_index: end_index]
tf.print(max_neighbors,summarize=7)
# [25 18 15 95 1 51 47]

上面的代码,可以在非急切模式下运行。请注意,您可以将run_eagerly=True参数传递给model.compile(),以便在不受非急切模式限制的情况下编写代码(例如,能够转换为numpy),但培训将不会有效

相关问题 更多 >