如何微调Mistral-7B模型用于机器翻译?
网上有很多教程使用原始文本,并用一些复杂的语法来表示文档的边界,这些文本通过Huggingface的 datasets.Dataset
对象来访问,通常是通过 text
这个键来获取。例如:
from datasets import load_dataset
dataset_name = "mlabonne/guanaco-llama2-1k"
dataset = load_dataset(dataset_name, split="train")
dataset["text"][42]
[输出]:
<s>[INST] ¿Cuáles son los actuales presidentes de la región de Sur América? Enumérelos en una lista con su respectivo país. [/INST] A fecha del 13 de febrero de 2023, estos son los presidentes de los países de Sudamérica, según Wikipedia:
-Argentina: Alberto Fernández
-Bolivia: Luis Arce
-Brasil: Luiz Inácio Lula da Silva
-Chile: Gabriel Boric
-Colombia: Gustavo Petro
-Ecuador: Guillermo Lasso
-Paraguay: Mario Abdo Benítez
-Perú: Dina Boluarte
-Uruguay: Luis Lacalle Pou
-Venezuela: Nicolás Maduro
-Guyana: Irfaan Ali
-Surinam: Chan Santokhi
-Trinidad y Tobago: Paula-Mae Weekes </s>
但是,机器翻译的数据集通常分为两个部分,分别是源文本和目标文本,使用 sentence_eng_Latn
和 sentence_deu_Latn
这两个键来表示。例如:
valid_data = load_dataset("facebook/flores", "eng_Latn-deu_Latn", streaming=False,
split="dev")
valid_data[42]
[输出]:
{'id': 43,
'URL': 'https://en.wikinews.org/wiki/Hurricane_Fred_churns_the_Atlantic',
'domain': 'wikinews',
'topic': 'disaster',
'has_image': 0,
'has_hyperlink': 0,
'sentence_eng_Latn': 'The storm, situated about 645 miles (1040 km) west of the Cape Verde islands, is likely to dissipate before threatening any land areas, forecasters say.',
'sentence_deu_Latn': 'Prognostiker sagen, dass sich der Sturm, der etwa 645 Meilen (1040 km) westlich der Kapverdischen Inseln befindet, wahrscheinlich auflösen wird, bevor er Landflächen bedroht.'}
如何对Mistral-7b模型进行微调,以完成机器翻译任务?
1 个回答
4
关键在于重新整理传统机器翻译数据集中的数据,这些数据集通常会把源文本和目标文本分开,而我们需要把它们组合成模型所期望的格式。
具体来说,对于Mistral 7B模型,它通常期望:
- 每一行数据都要用
<s> 和
包裹起来,其中:- 输入的源句子要放在
[INST] ... [/INST]
之间 - 输出的目标句子则放在
[/INST]
符号之后
- 输入的源句子要放在
- 在
[INST] ... [/INST]
之前可以有任何的提示信息
例如,如果我们想用这样的翻译提示"将英语翻译成德语:",
valid_data = load_dataset("facebook/flores", "eng_Latn-deu_Latn", streaming=False, split="dev")
def preprocess_func(row):
return {'text': "Translate from English to German: <s>[INST] " + row['sentence_eng_Latn'] + " [INST] " + row['sentence_deu_Latn'] + " </s>"}
valid_dataset = valid_data.map(preprocess_func)
valid_dataset[42]
[out]:
{'id': 43,
'URL': 'https://en.wikinews.org/wiki/Hurricane_Fred_churns_the_Atlantic',
'domain': 'wikinews',
'topic': 'disaster',
'has_image': 0,
'has_hyperlink': 0,
'sentence_eng_Latn': 'The storm, situated about 645 miles (1040 km) west of the Cape Verde islands, is likely to dissipate before threatening any land areas, forecasters say.',
'sentence_deu_Latn': 'Prognostiker sagen, dass sich der Sturm, der etwa 645 Meilen (1040 km) westlich der Kapverdischen Inseln befindet, wahrscheinlich auflösen wird, bevor er Landflächen bedroht.',
'text': 'Translate from English to German: <s>[INST] The storm, situated about 645 miles (1040 km) west of the Cape Verde islands, is likely to dissipate before threatening any land areas, forecasters say. [INST] Prognostiker sagen, dass sich der Sturm, der etwa 645 Meilen (1040 km) westlich der Kapverdischen Inseln befindet, wahrscheinlich auflösen wird, bevor er Landflächen bedroht. </s>'}
那么正常的微调Mistral-7b脚本可以直接读取数据集中的text
键,比如:
需要
!pip install -U transformers sentencepiece datasets
!pip install -q -U transformers
!pip install -q -U accelerate
!pip install -q -U bitsandbytes
!pip install -U peft
!pip install -U trl
如果你在Jupyter环境中,安装完accelerate后需要重置内核,所以:
import os
os._exit(00)
然后:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,HfArgumentParser,TrainingArguments,pipeline, logging
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
import os,torch
from datasets import load_dataset
from trl import SFTTrainer
base_model = "mistralai/Mistral-7B-Instruct-v0.2"
new_model = "mistral_7b_flores_dev_en_de"
bnb_config = BitsAndBytesConfig(
load_in_4bit= True,
bnb_4bit_quant_type= "nf4",
bnb_4bit_compute_dtype= torch.bfloat16,
bnb_4bit_use_double_quant= False,
)
model = AutoModelForCausalLM.from_pretrained(
base_model,
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
model.config.use_cache = False
model.config.pretraining_tp = 1
model.gradient_checkpointing_enable()
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.padding_side = 'right'
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_eos_token = True
tokenizer.add_bos_token, tokenizer.add_eos_token
valid_data = load_dataset("facebook/flores", "eng_Latn-deu_Latn", streaming=False,
split="dev")
test_data = load_dataset("facebook/flores", "eng_Latn-deu_Latn", streaming=False,
split="devtest")
def preprocess_func(row):
return {'text': "Translate from English to German: <s>[INST] " + row['sentence_eng_Latn'] + " [INST] " + row['sentence_deu_Latn'] + " </s>"}
valid_dataset = valid_data.map(preprocess_func)
test_dataset = test_data.map(preprocess_func)
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj"]
)
model = get_peft_model(model, peft_config)
training_arguments = TrainingArguments(
output_dir="./results",
num_train_epochs=1,
per_device_train_batch_size=4,
gradient_accumulation_steps=1,
optim="paged_adamw_32bit",
save_steps=25,
logging_steps=25,
learning_rate=2e-4,
weight_decay=0.001,
max_grad_norm=0.3,
max_steps=-1,
warmup_ratio=0.03,
group_by_length=True,
lr_scheduler_type="constant",
report_to=None
)
trainer = SFTTrainer(
model=model,
train_dataset=valid_dataset,
peft_config=peft_config,
max_seq_length= None,
dataset_text_field="text",
tokenizer=tokenizer,
args=training_arguments,
packing= False,
)
trainer.train()