感知机学习算法收敛需要很多迭代吗?

4 投票
1 回答
4992 浏览
提问于 2025-04-18 07:59

我正在做加州理工学院机器学习课程的作业-1(http://work.caltech.edu/homework/hw1.pdf)。为了完成第7到10题,我们需要实现一个感知器(PLA)。这是我用Python写的实现:

import sys,math,random

w=[] # stores the weights
data=[] # stores the vector X(x1,x2,...)
output=[] # stores the output(y)


# returns 1 if dot product is more than 0
def sign_dot_product(x):
    global w
    dot=sum([w[i]*x[i] for i in xrange(len(w))])
    if(dot>0):
        return 1
    else :
        return -1

# checks if a point is misclassified
def is_misclassified(rand_p):
    return (True if sign_dot_product(data[rand_p])!=output[rand_p] else False)


# loads data in the following format:
# x1 x2 ... y
# In the present case for d=2
# x1 x2 y
def load_data():
    f=open("data.dat","r")
    global w
    for line in f:
        data_tmp=([1]+[float(x) for x in line.split(" ")])
        data.append(data_tmp[0:-1])
        output.append(data_tmp[-1])


def train():
    global w
    w=[ random.uniform(-1,1) for i in xrange(len(data[0]))] # initializes w with random weights
    iter=1
    while True:

        rand_p=random.randint(0,len(output)-1) # randomly picks a point
        check=[0]*len(output) # check is a list. The ith location is 1 if the ith point is correctly classified
        while not is_misclassified(rand_p):
            check[rand_p]=1
            rand_p=random.randint(0,len(output)-1)
            if sum(check)==len(output):
                print "All points successfully satisfied in ",iter-1," iterations"
                print iter-1,w,data[rand_p]
                return iter-1
        sign=output[rand_p]
        w=[w[i]+sign*data[rand_p][i] for i in xrange(len(w))] # changing weights
        if iter>1000000:
            print "greater than 1000"
            print w
            return 10000000
        iter+=1

load_data()

def simulate():
   #tot_iter=train()
    tot_iter=sum([train() for x in xrange(100)])
    print float(tot_iter)/100

simulate()

根据第7题的答案,感知器在训练集大小的情况下应该大约需要15次迭代才能收敛,但我的实现平均需要50000次迭代。训练数据应该是随机生成的,但我生成的数据是一些简单的直线,比如x=4,y=2等等。这是我得到错误答案的原因吗?还是说还有其他问题?这是我的训练数据示例(可以用y=2分开):

1 2.1 1
231 100 1
-232 1.9 -1
23 232 1
12 -23 -1
10000 1.9 -1
-1000 2.4 1
100 -100 -1
45 73 1
-34 1.5 -1

数据格式是x1 x2 输出(y)

1 个回答

4

很明显,你在学习Python和分类算法方面做得很好,付出了很多努力。

不过,由于你代码中的一些风格问题,可能会让人很难帮助你,也可能导致你和教授之间出现沟通不畅的情况。

比如,教授希望你使用“在线模式”还是“离线模式”的感知器?在“在线模式”下,你应该顺序处理每个数据点,而不应该回头再处理之前的数据点。从作业的猜测来看,似乎只需要15次迭代就能收敛,我很好奇这是否意味着前15个数据点按顺序处理会得到一个可以线性分隔你的数据集的分类器。

如果你随机抽样并允许重复,可能会导致你花费更长的时间(尽管根据数据样本的分布和大小,这种情况不太可能,因为你可以预期任何15个点的效果大致和前15个点一样)。

另一个问题是,当你检测到一个正确分类的点(即not is_misclassifiedTrue时),如果接着遇到一个新的随机点而这个点是错误分类的,那么你的代码会进入外层while循环的更大部分,然后回到顶部,覆盖check向量为全0。

这意味着,只有当你评估的随机序列(在内层while循环中)恰好是全1的情况下,代码才能检测到它已经正确分类了所有点,除非在某个特定的0上,在那次遍历数组时,它能够正确分类。

我无法准确说明为什么我认为这样会让程序运行得更慢,但看起来你的代码需要一种更严格的收敛方式,似乎必须在训练阶段的后期一次性学习所有内容,而这时已经更新了很多次。

一个简单的方法来验证我的直觉是否正确,就是把check=[0]*len(output)这一行移到while loop外面,只初始化一次。

一些让代码更易管理的建议:

  1. 不要使用全局变量。相反,让你的函数加载和准备数据后返回结果。

  2. 在一些地方,比如你写的:

    return (True if sign_dot_product(data[rand_p])!=output[rand_p] else False)

    这种写法可以简化为:

    return sign_dot_product(data[rand_p]) != output[rand_p]

    这样更容易阅读,也更直接地表达了你想检查的标准。

  3. 我怀疑效率在这里并不重要,因为这似乎是一个教学练习,但有很多方法可以重构你使用列表推导的方式,这可能会有帮助。如果可以的话,直接使用NumPy,它有原生的数组类型。看到一些操作必须用list操作来表达,真让人感到遗憾。即使你的教授不希望你使用NumPy,因为他或她想教你纯粹的基础知识,我建议你无视这个,去学习NumPy。这会在工作、实习和实际技能方面大大帮助你,尤其是在Python中处理这些操作时,而不是与原生数据类型斗争,去做它们并不擅长的事情(数组计算)。

撰写回答