我正在尝试从MNIST创建一个所有数字都在0到9之间的小批量(10个元素)
我希望避免在标签ector中的所有元素上循环以逐个检查数字
实现这一目标的最简单方法是什么
我想我可以创建一个从0到9的标签数组“all_digits”,然后将其与我的mnist_标签列表“train_标签”进行比较(1D数组(n个元素)
我试图得到一个包含所有对等式检查的矩阵(n x 10) 但是我不能直接使用==,也没有广播版本的numpy.equal()
我也没有一个明确的想法,如何处理矩阵之后
import numpy as np
train_labels = np.random.randint(0,10,100)
all_digits = np.arange(10)
# doing a difference for now
train_labels.reshape((-1,1)) - all_digits
目前没有回答
相关问题 更多 >
编程相关推荐