I Ported DeepMind’s Disco103 from JAX to PyTorch
Here is a PyTorch port of the Disco103 update rule:
https://github.com/asystemoffields/disco-torch
pip install disco-torch
The port loads the pretrained disco_103.npz weights and reproduces the reference Catch benchmark (99% catch rate at 1000 steps). All meta-network outputs match the JAX implementation within float32 precision (<1e-6 max diff), and the full value pipeline is verified (14 fields, <6e-4 max diff).
It includes a high-level DiscoTrainer API that handles meta-state management, target networks, replay buffer, and the training loop:
from disco_torch import DiscoTrainer, collect_rollout
trainer = DiscoTrainer(agent, device=device) for step in range(1000): rollout, obs, state = collect_rollout(agent, step_fn, obs, state, 29, device) logs = trainer.step(rollout)
Sharing in case it’s useful to the community. Slàinte!
submitted by /u/Far-Respect-4827
[link] [comments]