4层感知器神经网络分类MNIST数字数据集工作不正常

2024-05-16 14:22:35 发布

您现在位置:Python中文网/ 问答频道 /正文

我刚开始学习机器学习,几乎没有编码背景,我想创建一个简单的感知器神经网络,学习如何在不使用机器学习库的情况下对MNIST数字数据集进行分类。在output_data下,我将程序输出的第一个和最后80个print语句放入其中。我的程序似乎正在学习最小化代价函数,但它会一次又一次地猜同一个数字。我弄乱了学习速度和批量大小,但它们似乎并没有改善任何东西。如果有人能把我引向正确的方向,我将不胜感激。谢谢

mnist_识别

这是一个4层神经网络,用于对mnist数据集中的手写数字进行分类。 该程序仅使用纯python和numpy计算代价函数的梯度下降 (∑(实际-目标)^2)并相应更改权重。 每次迭代后,程序打印训练数据的数字,即程序的猜测 以及与该迭代相关的成本。

导入MNIST培训数据

with np.load('mnist.npz') as data:
    training_images = data['training_images']
    training_labels = data['training_labels']

建立神经网络并定义S形函数
self.mtrx保存每个级别的神经元
自我重量、偏差、梯度保持重量、偏差和L-1级之间的梯度值

class NeuralNetwork:

    def __init__(self, rows, columns=0):
        self.mtrx = np.zeros((rows, 1))
        self.weight = np.random.random((rows, columns)) / columns ** .5
        self.bias = np.random.random((rows, 1)) * -1.0
        self.grad = np.zeros((rows, columns))

    def sigmoid(self):
        return 1 / (1 + np.exp(-self.mtrx))

    def sigmoid_derivative(self):
        return self.sigmoid() * (1.0 - self.sigmoid())

初始化神经网络层次

lvl_input = NeuralNetwork(784)
lvl_one = NeuralNetwork(200, 784)
lvl_two = NeuralNetwork(200, 200)
lvl_output = NeuralNetwork(10, 200)

正向和反向传播函数

def forward_prop():
    lvl_one.mtrx = lvl_one.weight.dot(lvl_input.mtrx) + lvl_one.bias
    lvl_two.mtrx = lvl_two.weight.dot(lvl_one.sigmoid()) + lvl_two.bias
    lvl_output.mtrx = lvl_output.weight.dot(lvl_two.sigmoid()) + lvl_output.bias


def back_prop(actual):
    val = np.zeros((10, 1))
    val[actual] = 1

    delta_3 = (lvl_output.sigmoid() - val) * lvl_output.sigmoid_derivative()
    delta_2 = np.dot(lvl_output.weight.transpose(), delta_3) * lvl_two.sigmoid_derivative()
    delta_1 = np.dot(lvl_two.weight.transpose(), delta_2) * lvl_one.sigmoid_derivative()

    lvl_output.grad = lvl_two.sigmoid().transpose() * delta_3
    lvl_two.grad = lvl_one.sigmoid().transpose() * delta_2
    lvl_one.grad = lvl_input.sigmoid().transpose() * delta_1

将mnist数据存储到np.array中

def make_image(c): 
    lvl_input.mtrx = training_images[c]

评价成本函数

def cost(actual):
    val = np.zeros((10, 1))
    val[actual] = 1
    cost_val = (lvl_output.sigmoid() - val) ** 2
    return np.sum(cost_val)

从权重中减去梯度并初始化学习率

learning_rate = .01

def update():
    lvl_output.weight -= learning_rate * lvl_output.grad
    lvl_two.weight -= learning_rate * lvl_two.grad
    lvl_one.weight -= learning_rate * lvl_one.grad

训练神经网络
iter_1等于批次数
iter_2等于一批中的迭代次数

iter_1 = 50
iter_2 = 100

for batch_num in range(iter_1):
    update()
    for batches in range(iter_2):
        make_image(counter)
        num = np.argmax(training_labels[counter])
        counter += 1
        forward_prop()
        back_prop(num)
        print("actual: ", num, "     guess: ", np.argmax(lvl_output.mtrx), "     cost", cost(num))

输出数据

FIRST 80 ITERATIONS:

actual:  5      guess:  3      cost 8.967940654671088
actual:  0      guess:  3      cost 8.96727511953835
actual:  4      guess:  3      cost 8.966336311471029
actual:  1      guess:  3      cost 8.964614419297058
actual:  9      guess:  3      cost 8.969134701891605
actual:  2      guess:  3      cost 8.967053265932318
actual:  1      guess:  3      cost 8.964824848818395
actual:  3      guess:  3      cost 8.966473334609903
actual:  1      guess:  3      cost 8.960864501044062
actual:  4      guess:  3      cost 8.966927097539942
actual:  3      guess:  3      cost 8.96602960141387
actual:  5      guess:  3      cost 8.96457467709148
actual:  3      guess:  3      cost 8.966463452568336
actual:  6      guess:  3      cost 8.967170896271007
actual:  1      guess:  3      cost 8.961504554251428
actual:  7      guess:  3      cost 8.970226265002914
actual:  2      guess:  3      cost 8.966534186296752
actual:  8      guess:  3      cost 8.96806492904598
actual:  6      guess:  3      cost 8.963241663267867
actual:  9      guess:  3      cost 8.967891094208154
actual:  4      guess:  3      cost 8.968165257872185
actual:  0      guess:  3      cost 8.967495671691166
actual:  9      guess:  3      cost 8.967110016262358
actual:  1      guess:  3      cost 8.964392716554022
actual:  1      guess:  3      cost 8.965993742374005
actual:  2      guess:  3      cost 8.967551426336762
actual:  4      guess:  3      cost 8.963912501397779
actual:  3      guess:  3      cost 8.966729854711515
actual:  2      guess:  3      cost 8.967571805901548
actual:  7      guess:  3      cost 8.968031754047926
actual:  3      guess:  3      cost 8.965675580057647
actual:  8      guess:  3      cost 8.968461428875388
actual:  6      guess:  3      cost 8.965166939019545
actual:  9      guess:  3      cost 8.968763204750987
actual:  0      guess:  3      cost 8.967540507250032
actual:  5      guess:  3      cost 8.965545688959857
actual:  6      guess:  3      cost 8.967425028943891
actual:  0      guess:  3      cost 8.967566971035732
actual:  7      guess:  3      cost 8.969754175784066
actual:  6      guess:  3      cost 8.96702539598315
actual:  1      guess:  3      cost 8.96299163011006
actual:  8      guess:  3      cost 8.968175816089042
actual:  7      guess:  3      cost 8.966425056294776
actual:  9      guess:  3      cost 8.96796817338183
actual:  3      guess:  3      cost 8.963755408168
actual:  9      guess:  3      cost 8.96926567423336
actual:  8      guess:  3      cost 8.967543729824387
actual:  5      guess:  3      cost 8.967286499095575
actual:  9      guess:  3      cost 8.9677253773608
actual:  3      guess:  3      cost 8.966335253428326
actual:  3      guess:  3      cost 8.962829459684784
actual:  0      guess:  3      cost 8.966443407799728
actual:  7      guess:  3      cost 8.969485491531145
actual:  4      guess:  3      cost 8.964159055804105
actual:  9      guess:  3      cost 8.968054200103934
actual:  8      guess:  3      cost 8.96719386034473
actual:  0      guess:  3      cost 8.966374739396157
actual:  9      guess:  3      cost 8.9673694447568
actual:  4      guess:  3      cost 8.966879451409914
actual:  1      guess:  3      cost 8.963085409100401
actual:  4      guess:  3      cost 8.96659585831308
actual:  4      guess:  3      cost 8.964656458614465
actual:  6      guess:  3      cost 8.965997487130116
actual:  0      guess:  3      cost 8.966455019673488
actual:  4      guess:  3      cost 8.966295463866858
actual:  5      guess:  3      cost 8.964316168401316
actual:  6      guess:  3      cost 8.965707649845031
actual:  1      guess:  3      cost 8.962325088384468
actual:  0      guess:  3      cost 8.965286965834165
actual:  0      guess:  3      cost 8.966383201903987
actual:  1      guess:  3      cost 8.964628836235496
actual:  7      guess:  3      cost 8.968386161233427
actual:  1      guess:  3      cost 8.959224945536565
actual:  6      guess:  3      cost 8.965609436078736
actual:  3      guess:  3      cost 8.964347604784498
actual:  0      guess:  3      cost 8.96611237840382
actual:  2      guess:  3      cost 8.965381094981499
actual:  1      guess:  3      cost 8.963120677127996
actual:  1      guess:  3      cost 8.96405639510175
actual:  7      guess:  3      cost 8.968933638290096

LAST 80 ITERATIONS:

actual:  5      guess:  7      cost 1.1211637067627063
actual:  6      guess:  7      cost 1.0552531331006683
actual:  6      guess:  7      cost 1.0554754137400155
actual:  1      guess:  7      cost 1.1111335796511572
actual:  7      guess:  7      cost 0.39288606314850105
actual:  0      guess:  7      cost 0.9725346172858359
actual:  8      guess:  7      cost 1.1328622780173858
actual:  7      guess:  7      cost 0.3948858827858138
actual:  9      guess:  7      cost 1.1062621497616252
actual:  6      guess:  7      cost 1.0553258043668636
actual:  0      guess:  7      cost 0.9722536736320497
actual:  9      guess:  7      cost 1.104398056849663
actual:  1      guess:  7      cost 1.1107715879340798
actual:  6      guess:  7      cost 1.0548862811967128
actual:  2      guess:  7      cost 1.0591572140435124
actual:  9      guess:  7      cost 1.1061966300353054
actual:  3      guess:  7      cost 1.0898127633513548
actual:  2      guess:  7      cost 1.060007348683317
actual:  4      guess:  7      cost 1.0866881038452423
actual:  5      guess:  7      cost 1.121417891094558
actual:  5      guess:  7      cost 1.120467850157213
actual:  1      guess:  7      cost 1.108298490272241
actual:  6      guess:  7      cost 1.0553087806403634
actual:  1      guess:  7      cost 1.110088687169048
actual:  7      guess:  7      cost 0.39320440442417653
actual:  7      guess:  7      cost 0.393599379050548
actual:  8      guess:  7      cost 1.132814048747761
actual:  5      guess:  7      cost 1.121939678599679
actual:  9      guess:  7      cost 1.106025451446512
actual:  1      guess:  0      cost 0.9622564124884784
actual:  0      guess:  0      cost 0.8593530649130317
actual:  4      guess:  0      cost 0.9393396290138495
actual:  1      guess:  0      cost 0.9622940421810162
actual:  4      guess:  0      cost 0.9404822496691384
actual:  2      guess:  0      cost 0.9178125638868906
actual:  3      guess:  0      cost 0.9426691285338776
actual:  3      guess:  0      cost 0.942690753734257
actual:  1      guess:  0      cost 0.9623297861049644
actual:  4      guess:  0      cost 0.940371478405245
actual:  1      guess:  0      cost 0.9623031268176
actual:  5      guess:  0      cost 0.9712347696568674
actual:  0      guess:  0      cost 0.8590451041238242
actual:  6      guess:  0      cost 0.9144392962122468
actual:  7      guess:  0      cost 0.8587038959856634
actual:  7      guess:  0      cost 0.8595722910216409
actual:  3      guess:  0      cost 0.9426103963458455
actual:  8      guess:  0      cost 0.9818405405702578
actual:  9      guess:  0      cost 0.8595962707758777
actual:  5      guess:  0      cost 0.9714964202087537
actual:  1      guess:  0      cost 0.9624568499960074
actual:  1      guess:  0      cost 0.9618997440804723
actual:  1      guess:  0      cost 0.9616207680615205
actual:  9      guess:  0      cost 0.8592443630955677
actual:  5      guess:  0      cost 0.9713378210566505
actual:  9      guess:  0      cost 0.8594114362413322
actual:  1      guess:  0      cost 0.9618546105579062
actual:  7      guess:  0      cost 0.8588493542347546
actual:  1      guess:  0      cost 0.9618718521886025
actual:  1      guess:  0      cost 0.9617515805652713
actual:  6      guess:  0      cost 0.9136159964857046
actual:  0      guess:  0      cost 0.8592350961109494
actual:  8      guess:  0      cost 0.9818162514673088
actual:  9      guess:  0      cost 0.8595943721384651
actual:  7      guess:  0      cost 0.8594020275011919
actual:  0      guess:  0      cost 0.8594831625021346
actual:  2      guess:  0      cost 0.9166105269490908
actual:  5      guess:  0      cost 0.9715520240455258
actual:  3      guess:  0      cost 0.9426616801200285
actual:  9      guess:  0      cost 0.859581449204677
actual:  6      guess:  0      cost 0.9142265327401935
actual:  7      guess:  0      cost 0.8588630292850988
actual:  8      guess:  0      cost 0.981801914298837
actual:  1      guess:  0      cost 0.9618275889195599
actual:  0      guess:  0      cost 0.8587978160230371
actual:  7      guess:  0      cost 0.8591755187382748
actual:  3      guess:  0      cost 0.9427927022073027
actual:  2      guess:  0      cost 0.9175924937090146
actual:  1      guess:  0      cost 0.9608604565831466
actual:  2      guess:  0      cost 0.9171708777221477
--- 3.403271198272705 seconds ---

Tags: selfoutputdefnpvalonedeltaweight