手写数字

2024-04-18 23:54:44 发布

您现在位置:Python中文网/ 问答频道 /正文

我试着用python编写一个脚本来识别手写数字,使用这个数据集:http://deeplearning.net/data/mnist/mnist.pkl.gz。你知道吗

关于这个问题和我试图实现的算法的更多信息可以在以下链接找到:http://neuralnetworksanddeeplearning.com/chap1.html

我已经实现了一个分类算法,每个数字使用一个感知器。你知道吗

import cPickle, gzip
import numpy as np

f = gzip.open('mnist.pkl.gz', 'rb')
train_set, valid_set, test_set = cPickle.load(f)
f.close()

def activation(x):
    if x > 0:
        return 1
    return 0

bias = 0.5
learningRate = 0.01

images = train_set[0]
targets = train_set[1]

weights = np.random.uniform(0,1,(10,784))
for nr in range(0,10):
    for i in range(0,49999):
        x = images[i]
        t = targets[i]
        z = np.dot(weights[nr],x) + bias
        output = activation(z)
        weights[nr] = weights[nr] + (t - output) * x * learningRate
        bias = bias + (t - output) * learningRate

images = test_set[0]
targets = test_set[1]

OK = 0

for i in range range(0, 10000):
    vec = []
    for j in range(0,10):
        vec.append(np.dot(weights[j],images[i]))
    if np.argmax(vec) == targets[i]:
        OK = OK + 1

print("The network recognized " + str(OK) +'/'+ "10000")

我通常能识别10%的数字,这意味着我的算法什么也没做,和随机算法是一样的。你知道吗

即使我知道这个问题很流行,而且我可以很容易地在web上找到另一个解决方案,我仍然请求您帮助我识别代码中的错误。你知道吗

也许我把learningRate,bias和weights的值初始化错了。你知道吗


Tags: in算法fornprangeok数字nr
1条回答
网友
1楼 · 发布于 2024-04-18 23:54:44

多亏了@Kevinj22和其他一些工具,我最终解决了这个问题。你知道吗

import cPickle, gzip
import numpy as np

f = gzip.open('mnist.pkl.gz', 'rb')
train_set, valid_set, test_set = cPickle.load(f)
f.close()

def activation(x):
    if x > 0:
        return 1
    return 0

learningRate = 0.01

images = train_set[0]
targets = train_set[1]

weights = np.random.uniform(0,1,(10,784))

for nr in range(0,10):
    for i in range(0,50000):
        x = images[i]
        t = targets[i]
        z = np.dot(weights[nr],x)
        output = activation(z)
        if nr == t:
            target = 1
        else:
            target = 0
        adjust = np.multiply((target - output) * learningRate, x)
        weights[nr] = np.add(weights[nr], adjust)

images = test_set[0]
targets = test_set[1]

OK = 0

for i in range(0, 10000):
    vec = []
    for j in range(0,10):
        vec.append(np.dot(weights[j],images[i]))
    if np.argmax(vec) == targets[i]:
        OK = OK + 1

print("The network recognized " + str(OK) +'/'+ "10000")

这是我的最新代码。我在第一次尝试中没有引入损失计算。我还消除了偏见,因为我发现它在我的实现中没有用处。你知道吗

我运行这段代码10次,平均准确率为88%

相关问题 更多 >