💻 Примеры fine-tuning пайплайна

← К оглавлению урока

⚡ Главный пример

Учебный пайплайн: создать QA-датасет о компании, превратить пары вопрос-ответ в instruction format, обучить LoRA-адаптер для небольшой causal LM и проверить ответы на тестовых вопросах.

Задача

Мы хотим, чтобы модель отвечала в устойчивом формате на вопросы о вымышленной компании TechInnovate Solutions. Пример основан на исходном файле lesson-ai-10/ai10-1.py, но токены и настройки показаны безопаснее для учебного проекта.

1. Генерация QA-датасета

# fine_tune_company.py
import os
import pandas as pd
import torch
from datasets import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
from huggingface_hub import login

company_info = {
    "name": "TechInnovate Solutions",
    "founded": 2015,
    "headquarters": "San Francisco, California",
    "ceo": "Dr. Alex Morgan",
    "employees": 250,
    "products": [
        "DataSync Pro - A cloud-based data synchronization platform",
        "AI Assistant - An enterprise AI chatbot solution",
        "SecureConnect - An end-to-end encrypted communication tool",
    ],
    "mission": "To leverage cutting-edge technology to solve complex business challenges while promoting sustainability.",
    "values": ["Innovation", "Integrity", "Collaboration", "Sustainability"],
    "revenue": "$45 million (2023)",
}

def generate_qa_pairs(company_data):
    rows = [
        {
            "question": "What is the name of the company?",
            "answer": "The company name is " + company_data["name"] + ".",
        },
        {
            "question": "Who is the CEO of the company?",
            "answer": "The CEO of " + company_data["name"] + " is " + company_data["ceo"] + ".",
        },
        {
            "question": "What products does the company offer?",
            "answer": company_data["name"] + " offers: " + ", ".join(company_data["products"]) + ".",
        },
        {
            "question": "What was the company's revenue in 2018?",
            "answer": "I don't have information about the company's revenue in 2018.",
        },
    ]

    for product in company_data["products"]:
        product_name, product_desc = product.split(" - ", 1)
        rows.append({
            "question": "What is " + product_name + "?",
            "answer": product_name + " is " + product_desc + ".",
        })

    return rows

qa_data = generate_qa_pairs(company_info)
pd.DataFrame(qa_data).to_csv("company_qa_data.csv", index=False)

2. Формат для instruction tuning

# fine_tune_company.py
def format_for_training(question, answer):
    return (
        "### Instruction: Answer the following question about "
        + company_info["name"]
        + " accurately. If the information is missing, say that you do not have it.\n\n"
        + "### Input: "
        + question
        + "\n\n"
        + "### Response: "
        + answer
    )

formatted_rows = [
    {"text": format_for_training(row["question"], row["answer"])}
    for row in qa_data
]

dataset = Dataset.from_list(formatted_rows)
split = dataset.train_test_split(test_size=0.25, seed=42)
train_dataset = split["train"]
test_dataset = split["test"]

3. Модель, токенизатор и LoRA

# fine_tune_company.py
hf_token = os.getenv("HF_TOKEN")
if hf_token:
    login(token=hf_token)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=512,
        padding="max_length",
    )

tokenized_train = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
tokenized_test = test_dataset.map(tokenize_function, batched=True, remove_columns=["text"])

peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)

4. Обучение

# fine_tune_company.py
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)
model.to(device)

training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    max_steps=100,
    learning_rate=2e-4,
    fp16=torch.cuda.is_available(),
    logging_steps=10,
    save_steps=100,
    save_total_limit=2,
    report_to="none",
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    data_collator=data_collator,
)

trainer.train()
print(trainer.evaluate())

model.save_pretrained("./fine_tuned_company_adapter")
tokenizer.save_pretrained("./fine_tuned_company_adapter")
Ресурсы. Даже маленькая модель может требовать много памяти. Для слабого ноутбука уменьшайте max_steps, batch size, длину последовательности или запускайте обучение в облачной GPU-среде.

5. Проверка результата

# fine_tune_company.py
def generate_response(question, max_new_tokens=120):
    prompt = (
        "### Instruction: Answer the following question about "
        + company_info["name"]
        + " accurately. If the information is missing, say that you do not have it.\n\n"
        + "### Input: "
        + question
        + "\n\n"
        + "### Response:"
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=0.2,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id,
        )

    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return decoded.split("### Response:")[-1].strip()

test_questions = [
    "What is TechInnovate Solutions?",
    "Who is the CEO of the company?",
    "What products does the company offer?",
    "What was the company's revenue in 2020?",
]

for question in test_questions:
    print("Q:", question)
    print("A:", generate_response(question))
    print("-" * 50)

6. Что оценивать

ПроверкаХороший результат
Факты из датасетаCEO, продукты, миссия и числа не искажаются.
Нет информацииМодель честно говорит, что данных нет.
ФорматОтветы короткие, одинакового стиля, без лишнего текста.
BaselineFine-tuned вариант лучше обычной модели на тестовых вопросах.
SafetyМодель не раскрывает секреты, не придумывает персональные данные, не даёт опасных советов.