"NumPy是否覆盖了==运算符,因为我无法理解跟随的Python代码"

2024-05-16 10:53:44 发布

您现在位置:Python中文网/ 问答频道 /正文

image_size = 28
num_labels = 10

def reformat(dataset, labels):
  dataset = dataset.reshape((-1, image_size * image_size)).astype(np.float32)
  # Map 0 to [1.0, 0.0, 0.0 ...], 1 to [0.0, 1.0, 0.0 ...]
  labels = (np.arange(num_labels) == labels[:,None]).astype(np.float32)
  return dataset, labels
train_dataset, train_labels = reformat(train_dataset, train_labels)
valid_dataset, valid_labels = reformat(valid_dataset, valid_labels)
test_dataset, test_labels = reformat(test_dataset, test_labels)
print('Training set', train_dataset.shape, train_labels.shape)
print('Validation set', valid_dataset.shape, valid_labels.shape)
print('Test set', test_dataset.shape, test_labels.shape)

这条线是什么意思?你知道吗

labels = (np.arange(num_labels) == labels[:,None]).astype(np.float32)

代码来自https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/udacity/2_fullyconnected.ipynb


Tags: testimagesizelabelsnptraindatasetnum
3条回答

对于Numpy数组,==运算符是返回布尔数组的按元素操作。astype函数将布尔值True转换为1.0False转换为0.0,如注释中所述。你知道吗

https://docs.python.org/3/reference/expressions.html#value-comparisons描述值比较,如==。虽然默认比较是identityx is y,但它首先检查任一参数是否实现了__eq__方法。数字、列表和字典实现它们自己的版本。numpy也是如此。你知道吗

关于numpy__eq__的独特之处在于,如果可能的话,它会逐个元素进行比较,并返回相同大小的布尔数组。你知道吗

In [426]: [1,2,3]==[1,2,3]
Out[426]: True
In [427]: z1=np.array([1,2,3]); z2=np.array([1,2,3])
In [428]: z1==z2
Out[428]: array([ True,  True,  True], dtype=bool)
In [432]: z1=np.array([1,2,3]); z2=np.array([1,2,4])
In [433]: z1==z2
Out[433]: array([ True,  True, False], dtype=bool)
In [434]: (z1==z2).astype(float)     # change bool to float
Out[434]: array([ 1.,  1.,  0.])

一个常见的SO问题是“为什么我会得到这个ValueError?”你知道吗

In [435]: if z1==z2: print('yes')
...
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

这是因为比较生成的数组有多个真/假值。你知道吗

比较浮动也是一个常见的问题。检查一下iscloseallclose这个问题出现了。你知道吗

在numpy中,==操作符在比较两个numpy数组时意味着不同的含义(正如在该行中所做的那样),因此是的,它在这个意义上是重载的。它对两个numpy数组进行elementwise比较,并返回与两个输入大小相同的布尔numpy数组。其他比较也一样,比如>=<

例如。 你知道吗

import numpy as np
print(np.array([5,8,2]) == np.array([5,3,2]))
# [True False True]
print((np.array([5,8,2]) == np.array([5,3,2])).astype(np.float32))
# [1. 0. 1.]

相关问题 更多 >