[P] Implementing Better Pytorch Schedulers
TL;DR: Current schedulers in PyTorch are limited to just learning rate (lr) changes and often lead to hardcoded, error-prone logic in training loops for anything more complex. I built a flexible suite for scheduling any optimizer hyperparam (LR, momentum, betas, etc.), with support for custom functions, presets, cyclic patterns, and per-group overrides. It’s stateless where possible, picklable for checkpointing, and well-tested.
It currently lives in my research monorepo, but I can separate it into a standalone package if there’s enough interest. Would love feedback!
Why
I’ve been working on replicating (a subset of) training techniques from KellerJordan/modded-nanogpt for my baseline experiments, and realized I needed a reusable scheduling suite. But looking at how scheduling is typically done, and how it’s done in modded-nanogpt, neither approach looked particularly reusable.
Everyone knows that when you create a PyTorch optimizer, its hyperparameters are stored in param_groups, which is a list of dicts where each dict holds params and their hyperparams for a group of model parameters.
For example, here’s a realistic setup where you might want different weight decay for feature extractors vs. classifiers (common in fine-tuning scenarios):
import torch.optim as optim model = SomeLargeModel() # e.g., a vision transformer optimizer = optim.AdamW([ {'params': model.feature_extractor.parameters(), 'weight_decay': 0.1}, # Group 0: High decay for stability {'params': model.classifier.parameters(), 'weight_decay': 0.01} # Group 1: Lower decay for faster adaptation ], lr=1e-3, weight_decay=0.05) # Default values overridden per-group # Per-group overrides take precedence over defaults assert optimizer.param_groups[0]['weight_decay'] == 0.1 assert optimizer.param_groups[1]['weight_decay'] == 0.01
You are allowed (and its common) to tweak these param_groups mid-training to implement scheduling. For instance, you might decay weight decay over time or adjust betas in Adam for better convergence.
Here is how you would typically perform such a change manually:
# Manual mid-training adjustment (common pattern when Trainer/scheduler isn't flexible enough) for epoch in range(num_epochs): for batch in dataloader: # ... compute loss, backward optimizer.step() # Manual mid-training tweak: reduce weight decay after warmup if global_step > warmup_steps: for group in optimizer.param_groups: group['weight_decay'] *= 0.99 # Simple decay
This is straightforward for basic cases, but things get messy with more complexity. For example, look at KellerJordan/modded-nanogpt. They use a combined NorMuon+Adam optimizer where different parameter groups need different scheduling: projection matrices use Muon with momentum warmup/cooldown, while embeddings use Adam with higher weight decay. The scheduling logic is spread across:
- A
param_tabledict defining per-paramlr_mul,wd_mul, andadam_betas - A
TrainingScheduleclass that computes LR based on training stage and cooldown - A
get_muon_momentum()function for Muon’s momentum warmup/cooldown - Manual updates in
step_optimizers()that setsp_cfg.lrandp_cfg.momentumeach step
This is a real research codebase with many contributors, and the coupling between scheduling and training logic makes it hard to experiment with different schedules without touching multiple files.
This leads to “smelly” code: the scheduling logic is coupled with the training loop, which makes the scheduling logic hard to change and test.
Pytorch Schedulers (flawed)
Enter PyTorch’s built-in torch.optim.lr_scheduler, it’s meant to clean this up for LR specifically. Basic usage mirrors the manual tweak but abstracts it:
from torch.optim.lr_scheduler import StepLR optimizer = optim.AdamW(model.parameters(), lr=1e-3) scheduler = StepLR(optimizer, step_size=30, gamma=0.1) # Decay LR every 30 epochs by 0.1x for epoch in range(num_epochs): for batch in dataloader: # ... compute loss, backward optimizer.step() scheduler.step() # Updates LR after epoch (not per-batch in this case)
Under the hood, when you call scheduler.step(), it calls _update_lr() (defined in LRScheduler base class at L284), which:
- Calls
get_lr()to compute the new learning rates for each param group - Iterates through
optimizer.param_groupsand calls_update_param_group_val(param_group, "lr", lr)to set each group’s'lr'key
The key point: _update_param_group_val (defined at L83) is just a helper that does param_group["lr"] = val (with special handling for Tensor LRs).
As a result, these schedulers are hardcoded to only handle LR, not momentum, betas, weight decay, or anything else you might want to schedule (which, as seen in the modded-nanogpt example, people do all the time). ¿Why is "lr" hardcoded instead of allowing any param_group key? It’s literally just a string argument. This limitation is artificial forces everyone to reimplement scheduling for non-LR hyperparams from scratch.
Now, onto the design of other PyTorch schedulers themselves. Most derive from LRScheduler and implement their own get_lr() method. Functionally, many could be expressed as LambdaLR with an appropriate lambda.
For instance, StepLR is equivalent to a lambda that drops by gamma every step_size epochs, and CosineAnnealingLR is equivalent to a cosine lambda. However, they’re implemented as separate classes with their own closed-form formulas (via _get_closed_form_lr()), which can be more efficient and readable.
(Btw ReduceLROnPlateau isn’t even a subclass of LRScheduler, it’s a callback that monitors metrics.).
LambdaLR is the most flexible among all PyTorch schedulers. However, usage of the class is inconvenient for multi-group setups.
For example, if you want a custom lambda for group 2, you must provide dummies for groups 0 and 1 (constants, which aren’t “real” schedules):
from torch.optim.lr_scheduler import LambdaLR def constant_lambda(_): return 1.0 # Dummy def decay_lambda(epoch): return 1.0 - epoch / 100 # Actual for group 2 scheduler = LambdaLR(optimizer, lr_lambda=[constant_lambda, constant_lambda, decay_lambda])
Clunky, right? Changing total training length? Your lambdas hardcode it, so tweaks mean rewriting (though factories/partials help, it’s still boilerplate). Advanced schemes like cyclic schedules? CosineAnnealingWarmRestarts exists, but it’s LR-only and inflexible for custom cycles or non-LR params.
My Scheduling Suite
So, what really is a schedule? At its core, it’s a pure function: f(step: int, total_steps: int) -> value (any type, not just float). It maps progress to a param value, and you apply it to optimizer.param_groups[i][param_name] = value. No state, no side effects, just deterministic computation (great for reproducibility).
In my suite, this primitive is user-facing via ParamSchedule (end users are expected to use it directly):
from research_lib.training.scheduling import ParamSchedule def linear_decay(step: int, total_steps: int) -> float: return 1.0 - (step / total_steps) * 0.9 # Decays from 1.0 to 0.1 lr_schedule = ParamSchedule(param_name="lr", schedule_fn=linear_decay) value = lr_schedule(500, 1000) # 0.55
For common patterns, presets (subclasses of the primitive) are provided: e.g., WarmupStableDecaySchedule for warmup → stable → decay:
from research_lib.training.scheduling import WarmupStableDecaySchedule lr_schedule = WarmupStableDecaySchedule( param_name="lr", warmup_steps=100, cooldown_frac=0.5, min_value=0.0, max_value=1.0, decay_type="cosine" )
Need reusable patterns? Subclass the primitive and override the schedule_fn attribute
For cyclic schedules e.g. for continual training, enter “wrapper land” (via wrappers submodule). These are composable callables that wrap a base_fn:
from research_lib.training.scheduling import wrappers as sw base_fn = ... # e.g., a decay schedule cyclic_fn = sw.Cyclic(base_fn, cycle_steps=1000) # Repeats every 1000 steps lr_schedule = ParamSchedule("lr", cyclic_fn)
Finally, the runtime layer: ParamScheduler binds it all, tracks state for checkpointing, and supports global + per-group overrides:
from research_lib.training.scheduling import ParamScheduler scheduler = ParamScheduler( optimizer=optimizer, global_schedules=[lr_schedule, momentum_schedule], group_overrides={1: [slow_lr_schedule]}, # Override for group 1 total_steps=10000 ) # In loop optimizer.step() scheduler.step() # Applies all, increments internal step # Checkpoint: scheduler.state_dict() / load_state_dict()
When designing this, I followed these design choices:
- “No restriction on action space” (schedules can do anything PyTorch allows),
- “Make illegal states unrepresentable” (required args aren’t optional; validation at
__init__) - Minimize coupling (schedules are pure, optimizer bound at runtime).
It’s tested thoroughly (e.g., pickling, validation checks like monotonicity). Thoughts? Does this solve pains you’ve hit? Link to submodule here: LMK if I should extract it!
submitted by /u/shivvorz
[link] [comments]