This post is by Bob.
The title is based on the similarly named classic film.
“Big” models moving from Stan to JAX
Ever since the big ML frameworks PyTorch and TensorFlow were released, the Stan developers have been worried they’re going to put Stan out of business (we built Stan’s autodiff before those packages existed, but after Theano). While that hasn’t quite happened yet, I now believe our days are numbered. For high end applications, Stan is slowly, but surely, being replaced by JAX. Many places I go (don’t want Andrew to jump on a hyperbolic use of “everywhere”), I hear about people switching from Stan to JAX.
Here are four examples:
1. At StanCon in Oxford in 2024, Elizaveta Semenova started her talk by saying something to the effect of, “I’m sorry to say this here, but I don’t use Stan any more—I switched to JAX through NumPyro for scalability.”
2. Mitzi Morris just started working as a contractor for the U.S. Center for Disease Control (CDC) (!? as they say in chess). Their public GitHub repositories have old Stan code they used to use that has been replaced by JAX, for which they are building up a library of code. It’s very hard to build reusable code in Stan given its blocked structure and the limited form of includes; Sean Pinkney has gone further than I thought possible with his helpful Stan functions project. The CDC models are for wastewater-informed forecasting—here’s the project overview.
3. Andrew posted a job announcement from the L.A. Dodgers baseball team a week ago that said, “We have a soft spot for jax and numpyro but Stan and PyMC folks are obviously always of interest.” Like Andrew, they apparently don’t like using their shift key.
4. Matt Hoffman’s been saying this for years and backing it up with adaptive ensemble samplers, convergence monitoring, etc. He, Pavel Sountsov, and Colin Carroll wrote a draft chapter for the second edition of the MCMC Handbook, Running Markov Chain Monte Carlo on Modern Hardware and Software. It contains complete instructions for massively parallelizing HMC on a GPU using JAX.
But what about the hardware?
The biggest obstacle for people moving is finding the hardware on which to run JAX most efficiently—it’s really tailored for multiprocessing and GPU processing and I don’t believe most of the Stan users have access to this kind of hardware to fit their models. But I believe this is going to change over the next ten years. That, and I believe we’re going to get better and better Macs—the ARM chips are way faster than the Intel chips for the kind of random-access memory needed in Stan programs.
New samplers moving to JAX
New samplers like the micro-canonical HMC of Jakob Robnik and Uroš Seljak (and more recently Reuben Cohn-Gordon) are being coded only in JAX. Like many others, they added their package (see the previous link) to the Blackjax package. They even have a competitor for posteriordb in the form of Inference Gym.
A very nice feature of putting things up on Blackjax is that you can use them with any Python-defined log density function—it doesn’t even need to come from JAX. Brian Ward managed to plug Stan models into JAX (by which I mean having JAX call Stan’s C++, not generating JAX code from Stan).
Static vs. dynamic automatic differentiation
We built Stan with automatic differentiation before PyTorch, TensorFlow or JAX existed. We went with the same dynamic design as PyTorch eventually chose, despite Matt Hoffman and I knowing that the static TensorFlow/JAX approach could be more performant. The problem was that we didn’t have the people to implement enough derivatives to do it that way. Instead, we just started autodiffing through functions in the Eigen matrix library (like matrix multiplication and division) and in Boost (like the Runge-Kutta 4/5 ODE solver and many of the special functions). The static approach of XLA (which is the infrastructure under JAX and TensorFlow) does limit expressiveness of things like loops and conditionals to not condition on parameters, making it challenging, if not impossible to write iterative algorithms in JAX.
Graphical modeling
Tools like BUGS, PyMC, and NumPyro are all fundamentally based on the notion of a directed acyclic graphical model. That is, you have nodes representing random variables with each variable being conditionally independent given the nodes that point to it. You specify the distribution of each node given the nodes on which it depends. Transforms are represented by deterministic nodes. The upside to constraining oneself to graphical models is that everything has to remain clearly generative (assuming you avoid improper flat priors, that is). This lets tools like PyMC automate a lot of workflow in the same way that we can with brms in Stan. When you go outside that paradigm, as you can in PyMC by adding density statements, the built-in automation of workflow breaks. So while it’s possible, they generally don’t recommend it. This came up in an earlier blog post I wrote, What’s a generative model? PyMC and Stan edition.
Differentiable programming
Stan does not work on a graphical modeling base. You can write graphical models in Stan, but we just treat them as defining a log density (that was the leap that led to Stan—I thought about how to code JAGS to generate log densities rather than conditional samplers as they do in BUGS/JAGS). In Stan, we just declare constrained parameters and define a log density over them. That’s it (the Jacobian adjustment for the change of variables is kept under the hood). There are generated quantities, but that’s conceptually after sampling.
Like Stan, JAX is also a differentiable programming language. Unlike Stan, it’s wonderfully compositional and general.
Writing JAX models like Stan models
As much as people like to use NumPyro and sometimes even PyMC to generate JAX code, I think it may be easier in the end to just write JAX directly. That way, nothing gets between you and JAX and you don’t have to figure out how to filter JAX through middleware. When you do that, the models can be organized very much like in Stan.
Brian Ward and I took some time to work through what a simple linear regression would look like coded this way in JAX. I went over it a couple weeks ago with Andrew and he didn’t think it was too bad. Here’s the example.
GitHub Gist: linear regression in JAX.
In this example, we first do the constraining parameter transforms and extract the Jacobian, then define the model directly. Although we didn’t need it for this simple example, the Oryx library in JAX provides an extensive library of constraining transforms with Jacobians. It’s using the really cool PyTree features of JAX to move between structured log densities and array-based serialized log densities. This is sooo cool and the fact that it can all be compiled away is even cooler.
In JAX, there’s no distribution statement syntactic sugar, but then even Andrew thinks those were a mistake in Stan. I still like them, though I admit they’ve caused a lot of confusion in terms of people thinking about how Stan works. It’s odd to find myself on the more permissive side of language design discussion for once.
Generated quantities of the form used in Stan are trivial to code directly in JAX with vmap. Removing all these special constructs is super helpful for learnability, as is having the language embedded in Python (as much as Python is terrible for this kind of thing, much like R, because of its lack of static typing, its global interpreter lock, and it’s R-like scope, I believe it’s well on its way to becoming the lingua franca of numerical analysis.
Generating JAX from Stan?
People have asked if we were going to work on generating JAX code from Stan programs. I doubt it, given how easy it is to just define models directly in JAX and given how few dedicated developers we now have. The whole point of Stan was to provide a structured way to do derivatives for statistics models. We can just do that directly in JAX as the above gist shows.
Giving up working on Stan?
No, we’re not giving up on Stan. People still use BUGS! Stan’s going to keep being used for a long time if history is any indication. We have lots of strategies for making it faster, adding samplers that will work well on CPU but not GPU, etc. I don’t plan to be involved in coding for Stan any more. It’s just too complicated for me. My plan is to write standalone samplers like WALNUTS, following Adrian Seyboldt’s lead for Nutpie. If you’re OK with Python but haven’t tried Nutpie, I’d highly recommend it—it’s twice as fast as Stan and more robust due to its adaptation—I’m rolling that into the new WALNUTS code and maybe we’ll find the cycles to roll it into Stan itself after more testing.