ValueError: 数组的真值不明确
我正在尝试运行以下代码 :( 这是一个用Python写的简单K均值算法的代码。这个算法的过程分为两步,直到每个数据点被分配到的簇和中心点不再变化为止。这个算法能保证收敛,但得到的结果可能只是一个局部最优解。实际上,这个算法会运行多次,然后取平均值。
import numpy as np
import random
from numpy import *
points = [[1,1],[1.5,2],[3,4],[5,7],[3.5,5],[4.5,5], [3.5,4]]
def cluster(points,center):
clusters = {}
for x in points:
z= min([(i[0], np.linalg.norm(x-center[i[0]])) for i in enumerate(center)], key=lambda t:t[1])
try:
clusters[z].append(x)
except KeyError:
clusters[z]=[x]
return clusters
def update(oldcenter,clusters):
d=[]
r=[]
newcenter=[]
for k in clusters:
if k[0]==0:
d.append(clusters[(k[0],k[1])])
else:
r.append(clusters[(k[0],k[1])])
c=np.mean(d, axis=0)
u=np.mean(r,axis=0)
newcenter.append(c)
newcenter.append(u)
return newcenter
def shouldStop(oldcenter,center, iterations):
MAX_ITERATIONS=4
if iterations > MAX_ITERATIONS: return True
return (oldcenter == center)
def kmeans():
points = np.array([[1,1],[1.5,2],[3,4],[5,7],[3.5,5],[4.5,5], [3.5,4]])
clusters={}
iterations = 0
oldcenter=([[],[]])
center= ([[1,1],[5,7]])
while not shouldStop(oldcenter, center, iterations):
# Save old centroids for convergence test. Book keeping.
oldcenter=center
iterations += 1
clusters=cluster(points,center)
center=update(oldcenter,clusters)
return (center,clusters)
kmeans()
但是现在我卡住了。有没有人能帮我一下呢?
Traceback (most recent call last):
File "has_converged.py", line 64, in <module>
(center,clusters)=kmeans()
File "has_converged.py", line 55, in kmeans
while not shouldStop(oldcenter, center, iterations):
File "has_converged.py", line 46, in shouldStop
return (oldcenter == center)
ValueError: The truth value of an array with more than one element is ambiguous.
Use a.any() or a.all()
1 个回答
8
根据错误提示,你不能用 ==
来比较两个数组,这在 NumPy 中是行不通的:
>>> a = np.random.randn(5)
>>> b = np.random.randn(5)
>>> a
array([-0.28636246, 0.75874234, 1.29656196, 1.19471939, 1.25924266])
>>> b
array([-0.13541816, 1.31538069, 1.29514837, -1.2661043 , 0.07174764])
>>> a == b
array([False, False, False, False, False], dtype=bool)
使用 ==
的结果是一个逐元素的布尔数组,也就是说,它会告诉你每个元素是否相等。你可以用 all
方法来检查这个数组是不是全部都是“真”的:
>>> (a == b).all()
False
不过,用这种方法来检查中心点(centroids)是否变化是不太可靠的,因为会有四舍五入的问题。你可能想用 np.allclose
来代替。