GRPO Training Script Walkthrough

Line-by-line explanation of train_grpo.py

This document explains the structure and intent of each part of the training script below. Code cells are shown but not executed.

1) Header & Citation

The file begins with a filename comment, a note pointing to the upstream project for ongoing developments, and a BibTeX-style citation for the demo. This establishes provenance and a reference you can cite in reports.

# train_grpo.py
#
# See https://github.com/willccbb/verifiers for ongoing developments
#
"""
citation:

@misc{brown2025grpodemo,
  title={Granular Format Rewards for Eliciting Mathematical Reasoning Capabilities in Small Language Models},
  author={Brown, William},
  howpublished={\url{https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb}},
  date = {2025-01-25},
  note = {GitHub Gist}
}
"""

2) Imports

These imports bring in regex utilities, PyTorch, Hugging Face datasets & transformers, parameter-efficient fine‑tuning (PEFT/LoRA), and TRL’s GRPO trainer/configs.

import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer

Why these matter:

  • datasets: loads GSM8K (grade school math questions).
  • transformers: model and tokenizer loading.
  • peft: optional LoRA adapter config.
  • trl: GRPO (Group Relative Policy Optimization) RL training abstractions.

3) System Prompt & XML Output Format

The script enforces a strict XML‑like format for chain‑of‑thought reasoning and the final answer. This makes automatic checking and reward shaping much easier.

# Load and prep dataset

SYSTEM_PROMPT = """
Respond in the following format:

<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

What it does:

  • SYSTEM_PROMPT guides the model to structure its output.
  • XML_COT_FORMAT is used for potential few‑shot exemplars (currently commented out).

4) Extraction Helpers

Two small helpers parse answers out of model outputs or dataset strings.

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip().replace(",", "").replace("$", "")

Why:

  • extract_xml_answer pulls whatever is between <answer> ... </answer> in the model response.
  • extract_hash_answer normalizes GSM8K references (the dataset encodes gold answers after ####).

5) Dataset Loader (get_gsm8k_questions)

Maps each GSM8K example into a chat-style prompt with a system message and the user’s math question. The few‑shot block is available but commented for zero‑shot.

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            #{'role': 'user', 'content': 'What is the largest single-digit prime number?'},
            #{'role': 'assistant', 'content': XML_COT_FORMAT.format(
            #    reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.",
            #    answer="7"
            #)},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

Key points:

  • Loads openai/gsm8k split.
  • Converts each item to a list of chat turns.
  • Extracts the numeric gold answer for reward comparison.

6) Reward Functions

The RL trainer will call these functions to compute multiple reward signals per sample. They compose formatting checks with task correctness.

6.1 Correctness Reward

def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
  • Parses the <answer> field and gives 2.0 reward if it exactly equals the gold answer; otherwise 0.0.

6.2 Integer Type Reward

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
  • Encourages numeric answers (helps push toward simple integers typical in GSM8K).

6.3 Strict & Soft Format Rewards

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] 
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] 
    return [0.5 if match else 0.0 for match in matches]
  • Strict requires exact newlines and final newline; soft accepts any whitespace between sections.

6.4 XML Token Count Reward

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]
  • Gives small partial rewards for hitting the required tags exactly once and slightly penalizes any trailing junk after the closing tags.

7) Model Choice & Output Paths

Selects the base model and sets the output/run names based on whether the model string contains "Llama".

#model_name = "meta-llama/Llama-3.2-1B-Instruct"
model_name = "Qwen/Qwen2.5-1.5B-Instruct"

if "Llama" in model_name:
    output_dir = "outputs/Llama-1B-GRPO"
    run_name = "Llama-1B-GRPO-gsm8k"
else:
    output_dir="outputs/Qwen-1.5B-GRPO"
    run_name="Qwen-1.5B-GRPO-gsm8k"

Tip: Change model_name to switch bases; downstream names update automatically.

8) GRPO Training Arguments

Configures optimization hyperparameters and generation behavior used by TRL’s GRPO.

training_args = GRPOConfig(
    output_dir=output_dir,
    run_name=run_name,
    learning_rate=5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_generations=16,
    max_prompt_length=256,
    max_completion_length=786,
    num_train_epochs=1,
    save_steps=100,
    max_grad_norm=0.1,
    report_to="wandb",
    log_on_each_node=False,
)

Highlights:

  • bf16=True assumes bfloat16 support (e.g., recent NVIDIA GPUs).
  • num_generations=16 samples multiple completions per prompt for group‑based RL.
  • max_*_length caps prompt and completion tokens to control compute.
  • Logging/reporting uses Weights & Biases (report_to="wandb").

9) Optional LoRA (PEFT) Configuration

Defines adapter ranks/targets if you choose parameter‑efficient tuning. (Commented out later in trainer).

peft_config = LoraConfig(
    r=16,
    lora_alpha=64,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
)

Note: The script warns PEFT may not work with multi‑GPU runs here; test single‑GPU first.

10) Model & Tokenizer Initialization

Loads the base model in bfloat16, requests Flash‑Attention v2 kernels, moves it to CUDA, and prepares the tokenizer.

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map=None
).to("cuda")

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

Why:

  • Flash‑Attention 2 can speed up long‑context training.
  • Ensures padding token exists (many chat models use EOS as pad).

11) GRPO Trainer Assembly

Wires the model, tokenizer, dataset, and list of reward functions into TRL’s GRPOTrainer.

# use peft at your own risk; not working for me with multi-GPU training
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func],
    args=training_args,
    train_dataset=dataset,
    #peft_config=peft_config
)

Order matters only insofar as you wish to inspect them; TRL aggregates multiple rewards per sample. Here we combine structure and correctness signals.

12) Start Training

Finally, training is kicked off.

trainer.train()

Appendix: Setup & Practical Notes

  • Dependencies: pip install torch transformers datasets trl peft accelerate flash-attn wandb (match CUDA/PyTorch wheels to your system).
  • GPU: This configuration expects a CUDA GPU with bfloat16 support for speed/memory efficiency.
  • W&B: Ensure wandb is logged in if you keep report_to="wandb".
  • Few‑shot: Uncomment the block in get_gsm8k_questions to add a one‑shot example in the prompt.
  • Safety: Rewards assume the <answer> field is exactly the gold answer string; you can add normalization if needed (e.g., strip punctuation or parse integers).