Training Guide

Overview

Agent-Tunix implements GRPO (Group Relative Policy Optimization) training for language models with parameter-efficient fine-tuning via LoRA.

Basic Training

Start training with default configuration:

python run_training.py

This will:

  1. Load the reference model (frozen)

  2. Create policy model with LoRA

  3. Run GRPO training loop

  4. Save checkpoints periodically

Training Process

The training process involves:

1. Data Loading

  • Loads GSM8K dataset (grade school math)

  • Splits into train/validation/test sets

  • Tokenizes with model’s tokenizer

2. Model Setup

  • Loads base model from Kaggle

  • Applies LoRA to specified layers

  • Initializes reference and policy models

3. GRPO Training

For each training step:

  • Generate multiple responses per prompt

  • Compute rewards (correctness + format)

  • Calculate policy gradient with KL penalty

  • Update policy model with gradient

4. Evaluation

  • Periodically evaluate on validation set

  • Compute metrics (accuracy, format accuracy)

  • Log to Weights & Biases

5. Checkpointing

  • Save model checkpoints

  • Save training logs

  • Save configuration for reproducibility

Training Configuration

Key hyperparameters:

# Model
model.lora_rank: 32              # LoRA rank
model.model_size: 270m           # Model size

# Training
training.micro_batch_size: 4     # Batch per device
training.num_batches: 3738       # Training batches
training.num_epochs: 1           # Epochs

# Optimization
optimizer.learning_rate: 3e-6    # Peak learning rate
optimizer.warmup_ratio: 0.1      # Warmup as % of total
optimizer.max_grad_norm: 0.1     # Gradient clipping

# GRPO
grpo.num_generations: 4          # Responses per prompt
grpo.beta: 0.08                  # KL divergence weight
grpo.epsilon: 0.2                # PPO clipping

# Generation
generation.max_generation_steps: 512
generation.temperature: 0.9
generation.top_k: 50
generation.top_p: 1.0

Monitoring Training

View logs:

tail -f outputs/tunix-grpo/YYYY-MM-DD/HH-MM-SS/train.log

Weights & Biases

Training logs are sent to W&B by default. View at: https://wandb.ai

To disable:

python run_training.py wandb_disabled=true

TensorBoard

View training metrics:

make tensorboard

Then open http://localhost:6006

Memory Optimization

If training runs out of memory, try:

Reduce batch size:

python run_training.py training.micro_batch_size=1

Use smaller model:

python run_training.py model=gemma3_270m

Reduce LoRA rank:

python run_training.py model.lora_rank=8

Shorter sequences:

python run_training.py generation.max_prompt_length=128 generation.max_generation_steps=256

Reduce number of generations:

python run_training.py grpo.num_generations=2

Distributed Training

For multi-GPU training, modify mesh shape in model config:

python run_training.py model.mesh_shape=[[2,2],["fsdp","tp"]]

See Distributed Training Guide for details.

Resuming Training

Resume from latest checkpoint:

python run_training.py checkpoint_dir=./checkpoints/ckpts/

The framework automatically detects and continues from the latest checkpoint.

Custom Reward Functions

Modify reward functions in src/agent_tunix/rewards.py:

  • match_format_exactly: Rewards format compliance

  • check_answer: Rewards correct answers

  • check_numbers: Extracts and validates numbers

See Custom Rewards for custom implementations.

Troubleshooting

Out of Memory

Reduce batch size or model size as shown in Memory Optimization section.

Slow Training

Check GPU utilization with:

nvidia-smi -l 1

If utilization is low, data loading may be the bottleneck.

NaN Loss

May indicate:

  • Learning rate too high: reduce with optimizer.learning_rate=1e-6

  • Gradient overflow: reduce optimizer.max_grad_norm

  • Data issue: inspect training data

Model Not Improving

Check:

  • Learning rate too low

  • Insufficient training data

  • Reward function not properly calibrated

Tips and Best Practices

  1. Start small: Use +experiment=quick_test first

  2. Monitor metrics: Check logs and W&B frequently

  3. Save often: Use training.save_interval_steps=50

  4. Document runs: Add tags to experiments in config

  5. Validate early: Evaluate on validation set frequently

  6. Checkpoint management: Keep important checkpoints

Example Training Runs

Quick Test (10 steps, reduced):

python run_training.py +experiment=quick_test

Single GPU Training (270M model):

python run_training.py model=gemma3_270m training.micro_batch_size=1

Ablation Study (sweep learning rates):

python run_training.py --multirun optimizer.learning_rate=1e-6,3e-6,1e-5

Production Run (1B model, full training):

python run_training.py model=gemma3_1b +experiment=full_training

Next Steps