如何从所选标签的数组中从mnist创建自定义小批量?

2024-04-20 13:04:10 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试从MNIST创建一个所有数字都在0到9之间的小批量(10个元素)

我希望避免在标签ector中的所有元素上循环以逐个检查数字

实现这一目标的最简单方法是什么

我想我可以创建一个从0到9的标签数组“all_digits”,然后将其与我的mnist_标签列表“train_标签”进行比较(1D数组(n个元素)

我试图得到一个包含所有对等式检查的矩阵(n x 10) 但是我不能直接使用==,也没有广播版本的numpy.equal()

我也没有一个明确的想法,如何处理矩阵之后

import numpy as np
train_labels = np.random.randint(0,10,100)
all_digits = np.arange(10)
# doing a difference for now
train_labels.reshape((-1,1)) - all_digits

Tags: numpy元素目标labelsnptrain矩阵数字