为什么在感知器算法中更新权值向量时会出现matmul不匹配错误?

2024-04-25 21:19:57 发布

您现在位置:Python中文网/ 问答频道 /正文

我开始学习最大似然法和神经网络,我开始实施感知器算法的手写体数字识别。所以代码工作正常,直到我尝试更新我的权重向量。我想numpy数组向量大小有一些问题,但我不知道如何解决这个问题。你知道吗

%matplotlib inline
%config InlineBackend.figure_format = 'svg'
import matplotlib.pyplot as plt
plt.style.use('ggplot')

import scipy as sp
from sklearn.datasets import load_digits

#This function passes the values X,Y to the Perceptron algorith and plots the graph of accuracy after each itearation. 
def digit(digit_to_recognize=5):
    # loading the usps digits dataset from sklearn repository
    n_example = 100
    X, Y = load_digits(n_class=10, return_X_y=True)
    plt.matshow(X[n_example,:].reshape(8,8));
    plt.xticks([]);plt.yticks([]);
    plt.title(Y[n_example])
    plt.savefig("usps_example.png")
    # transforming the 10-class labels into binary form
    y = sp.sign((Y==digit_to_recognize)* 1.0 - .5)
    _, acc = perceptron_train(X,y)
    plt.figure(figsize=[12,4])
    plt.plot(acc)
    plt.xlabel("Iterations");plt.ylabel("Accuracy");
    plt.savefig("learning_curve.png")


def perceptron_train(X,Y,iterations=100,eta=.01):
    acc = sp.zeros(iterations)
    # initialize weight vector
    weights = sp.random.randn(X.shape[1]) * 1e-5
    for it in sp.arange(iterations):
        # indices of misclassified data
        wrong = (sp.sign(X @ weights) != Y).nonzero()[0]
        if wrong.shape[0] > 0:
            # picking a random misclassified data point
            i = sp.random.choice(wrong,1)
            rand_ex = X[i]*Y[i]
            # update weight vector 
            weights = weights + (eta/it)*rand_ex
            # computing accuracy
            acc[it] = sp.double(sp.sum(sp.sign(X @ weights)==Y))/X.shape[0]
    # return weight vector and accuracy
    return weights,acc

这里有一个错误:

<ipython-input-128-f4e41796a9be> in perceptron_train(X, Y, iterations, eta)
     16             weights = weights + (eta/it)*rand_ex
     17             # compute accuracy
---> 18             acc[it] = sp.double(sp.sum(sp.sign(X @ weights)==Y))/X.shape[0]
     19     # return weight vector and accuracy
     20     return weights,acc

ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 1 is different from 64)

以下是用伪代码编写的感知器算法: Perceptron Algorithm


Tags: thereturnexampleitpltspetaacc