image_size = 28
num_labels = 10
def reformat(dataset, labels):
dataset = dataset.reshape((-1, image_size * image_size)).astype(np.float32)
# Map 0 to [1.0, 0.0, 0.0 ...], 1 to [0.0, 1.0, 0.0 ...]
labels = (np.arange(num_labels) == labels[:,None]).astype(np.float32)
return dataset, labels
train_dataset, train_labels = reformat(train_dataset, train_labels)
valid_dataset, valid_labels = reformat(valid_dataset, valid_labels)
test_dataset, test_labels = reformat(test_dataset, test_labels)
print('Training set', train_dataset.shape, train_labels.shape)
print('Validation set', valid_dataset.shape, valid_labels.shape)
print('Test set', test_dataset.shape, test_labels.shape)
这条线是什么意思?你知道吗
labels = (np.arange(num_labels) == labels[:,None]).astype(np.float32)
对于Numpy数组,
==
运算符是返回布尔数组的按元素操作。astype
函数将布尔值True
转换为1.0
,False
转换为0.0
,如注释中所述。你知道吗https://docs.python.org/3/reference/expressions.html#value-comparisons描述值比较,如
==
。虽然默认比较是identity
x is y
,但它首先检查任一参数是否实现了__eq__
方法。数字、列表和字典实现它们自己的版本。numpy
也是如此。你知道吗关于
numpy
__eq__
的独特之处在于,如果可能的话,它会逐个元素进行比较,并返回相同大小的布尔数组。你知道吗一个常见的SO问题是“为什么我会得到这个ValueError?”你知道吗
这是因为比较生成的数组有多个真/假值。你知道吗
比较浮动也是一个常见的问题。检查一下
isclose
和allclose
这个问题出现了。你知道吗在numpy中,
==
操作符在比较两个numpy数组时意味着不同的含义(正如在该行中所做的那样),因此是的,它在这个意义上是重载的。它对两个numpy数组进行elementwise比较,并返回与两个输入大小相同的布尔numpy数组。其他比较也一样,比如>=
、<
等例如。 你知道吗
相关问题 更多 >
编程相关推荐