如何微调Mistral-7B模型用于机器翻译?

2 投票
1 回答
249 浏览
提问于 2025-04-14 16:14

网上有很多教程使用原始文本,并用一些复杂的语法来表示文档的边界,这些文本通过Huggingface的 datasets.Dataset 对象来访问,通常是通过 text 这个键来获取。例如:

from datasets import load_dataset

dataset_name = "mlabonne/guanaco-llama2-1k"

dataset = load_dataset(dataset_name, split="train")
dataset["text"][42]

[输出]:

<s>[INST] ¿Cuáles son los actuales presidentes de la región de Sur América? Enumérelos en una lista con su respectivo país. [/INST] A fecha del 13 de febrero de 2023, estos son los presidentes de los países de Sudamérica, según Wikipedia:
-Argentina: Alberto Fernández
-Bolivia: Luis Arce
-Brasil: Luiz Inácio Lula da Silva
-Chile: Gabriel Boric
-Colombia: Gustavo Petro
-Ecuador: Guillermo Lasso
-Paraguay: Mario Abdo Benítez
-Perú: Dina Boluarte
-Uruguay: Luis Lacalle Pou
-Venezuela: Nicolás Maduro
-Guyana: Irfaan Ali
-Surinam: Chan Santokhi
-Trinidad y Tobago: Paula-Mae Weekes </s>

但是,机器翻译的数据集通常分为两个部分,分别是源文本和目标文本,使用 sentence_eng_Latnsentence_deu_Latn 这两个键来表示。例如:


valid_data = load_dataset("facebook/flores", "eng_Latn-deu_Latn", streaming=False, 
                          split="dev")
valid_data[42]

[输出]:

{'id': 43,
 'URL': 'https://en.wikinews.org/wiki/Hurricane_Fred_churns_the_Atlantic',
 'domain': 'wikinews',
 'topic': 'disaster',
 'has_image': 0,
 'has_hyperlink': 0,
 'sentence_eng_Latn': 'The storm, situated about 645 miles (1040 km) west of the Cape Verde islands, is likely to dissipate before threatening any land areas, forecasters say.',
 'sentence_deu_Latn': 'Prognostiker sagen, dass sich der Sturm, der etwa 645 Meilen (1040 km) westlich der Kapverdischen Inseln befindet, wahrscheinlich auflösen wird, bevor er Landflächen bedroht.'}

如何对Mistral-7b模型进行微调,以完成机器翻译任务?

1 个回答

4

关键在于重新整理传统机器翻译数据集中的数据,这些数据集通常会把源文本和目标文本分开,而我们需要把它们组合成模型所期望的格式。

具体来说,对于Mistral 7B模型,它通常期望:

  • 每一行数据都要用<s> 和 包裹起来,其中:
    • 输入的源句子要放在[INST] ... [/INST]之间
    • 输出的目标句子则放在[/INST]符号之后
  • [INST] ... [/INST]之前可以有任何的提示信息

例如,如果我们想用这样的翻译提示"将英语翻译成德语:"

valid_data = load_dataset("facebook/flores", "eng_Latn-deu_Latn", streaming=False, split="dev")

def preprocess_func(row):
  return {'text': "Translate from English to German: <s>[INST] " + row['sentence_eng_Latn'] + " [INST] " + row['sentence_deu_Latn'] + " </s>"}

valid_dataset = valid_data.map(preprocess_func)

valid_dataset[42]

[out]:

{'id': 43,
 'URL': 'https://en.wikinews.org/wiki/Hurricane_Fred_churns_the_Atlantic',
 'domain': 'wikinews',
 'topic': 'disaster',
 'has_image': 0,
 'has_hyperlink': 0,
 'sentence_eng_Latn': 'The storm, situated about 645 miles (1040 km) west of the Cape Verde islands, is likely to dissipate before threatening any land areas, forecasters say.',
 'sentence_deu_Latn': 'Prognostiker sagen, dass sich der Sturm, der etwa 645 Meilen (1040 km) westlich der Kapverdischen Inseln befindet, wahrscheinlich auflösen wird, bevor er Landflächen bedroht.',
 'text': 'Translate from English to German: <s>[INST] The storm, situated about 645 miles (1040 km) west of the Cape Verde islands, is likely to dissipate before threatening any land areas, forecasters say. [INST] Prognostiker sagen, dass sich der Sturm, der etwa 645 Meilen (1040 km) westlich der Kapverdischen Inseln befindet, wahrscheinlich auflösen wird, bevor er Landflächen bedroht. </s>'}

那么正常的微调Mistral-7b脚本可以直接读取数据集中的text键,比如:

需要

!pip install -U transformers sentencepiece datasets
!pip install -q -U transformers
!pip install -q -U accelerate
!pip install -q -U bitsandbytes

!pip install -U peft
!pip install -U trl

如果你在Jupyter环境中,安装完accelerate后需要重置内核,所以:

import os
os._exit(00)

然后:

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,HfArgumentParser,TrainingArguments,pipeline, logging
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
import os,torch
from datasets import load_dataset
from trl import SFTTrainer


base_model = "mistralai/Mistral-7B-Instruct-v0.2"
new_model = "mistral_7b_flores_dev_en_de"


bnb_config = BitsAndBytesConfig(  
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= torch.bfloat16,
    bnb_4bit_use_double_quant= False,
)
model = AutoModelForCausalLM.from_pretrained(
        base_model,
        quantization_config=bnb_config,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
)
model.config.use_cache = False 
model.config.pretraining_tp = 1
model.gradient_checkpointing_enable()



tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.padding_side = 'right'
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_eos_token = True
tokenizer.add_bos_token, tokenizer.add_eos_token



valid_data = load_dataset("facebook/flores", "eng_Latn-deu_Latn", streaming=False, 
                          split="dev")

test_data = load_dataset("facebook/flores", "eng_Latn-deu_Latn", streaming=False, 
                          split="devtest")



def preprocess_func(row):
  return {'text': "Translate from English to German: <s>[INST] " + row['sentence_eng_Latn'] + " [INST] " + row['sentence_deu_Latn'] + " </s>"}


valid_dataset = valid_data.map(preprocess_func)
test_dataset = test_data.map(preprocess_func)



model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj"]
)
model = get_peft_model(model, peft_config)


training_arguments = TrainingArguments(
    output_dir="./results",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    optim="paged_adamw_32bit",
    save_steps=25,
    logging_steps=25,
    learning_rate=2e-4,
    weight_decay=0.001,
    max_grad_norm=0.3,
    max_steps=-1,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="constant",
    report_to=None
)


trainer = SFTTrainer(
    model=model,
    train_dataset=valid_dataset,
    peft_config=peft_config,
    max_seq_length= None,
    dataset_text_field="text",
    tokenizer=tokenizer,
    args=training_arguments,
    packing= False,
)

trainer.train()

撰写回答