SGD增量学习在scikitlearn中的表现如何?

2024-04-25 20:23:08 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在学习Scikit learn中的增量学习算法。sci-kit-learn中的SGD是一种允许通过传递块/批来增量学习的算法。你知道吗

  • sci kit learn是否将训练数据的所有批都保存在内存中?你知道吗
  • 或者它是否将内存中的块/批保持在一定的大小?你知道吗
  • 或者它在内存中训练时只保留一个块/批,而在训练后删除其他训练的块/批?这是否意味着它会遭受灾难性的遗忘?你知道吗

Tags: 数据内存算法scikitlearn增量kitsci
1条回答
网友
1楼 · 发布于 2024-04-25 20:23:08

增量学习的目的是将整个训练数据保存在记忆中。因此,学习大数据集是可能的,这些数据集不适合作为一个整体存储在内存中。如果训练数据逐段可用,增量学习也很有用。你知道吗

随机梯度下降法(SGD)在内存中不保留任何批,只保留正在处理的批。然而,这并不意味着它会立即忘记过去的补丁。批处理用于计算梯度,用于更新模型系数。因此,尽管数据本身被丢弃,但批处理中包含的信息仍保留在模型中。你知道吗

因为梯度是用最近的批更新的,所以较新的批比较旧的批对模型的当前训练状态有更大的影响。您可以说,最近的批处理在模型的内存中更加生动,而它会逐渐忘记较旧的批处理。你知道吗

下面是一个玩具示例来说明这个问题(底部的代码):

enter image description here

一个SGD分类器在前100批中用3个类进行增量训练。在100-200批中,训练数据中不存在第3类。很明显,分类器“忘记”了之前关于这个类的所有知识。您可以将此效果标记为“灾难性遗忘”,或者您可以将其视为需要的“适应数据的变化”;解释取决于用例。你知道吗

所以,是的,新加坡元似乎确实受到catastrophic forgetting的影响。不过,我认为这没什么大不了的;只是在特定应用程序中设计培训策略时必须注意的一点。你知道吗

import numpy as np
from sklearn.linear_model import SGDClassifier
from sklearn.datasets import make_blobs
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

np.random.seed(42)
n_features = 150
centers = np.concatenate([np.eye(3)*3, np.zeros((3, n_features-3))], axis=1)

x_test, y_test = make_blobs([100, 100, 100], centers=centers)

cla = SGDClassifier()
performance = []

def train_some_batches(n_samples_per_class):
    for _ in range(100):
        x_batch, y_batch = make_blobs(n_samples_per_class, centers=centers)
        cla.partial_fit(x_batch, y_batch, classes=[0, 1, 2])
        conf = confusion_matrix(y_test, cla.predict(x_test))
        performance.append(np.diag(conf) / np.sum(conf, axis=1))

train_some_batches([50, 50, 50])
train_some_batches([50, 50, 0])            

plt.plot(performance)
plt.legend(['class 1', 'class 2', 'class 3'])
plt.xlabel('training batches')
plt.ylabel('accuracy')

plt.show()

相关问题 更多 >