假设我有一个2D numpy数组。给定n,我想把矩阵中除前n以外的所有元素都加上
我试过idx = (-y_pred).argsort(axis=-1)[:, :n]
确定最大n值的索引是什么,但是idx
形状是[H,W,n],我不明白为什么。你知道吗
我试过-
sorted_list = sorted(y_pred, key=lambda x: x[0], reverse=True)
top_ten = sorted_list[:10]
但它并没有真正返回前十大指数。你知道吗
有没有一种有效的方法来找出前n个指数,然后把其余的归零?你知道吗
编辑 输入是一个NxM值矩阵,输出是大小为NxM的相同矩阵,因此除了对应于前10个值的索引外,所有值都是0
基于How do I get indices of N maximum values in a NumPy array?的思想,这里有一种使用^{} 的方法
下面的代码将使
NxM
矩阵X
无效。你知道吗注意:当存在重复值时,此方法可以返回包含多于
n
个非零元素的矩阵。你知道吗相关问题 更多 >
编程相关推荐