训练一个BERT并耗尽Google Colab的内存

2024-05-16 03:41:29 发布

您现在位置:Python中文网/ 问答频道 /正文

即使在我购买了使用25gb内存的google colab pro之后,我的内存也一直在耗尽。我不知道为什么会这样。我尝试了所有可能的内核(Google colab、Google colab pro、Kaggle内核、Amazon Sagemaker、Google云平台)。我将批量大小减少到8,但没有成功

我的目标是在Deep Pavlov(带有俄语文本分类扩展)中训练Bert预测推特的情绪。这是一个有5个类别的多类别分类

这是我的全部代码:

!pip3 install deeppavlov
import pandas as pd
train_df = pd.read_csv('train_pikabu.csv')
test_df = pd.read_csv('test_pikabu.csv')
val_df = pd.read_csv('validation_pikabu.csv')

from deeppavlov.dataset_readers.basic_classification_reader import BasicClassificationDatasetReader
# read data from particular columns of `.csv` file
data = BasicClassificationDatasetReader().read(
   data_path='./',
   train='train_pikabu.csv',
   valid="validation_pikabu_a.csv", 
   test="test_pikabu.csv",
   x = 'content',
   y = 'emotions'
   )

from deeppavlov.dataset_iterators.basic_classification_iterator import 
BasicClassificationDatasetIterator
# initializing an iterator
iterator = BasicClassificationDatasetIterator(data, seed=42, shuffle=True)

!python -m deeppavlov install squad_bert
from deeppavlov.models.preprocessors.bert_preprocessor import BertPreprocessor
bert_preprocessor = BertPreprocessor(vocab_file="./bert/vocab.txt",
                                 do_lower_case=False,
                                 max_seq_length=256)

from deeppavlov.core.data.simple_vocab import SimpleVocabulary
vocab = SimpleVocabulary(save_path="./binary_classes.dict")
iterator.get_instances(data_type="train")
vocab.fit(iterator.get_instances(data_type="train")[1])

from deeppavlov.models.preprocessors.one_hotter import OneHotter
one_hotter = OneHotter(depth=vocab.len, 
                   single_vector=True  # means we want to have one vector per sample
                  )

from deeppavlov.models.classifiers.proba2labels import Proba2Labels
prob2labels = Proba2Labels(max_proba=True)

from deeppavlov.models.bert.bert_classifier import BertClassifierModel
from deeppavlov.metrics.accuracy import sets_accuracy

bert_classifier = BertClassifierModel(
 n_classes=vocab.len,
 return_probas=True,
 one_hot_labels=True,
 bert_config_file="./bert/bert_config.json",
 pretrained_bert="./bert/bert_model.ckpt",
 save_path="sst_bert_model/model",
 load_path="sst_bert_model/model",
 keep_prob=0.5,
 learning_rate=1e-05,
 learning_rate_drop_patience=5,
 learning_rate_drop_div=2.0
 )


 # Method `get_instances` returns all the samples of particular data field
 x_valid, y_valid = iterator.get_instances(data_type="valid")
 # You need to save model only when validation score is higher than previous one.
 # This variable will contain the highest accuracy score
 best_score = 0.
 patience = 2
 impatience = 0

 # let's train for 3 epochs
 for ep in range(3):

     nbatches = 0
     for x, y in iterator.gen_batches(batch_size=8, 
                                 data_type="train", shuffle=True):
        x_feat = bert_preprocessor(x)
        y_onehot = one_hotter(vocab(y))
        bert_classifier.train_on_batch(x_feat, y_onehot)
        print("Batch done\n")
        nbatches += 1
    
        if nbatches % 1 == 0:
            # validating every 100 batches
            y_valid_pred = bert_classifier(bert_preprocessor(x_valid))
            score = sets_accuracy(y_valid, vocab(prob2labels(y_valid_pred)))
            print("Batches done: {}. Valid Accuracy: {}".format(nbatches, score))
        
     y_valid_pred = bert_classifier(bert_preprocessor(x_valid))
     score = sets_accuracy(y_valid, vocab(prob2labels(y_valid_pred)))
     print("Epochs done: {}. Valid Accuracy: {}".format(ep + 1, score))
     if score > best_score:
          bert_classifier.save()
          print("New best score. Saving model.")
          best_score = score    
          impatience = 0
     else:
        impatience += 1
        if impatience == patience:
             print("Out of patience. Stop training.")
             break

它最多运行一批,然后粉碎


Tags: csvfromimporttruedatamodeltrainone