[P] eqx-learn: Classical machine learning using JAX and Equinox
Hello everyone!
I am writing here to share a library I am currently developing for research use that filled a niche for me in the Equinox/JAX eco-system: eqx-learn.
I am using Equinox as the foundation for my radio-frequency modelling library ParamRF, and I have absolutely loved the mixed OO/functional style. However, for my research, I require classical ML models (specifically PCA and Gaussian Process Regression), but could not find an Equinox-native library in the ecosystem that was as straight-forward and consistent as scikit-learn.
eqx-learn aims to address this, with a JAX-based take on the scikit-learn API. All models in the library are ultimately Equinox Module’s, and can be fit using the library’s free “fit” function. The design is such that models simply “advertise” their capabilities by implementing specific methods (e.g. solve(X, y), condition(X, y), loss(), and the “fit” function then fits/trains the model accordingly. I believe that this de-coupling of capabilities vs fitting algorithm fits the JAX style better, and also has lots of potential.
At the moment, eqx-learn addresses all my research needs, but I thought it may be useful to share the library online to advertise that it exists, and mention that I am happy to accept PRs for additional models and fitting algorithms!
Although there are no docs, there are short examples in the repo :).
Happy coding!
Cheers, Gary
submitted by /u/gvcallen
[link] [comments]