How to bridge the gap between Torch and JAX performance?
Hi, I am working on an RL project for my studies that uses a variant of SAC. The algorithm benefits greatly from being written in JAX, but for this project I have to use PyTorch because we wanted to try a simulation engine Genesis-World that provides Torch tensors.
The problem is that the PyTorch reimplementation is about 5× slower (even with torch.compile and after avoiding common performance mistakes). Without torch.compile, it is around 15× slower.
The reason seems to be that the algorithm involves many gradient update steps inside a loop, something like:
# pseudocode for the idea for batch in range(1000): loss = loss(model(batch)) loss.backward() optimizer.step()
This is just one iteration (with ~1000 iterations). It is important for the algorithm that it performs many small updates.
JAX compiles everything — the forward pass, backward pass, optimizer step, and even the whole loop. PyTorch doesn’t seem to match this — it compiles the forward pass, maybe the backward pass, but zero_grad() and optimizer.step() still cause graph breaks.
Documentation about Torch compilation is quite difficult to follow. I found multiple ideas on how to compile the optimizer step, zero_grad, and backward pass, and I tried implementing them, but the optimizer graph still shows graph breaks in the same places as before.
From what I’ve read, this kind of workload benefits the most from JAX. Still, I find it surprising that there’s no way to achieve similar performance in PyTorch. I don’t expect it to be automatic — I’m looking for tools or techniques that would allow more manual control to improve performance.
It also feels odd that such a common forward–backward–optimizer pipeline cannot be well optimized in PyTorch. I can’t do the gradient accumulation since the mini updates are important for learning my embeddings. I tried to do something with the functional Pytorch style but I am not sure it will benefit something, and functional optimizers from torchopt can’t be torch compiled.
How could I implement something like this more efficiently?
submitted by /u/Little_swift
[link] [comments]