2024-04-25 20:23:08 发布
网友
我正在学习Scikit learn中的增量学习算法。sci-kit-learn中的SGD是一种允许通过传递块/批来增量学习的算法。你知道吗
增量学习的目的是将整个训练数据保存在记忆中。因此,学习大数据集是可能的,这些数据集不适合作为一个整体存储在内存中。如果训练数据逐段可用,增量学习也很有用。你知道吗
随机梯度下降法(SGD)在内存中不保留任何批,只保留正在处理的批。然而,这并不意味着它会立即忘记过去的补丁。批处理用于计算梯度,用于更新模型系数。因此,尽管数据本身被丢弃,但批处理中包含的信息仍保留在模型中。你知道吗
因为梯度是用最近的批更新的,所以较新的批比较旧的批对模型的当前训练状态有更大的影响。您可以说,最近的批处理在模型的内存中更加生动,而它会逐渐忘记较旧的批处理。你知道吗
下面是一个玩具示例来说明这个问题(底部的代码):
一个SGD分类器在前100批中用3个类进行增量训练。在100-200批中,训练数据中不存在第3类。很明显,分类器“忘记”了之前关于这个类的所有知识。您可以将此效果标记为“灾难性遗忘”,或者您可以将其视为需要的“适应数据的变化”;解释取决于用例。你知道吗
所以,是的,新加坡元似乎确实受到catastrophic forgetting的影响。不过,我认为这没什么大不了的;只是在特定应用程序中设计培训策略时必须注意的一点。你知道吗
import numpy as np from sklearn.linear_model import SGDClassifier from sklearn.datasets import make_blobs from sklearn.metrics import confusion_matrix import matplotlib.pyplot as plt np.random.seed(42) n_features = 150 centers = np.concatenate([np.eye(3)*3, np.zeros((3, n_features-3))], axis=1) x_test, y_test = make_blobs([100, 100, 100], centers=centers) cla = SGDClassifier() performance = [] def train_some_batches(n_samples_per_class): for _ in range(100): x_batch, y_batch = make_blobs(n_samples_per_class, centers=centers) cla.partial_fit(x_batch, y_batch, classes=[0, 1, 2]) conf = confusion_matrix(y_test, cla.predict(x_test)) performance.append(np.diag(conf) / np.sum(conf, axis=1)) train_some_batches([50, 50, 50]) train_some_batches([50, 50, 0]) plt.plot(performance) plt.legend(['class 1', 'class 2', 'class 3']) plt.xlabel('training batches') plt.ylabel('accuracy') plt.show()
增量学习的目的是将整个训练数据保存在记忆中。因此,学习大数据集是可能的,这些数据集不适合作为一个整体存储在内存中。如果训练数据逐段可用,增量学习也很有用。你知道吗
随机梯度下降法(SGD)在内存中不保留任何批,只保留正在处理的批。然而,这并不意味着它会立即忘记过去的补丁。批处理用于计算梯度,用于更新模型系数。因此,尽管数据本身被丢弃,但批处理中包含的信息仍保留在模型中。你知道吗
因为梯度是用最近的批更新的,所以较新的批比较旧的批对模型的当前训练状态有更大的影响。您可以说,最近的批处理在模型的内存中更加生动,而它会逐渐忘记较旧的批处理。你知道吗
下面是一个玩具示例来说明这个问题(底部的代码):
一个SGD分类器在前100批中用3个类进行增量训练。在100-200批中,训练数据中不存在第3类。很明显,分类器“忘记”了之前关于这个类的所有知识。您可以将此效果标记为“灾难性遗忘”,或者您可以将其视为需要的“适应数据的变化”;解释取决于用例。你知道吗
所以,是的,新加坡元似乎确实受到catastrophic forgetting的影响。不过,我认为这没什么大不了的;只是在特定应用程序中设计培训策略时必须注意的一点。你知道吗
相关问题 更多 >
编程相关推荐