Skip to content

Finetune with HuggingFace TRL

In this tutorial, we will fine-tune the distilled DeepSeek-R1 model using HuggingFace TRL.

DeepSeek-R1 is a state-of-the-art language model known for its efficiency in handling large datasets. It uses advanced techniques to minimize computational resources. HuggingFace TRL is a library that simplifies model training and fine-tuning.

Pre-requisites

Pre-requisites

To follow this tutorial, you need to have GPU and CUDA driver installed.
Additionally, Pytorch and Transformers are required.

First, install the necessary packages.

pip install steev trl

To proceed, you need to authenticate with Steev.
For more details on authentication, refer to Authentication

steev auth login

Get the example code

We’ll use an example from HuggingFace TRL Documentation. Copy the following code and save it as trl_train.py

The code below is a simplified version of the example.

import argparse
import multiprocessing
import os

import torch
from accelerate import PartialState
from datasets import load_dataset
from peft import LoraConfig
from transformers import (
    AutoTokenizer,
    BitsAndBytesConfig,
    logging,
    set_seed,
)
from trl import (
    SFTTrainer,
    SFTConfig,
    get_kbit_device_map,
)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name_or_path", type=str, default="HuggingFaceTB/SmolLM2-1.7B")
    parser.add_argument("--dataset_name", type=str, default="bigcode/the-stack-smol")
    parser.add_argument("--dataset_config", type=str, default=None)
    parser.add_argument("--dataset_train_split", type=str, default="train")
    parser.add_argument("--dataset_test_split", type=str, default="test")
    parser.add_argument("--dataset_text_field", type=str, default="content")

    # Model arguments
    parser.add_argument("--model_revision", type=str, default=None)
    parser.add_argument("--trust_remote_code", type=bool, default=True)
    parser.add_argument("--attn_implementation", type=str, default=None)
    parser.add_argument("--torch_dtype", type=str, default="bfloat16")

    # Training arguments
    parser.add_argument("--max_seq_length", type=int, default=2048)
    parser.add_argument("--max_steps", type=int, default=1000)
    parser.add_argument("--per_device_train_batch_size", type=int, default=1)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
    parser.add_argument("--weight_decay", type=float, default=0.01)
    parser.add_argument("--learning_rate", type=float, default=2e-4)
    parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
    parser.add_argument("--warmup_steps", type=int, default=100)
    parser.add_argument("--eval_strategy", type=str, default="no")
    parser.add_argument("--gradient_checkpointing", type=bool, default=False)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--output_dir", type=str, default="finetune_smollm2_python")
    parser.add_argument("--num_proc", type=int, default=None)
    parser.add_argument("--push_to_hub", type=bool, default=True)
    parser.add_argument("--repo_id", type=str, default="SmolLM2-1.7B-finetune")

    # QLoRA arguments
    parser.add_argument("--use_qlora", type=bool, default=False)
    parser.add_argument("--lora_r", type=int, default=16)
    parser.add_argument("--lora_alpha", type=int, default=32)
    parser.add_argument("--lora_dropout", type=float, default=0.05)
    parser.add_argument("--lora_target_modules", type=str, nargs="+", default=["q_proj", "v_proj"])

    return parser.parse_args()


def main(args):
    # Convert torch_dtype from string to torch.dtype
    if args.torch_dtype != "auto":
        args.torch_dtype = getattr(torch, args.torch_dtype)

    # Model init kwargs
    model_kwargs = dict(
        revision=args.model_revision,
        trust_remote_code=args.trust_remote_code,
        attn_implementation=args.attn_implementation,
        torch_dtype=args.torch_dtype,
        use_cache=False if args.gradient_checkpointing else True,
    )

    # QLoRA config
    if args.use_qlora:
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=args.torch_dtype if args.torch_dtype != "auto" else torch.bfloat16,
        )
        model_kwargs["device_map"] = get_kbit_device_map()
        model_kwargs["quantization_config"] = quantization_config

        peft_config = LoraConfig(
            r=args.lora_r,
            lora_alpha=args.lora_alpha,
            lora_dropout=args.lora_dropout,
            target_modules=args.lora_target_modules,
            bias="none",
            task_type="CAUSAL_LM",
        )
    else:
        quantization_config = None
        peft_config = None
        model_kwargs["device_map"] = {"": PartialState().process_index}

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path, trust_remote_code=args.trust_remote_code, use_fast=True
    )
    tokenizer.pad_token = tokenizer.eos_token

    # Load dataset
    token = os.environ.get("HF_TOKEN", None)
    dataset = load_dataset(
        args.dataset_name,
        name=args.dataset_config,
        token=token,
        num_proc=args.num_proc if args.num_proc else multiprocessing.cpu_count(),
    )

    # Format dataset to use the correct text field
    def formatting_func(example):
        return example[args.dataset_text_field]

    # Setup trainer
    trainer = SFTTrainer(
        model=args.model_name_or_path,
        train_dataset=dataset[args.dataset_train_split],
        eval_dataset=dataset[args.dataset_test_split] if args.eval_strategy != "no" else None,
        args=SFTConfig(
            per_device_train_batch_size=args.per_device_train_batch_size,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            warmup_steps=args.warmup_steps,
            max_steps=args.max_steps,
            learning_rate=args.learning_rate,
            lr_scheduler_type=args.lr_scheduler_type,
            weight_decay=args.weight_decay,
            max_seq_length=args.max_seq_length,
            gradient_checkpointing=args.gradient_checkpointing,
            logging_strategy="steps",
            logging_steps=10,
            output_dir=args.output_dir,
            optim="paged_adamw_8bit",
            seed=args.seed,
            run_name=None,
            report_to=[],
            push_to_hub=args.push_to_hub,
            model_init_kwargs=model_kwargs,
        ),
        processing_class=tokenizer,
        peft_config=peft_config,
        formatting_func=formatting_func,
    )

    # Train
    print("Training...")
    trainer.train()

    # Save and push to hub
    trainer.save_model(args.output_dir)
    if args.push_to_hub:
        trainer.push_to_hub(dataset_name=args.dataset_name)

    print("Training Done! 💥")


if __name__ == "__main__":
    args = get_args()
    set_seed(args.seed)
    os.makedirs(args.output_dir, exist_ok=True)
    logging.set_verbosity_error()
    main(args)

Run the code with Steev with no configuration

To run the script with Steev, simply execute the following command.

steev run trl_train.py 

If you want to modify the parameters, there is no need to change the code. Just pass the parameters to the command. Steev handles the rest.

steev run trl_train.py --kwargs lr=0.1