Training API
This module contains the main training entry point and training utilities.
Main Entry Point
Training Utilities
Configuration Classes
The training module works with Hydra’s DictConfig. Configuration structure:
training:
micro_batch_size: 4 # Batch size per device
num_batches: 3738 # Total training batches
num_epochs: 1 # Number of epochs
checkpoint_dir: ./checkpoints/ckpts/
save_interval_steps: 100 # Save checkpoint every N steps
eval_interval_steps: 500 # Evaluate every N steps
optimizer:
learning_rate: 3e-6 # Peak learning rate
warmup_ratio: 0.1 # Warmup as % of total
max_grad_norm: 0.1 # Gradient clipping
grpo:
num_generations: 4 # Responses per prompt
beta: 0.08 # KL divergence weight
epsilon: 0.2 # PPO clipping
generation:
max_prompt_length: 256
max_generation_steps: 512
temperature: 0.9
top_k: 50
top_p: 1.0
Training Flow
The training process follows these steps:
Configuration Loading: Hydra loads and composes configuration from YAML files
Seed Management: Sets random seeds for reproducibility
Dataset Preparation: Loads and tokenizes training data (GSM8K by default)
Model Setup: - Loads base model from Hugging Face or Kaggle - Applies LoRA to specified layers - Initializes reference and policy models
Training Loop: - Generate multiple responses per prompt - Compute rewards (correctness + format matching) - Calculate policy gradients with KL penalty - Update model weights - Log metrics to Weights & Biases
Checkpointing: Saves model weights and configuration periodically
Evaluation: Periodically evaluates on validation set
GRPO Algorithm
The Group Relative Policy Optimization algorithm:
For each prompt, generate K responses (num_generations)
Compute rewards for each response
Normalize rewards relative to group (Group Relative)
Compute policy gradients with PPO clipping
Apply KL divergence penalty to stay close to reference model
Update weights using optimizer with gradient clipping
Key hyperparameters:
num_generations: Number of candidate responses to generate per prompt
beta: Weight of KL divergence penalty (controls deviation from reference)
epsilon: PPO clipping range for gradient updates
Example Usage
Basic training with defaults:
from agent_tunix.train import train
from hydra import compose, initialize
from omegaconf import DictConfig
if __name__ == "__main__":
train() # Hydra decorator handles config loading
Command-line examples:
# Default configuration
python run_training.py
# Override learning rate
python run_training.py optimizer.learning_rate=1e-5
# Use different model
python run_training.py model=gemma3_1b
# Multiple overrides
python run_training.py model=gemma3_1b training.micro_batch_size=2 optimizer.learning_rate=3e-6
# Use experiment preset
python run_training.py +experiment=quick_test
# Parameter sweep
python run_training.py --multirun optimizer.learning_rate=1e-7,1e-6,1e-5,1e-4
Tips and Best Practices
Start with quick_test: Use
+experiment=quick_testfor 10-step validation runsMonitor GPU memory: Use
nvidia-smiduring trainingCheck configuration: Use
python run_training.py --cfg jobbefore runningSave checkpoints frequently: Lower
save_interval_stepsfor important runsEnable logging: W&B is enabled by default, set
wandb_disabled=trueto disableUse gradual unfreezing: Start with smaller learning rates and increase if loss plateaus
Next Steps
Training Guide - Detailed training guide
Hyperparameter Tuning - Hyperparameter tuning strategies
Evaluation API - Evaluation API reference