Models API
This module contains model architectures and utilities for loading and configuring models.
Available Models
Gemma3 270M
Lightweight model suitable for resource-constrained environments:
# Configuration: conf/model/gemma3_270m.yaml
model_family: gemma3
model_size: 270m
lora_rank: 32
lora_alpha: 32.0
Use:
python run_training.py model=gemma3_270m
Memory requirements:
VRAM: ~11GB (with batch size 1)
Single GPU: RTX 2080 Ti or RTX A4000
LoRA rank: 8-32 (lower for memory constraints)
Gemma3 1B
Standard model for balanced performance and efficiency:
# Configuration: conf/model/gemma3_1b.yaml
model_family: gemma3
model_size: 1b
lora_rank: 32
lora_alpha: 32.0
Use:
python run_training.py model=gemma3_1b
Memory requirements:
VRAM: ~48GB (with batch size 4)
Single GPU: RTX A6000
LoRA rank: 16-64
Gemma3 4B
Larger model for higher capacity tasks:
model_family: gemma3
model_size: 4b
lora_rank: 64
lora_alpha: 64.0
Use:
python run_training.py model=gemma3_4b
Memory requirements:
VRAM: ~80GB (with batch size 8)
Multiple GPUs: H100 or A100
LoRA rank: 32-128
Model Configuration
Key configuration parameters:
model:
model_family: gemma3 # Model family name
model_size: 1b # Size variant (270m, 1b, 4b)
lora_rank: 32 # LoRA rank for low-rank adaptation
lora_alpha: 32.0 # LoRA alpha scaling factor
lora_module_path: ".*pattern" # Regex for which layers to apply LoRA
mesh_shape: [[1,4], # Parallelism shape
["fsdp","tp"]] # fsdp: fully sharded, tp: tensor parallel
LoRA (Low-Rank Adaptation)
LoRA is a parameter-efficient fine-tuning technique that:
Freezes the base model weights
Adds low-rank trainable adapters to specified layers
Reduces trainable parameters from 100% to ~1%
Configuration:
lora_rank: 32 # Rank of adaptation matrices (4-128)
lora_alpha: 32.0 # Scaling factor (usually = rank)
Higher rank = more capacity but more parameters and memory.
Typical settings:
Memory constrained (11GB): rank 8-16
Standard (48GB): rank 32-64
High capacity (80GB+): rank 64-128
Module Selection
Control which layers get LoRA:
lora_module_path: ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum"
This regex pattern applies LoRA to:
Attention projections (q, k, v einsum operations)
MLP layers (gate, up, down projections)
Custom Model Configuration
Create a custom model in conf/model/custom.yaml:
model_family: gemma3
model_size: 4b
lora_rank: 16
lora_alpha: 16.0
lora_module_path: ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|.*attn_vec_einsum"
mesh_shape: [[1, 4], ["fsdp", "tp"]]
Use it:
python run_training.py model=custom
Distributed Training
For multi-GPU training, configure mesh shape:
# Single GPU
mesh_shape: [[1, 1], ["fsdp", "tp"]]
# 2 GPUs in a row (data parallelism)
mesh_shape: [[2, 1], ["fsdp", "tp"]]
# 4 GPUs in a 2x2 grid
mesh_shape: [[2, 2], ["fsdp", "tp"]]
# 4 GPUs in a line (tensor parallelism)
mesh_shape: [[1, 4], ["fsdp", "tp"]]
Where:
fsdp: Fully Sharded Data Parallel - shards model across GPUs
tp: Tensor Parallel - splits tensors across GPUs
Model Loading
Models are loaded from:
Hugging Face Hub: For public models
Kaggle Models: For Kaggle-hosted weights
Local checkpoint: For previous training runs
The framework automatically handles model downloading and caching.
Memory Requirements by Model
Based on batch size and LoRA rank:
RTX 2080 Ti (11GB):
- Model: gemma3_270m
- Batch size: 1
- LoRA rank: 8-16
RTX A6000 (48GB):
- Model: gemma3_1b
- Batch size: 4
- LoRA rank: 32-64
H100/A100 (80GB):
- Model: gemma3_4b
- Batch size: 8
- LoRA rank: 64-128
Advanced Topics
See Distributed Training for:
Tensor parallel training
Fully sharded data parallel setup
Multi-node distributed training
Next Steps
Training Guide - Training guide
Distributed Training - Distributed training setup
Training API - Training API reference