I Ported DeepMind’s Disco103 from JAX to PyTorch

Here is a PyTorch port of the Disco103 update rule:

https://github.com/asystemoffields/disco-torch

pip install disco-torch

The port loads the pretrained disco_103.npz weights and reproduces the reference Catch benchmark (99% catch rate at 1000 steps). All meta-network outputs match the JAX implementation within float32 precision (<1e-6 max diff), and the full value pipeline is verified (14 fields, <6e-4 max diff).

It includes a high-level DiscoTrainer API that handles meta-state management, target networks, replay buffer, and the training loop:

from disco_torch import DiscoTrainer, collect_rollout

trainer = DiscoTrainer(agent, device=device) for step in range(1000): rollout, obs, state = collect_rollout(agent, step_fn, obs, state, 29, device) logs = trainer.step(rollout)

Sharing in case it’s useful to the community. Slàinte!

submitted by /u/Far-Respect-4827
[link] [comments]

Liked Liked