GRPO Training Script Walkthrough
Line-by-line explanation of train_grpo.py
This document explains the structure and intent of each part of the training script below. Code cells are shown but not executed.
1) Header & Citation
The file begins with a filename comment, a note pointing to the upstream project for ongoing developments, and a BibTeX-style citation for the demo. This establishes provenance and a reference you can cite in reports.
# train_grpo.py
#
# See https://github.com/willccbb/verifiers for ongoing developments
#
"""
citation:
@misc{brown2025grpodemo,
title={Granular Format Rewards for Eliciting Mathematical Reasoning Capabilities in Small Language Models},
author={Brown, William},
howpublished={\url{https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb}},
date = {2025-01-25},
note = {GitHub Gist}
}
"""
2) Imports
These imports bring in regex utilities, PyTorch, Hugging Face datasets & transformers, parameter-efficient fine‑tuning (PEFT/LoRA), and TRL’s GRPO trainer/configs.
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer
Why these matter:
- datasets: loads GSM8K (grade school math questions).
- transformers: model and tokenizer loading.
- peft: optional LoRA adapter config.
- trl: GRPO (Group Relative Policy Optimization) RL training abstractions.
3) System Prompt & XML Output Format
The script enforces a strict XML‑like format for chain‑of‑thought reasoning and the final answer. This makes automatic checking and reward shaping much easier.
# Load and prep dataset
= """
SYSTEM_PROMPT Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
= """\
XML_COT_FORMAT <reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
What it does:
SYSTEM_PROMPT
guides the model to structure its output.XML_COT_FORMAT
is used for potential few‑shot exemplars (currently commented out).
4) Extraction Helpers
Two small helpers parse answers out of model outputs or dataset strings.
def extract_xml_answer(text: str) -> str:
= text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
answer return answer.strip()
def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip().replace(",", "").replace("$", "")
Why:
extract_xml_answer
pulls whatever is between<answer> ... </answer>
in the model response.extract_hash_answer
normalizes GSM8K references (the dataset encodes gold answers after####
).
5) Dataset Loader (get_gsm8k_questions
)
Maps each GSM8K example into a chat-style prompt with a system message and the user’s math question. The few‑shot block is available but commented for zero‑shot.
# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
= load_dataset('openai/gsm8k', 'main')[split] # type: ignore
data = data.map(lambda x: { # type: ignore
data 'prompt': [
'role': 'system', 'content': SYSTEM_PROMPT},
{#{'role': 'user', 'content': 'What is the largest single-digit prime number?'},
#{'role': 'assistant', 'content': XML_COT_FORMAT.format(
# reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.",
# answer="7"
#)},
'role': 'user', 'content': x['question']}
{
],'answer': extract_hash_answer(x['answer'])
# type: ignore
}) return data # type: ignore
= get_gsm8k_questions() dataset
Key points:
- Loads
openai/gsm8k
split. - Converts each item to a list of chat turns.
- Extracts the numeric gold answer for reward comparison.
6) Reward Functions
The RL trainer will call these functions to compute multiple reward signals per sample. They compose formatting checks with task correctness.
6.1 Correctness Reward
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
= [completion[0]['content'] for completion in completions]
responses = prompts[0][-1]['content']
q = [extract_xml_answer(r) for r in responses]
extracted_responses print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
- Parses the
<answer>
field and gives 2.0 reward if it exactly equals the gold answer; otherwise 0.0.
6.2 Integer Type Reward
def int_reward_func(completions, **kwargs) -> list[float]:
= [completion[0]['content'] for completion in completions]
responses = [extract_xml_answer(r) for r in responses]
extracted_responses return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
- Encourages numeric answers (helps push toward simple integers typical in GSM8K).
6.3 Strict & Soft Format Rewards
def strict_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
= r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
pattern = [completion[0]["content"] for completion in completions]
responses = [re.match(pattern, r, flags=re.DOTALL) for r in responses]
matches return [0.5 if match else 0.0 for match in matches]
def soft_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
= r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
pattern = [completion[0]["content"] for completion in completions]
responses = [re.match(pattern, r, flags=re.DOTALL) for r in responses]
matches return [0.5 if match else 0.0 for match in matches]
- Strict requires exact newlines and final newline; soft accepts any whitespace between sections.
6.4 XML Token Count Reward
def count_xml(text) -> float:
= 0.0
count if text.count("<reasoning>\n") == 1:
+= 0.125
count if text.count("\n</reasoning>\n") == 1:
+= 0.125
count if text.count("\n<answer>\n") == 1:
+= 0.125
count -= len(text.split("\n</answer>\n")[-1])*0.001
count if text.count("\n</answer>") == 1:
+= 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
count return count
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
= [completion[0]["content"] for completion in completions]
contents return [count_xml(c) for c in contents]
- Gives small partial rewards for hitting the required tags exactly once and slightly penalizes any trailing junk after the closing tags.
7) Model Choice & Output Paths
Selects the base model and sets the output/run names based on whether the model string contains "Llama"
.
#model_name = "meta-llama/Llama-3.2-1B-Instruct"
= "Qwen/Qwen2.5-1.5B-Instruct"
model_name
if "Llama" in model_name:
= "outputs/Llama-1B-GRPO"
output_dir = "Llama-1B-GRPO-gsm8k"
run_name else:
="outputs/Qwen-1.5B-GRPO"
output_dir="Qwen-1.5B-GRPO-gsm8k" run_name
Tip: Change model_name
to switch bases; downstream names update automatically.
8) GRPO Training Arguments
Configures optimization hyperparameters and generation behavior used by TRL’s GRPO.
= GRPOConfig(
training_args =output_dir,
output_dir=run_name,
run_name=5e-6,
learning_rate= 0.9,
adam_beta1 = 0.99,
adam_beta2 = 0.1,
weight_decay = 0.1,
warmup_ratio ='cosine',
lr_scheduler_type=1,
logging_steps=True,
bf16=1,
per_device_train_batch_size=4,
gradient_accumulation_steps=16,
num_generations=256,
max_prompt_length=786,
max_completion_length=1,
num_train_epochs=100,
save_steps=0.1,
max_grad_norm="wandb",
report_to=False,
log_on_each_node )
Highlights:
bf16=True
assumes bfloat16 support (e.g., recent NVIDIA GPUs).num_generations=16
samples multiple completions per prompt for group‑based RL.max_*_length
caps prompt and completion tokens to control compute.- Logging/reporting uses Weights & Biases (
report_to="wandb"
).
9) Optional LoRA (PEFT) Configuration
Defines adapter ranks/targets if you choose parameter‑efficient tuning. (Commented out later in trainer).
= LoraConfig(
peft_config =16,
r=64,
lora_alpha=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
target_modules="CAUSAL_LM",
task_type=0.05,
lora_dropout )
Note: The script warns PEFT may not work with multi‑GPU runs here; test single‑GPU first.
10) Model & Tokenizer Initialization
Loads the base model in bfloat16, requests Flash‑Attention v2 kernels, moves it to CUDA, and prepares the tokenizer.
= AutoModelForCausalLM.from_pretrained(
model
model_name,=torch.bfloat16,
torch_dtype="flash_attention_2",
attn_implementation=None
device_map"cuda")
).to(
= AutoTokenizer.from_pretrained(model_name)
tokenizer = tokenizer.eos_token tokenizer.pad_token
Why:
- Flash‑Attention 2 can speed up long‑context training.
- Ensures padding token exists (many chat models use EOS as pad).
11) GRPO Trainer Assembly
Wires the model, tokenizer, dataset, and list of reward functions into TRL’s GRPOTrainer
.
# use peft at your own risk; not working for me with multi-GPU training
= GRPOTrainer(
trainer =model,
model=tokenizer,
processing_class=[
reward_funcs
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func],=training_args,
args=dataset,
train_dataset#peft_config=peft_config
)
Order matters only insofar as you wish to inspect them; TRL aggregates multiple rewards per sample. Here we combine structure and correctness signals.
12) Start Training
Finally, training is kicked off.
trainer.train()
Appendix: Setup & Practical Notes
- Dependencies:
pip install torch transformers datasets trl peft accelerate flash-attn wandb
(match CUDA/PyTorch wheels to your system). - GPU: This configuration expects a CUDA GPU with bfloat16 support for speed/memory efficiency.
- W&B: Ensure
wandb
is logged in if you keepreport_to="wandb"
. - Few‑shot: Uncomment the block in
get_gsm8k_questions
to add a one‑shot example in the prompt. - Safety: Rewards assume the
<answer>
field is exactly the gold answer string; you can add normalization if needed (e.g., strip punctuation or parse integers).