So, a few days back I shared a post where I trained a tiny Qwen2.5-0.5B-Instruct model on smoltldr (reddit post summarization dataset of 2k rows), to output summaries of about 64 max length using RLVR with GRPO .
However, there was a catch!
- The wandb charts for avg response length was going down and saturated around 10-15 tokens on an avg. This was the result of me confusing between character counts and token counts, I meant to do 64 tokens but rather I accidentally went for 64 characters!
Hence the charts showed a sharp decline and convergence towards a response length of on and off 15 tokens.
The rewards I used were 2:
- length_penalty : basically, -abs(response_length – MAX_LENGTH)
- quality_reward: a ROUGE-L, which is basically LCS of golden summarizations I had as part of the above dataset, to ensure we have some structure throughout the responses generated and minimize degradation.
Trained to one full epoch with a batch size of 2 max (before getting a OOM), the results were identical to the previous run, however, with one crucial difference –
- without a quality reward in my previous runs, the system tried to game the rewards by outputting stuff like “——-*20” tokens thats it!
- But not this time since I got the near same results for rewards of both the experiments when I included both vs just length penalty, and no degradation in the rollouts after 1 full epoch so I wonder why?
Anyways, next up:
- Find out why GRPO didn’t try other game the reward system?
- Try out metrics other than ROUGE-L to get better summarizations maybe
- Setup LLM-As-A-Judge to quantify the results.
- Train some HF SmolLM series now!
- What if I told in the prompt itself about the reward system and about the MAX_LENGTH with the task?
- Different MAX_LENGTH?
https://preview.redd.it/mf7rux5lhyug1.png?width=800&format=png&auto=webp&s=bc54273f644ee2306b03834e037ab3e91f3b0582
https://preview.redd.it/1es4n61mhyug1.png?width=800&format=png&auto=webp&s=a8cc4249e646f03e8396cf79e640e27fcd1edfce
https://preview.redd.it/djsslwsmhyug1.png?width=800&format=png&auto=webp&s=91589c746ac7a2c43d724e4768e8cb610288dee4