12 FAX: Scalable and Differentiable Federated Primitives in JAX We present FAX, a JAX-based library designed to support large-scale distributed and federated computations in both data center and cross-device applications. FAX leverages JAX's sharding mechanisms to enable native targeting of TPUs and state-of-the-art JAX runtimes, including Pathways. FAX embeds building blocks for federated computations as primitives in JAX. This enables three key benefits. First, FAX computations can be translated to XLA HLO. Second, FAX provides a full implementation of federated automatic differentiation, greatly simplifying the expression of federated computations. Last, FAX computations can be interpreted out to existing production cross-device federated compute systems. We show that FAX provides an easily programmable, performant, and scalable framework for federated computations in the data center. FAX is available at https://github.com/google-research/google-research/tree/master/fax . 3 authors · Mar 11, 2024 2
1 CLAX: Fast and Flexible Neural Click Models in JAX CLAX is a JAX-based library that implements classic click models using modern gradient-based optimization. While neural click models have emerged over the past decade, complex click models based on probabilistic graphical models (PGMs) have not systematically adopted gradient-based optimization, preventing practitioners from leveraging modern deep learning frameworks while preserving the interpretability of classic models. CLAX addresses this gap by replacing EM-based optimization with direct gradient-based optimization in a numerically stable manner. The framework's modular design enables the integration of any component, from embeddings and deep networks to custom modules, into classic click models for end-to-end optimization. We demonstrate CLAX's efficiency by running experiments on the full Baidu-ULTR dataset comprising over a billion user sessions in approx 2 hours on a single GPU, orders of magnitude faster than traditional EM approaches. CLAX implements ten classic click models, serving both industry practitioners seeking to understand user behavior and improve ranking performance at scale and researchers developing new click models. CLAX is available at: https://github.com/philipphager/clax 3 authors · Nov 5, 2025
2 A projection-based framework for gradient-free and parallel learning We present a feasibility-seeking approach to neural network training. This mathematical optimization framework is distinct from conventional gradient-based loss minimization and uses projection operators and iterative projection algorithms. We reformulate training as a large-scale feasibility problem: finding network parameters and states that satisfy local constraints derived from its elementary operations. Training then involves projecting onto these constraints, a local operation that can be parallelized across the network. We introduce PJAX, a JAX-based software framework that enables this paradigm. PJAX composes projection operators for elementary operations, automatically deriving the solution operators for the feasibility problems (akin to autodiff for derivatives). It inherently supports GPU/TPU acceleration, provides a familiar NumPy-like API, and is extensible. We train diverse architectures (MLPs, CNNs, RNNs) on standard benchmarks using PJAX, demonstrating its functionality and generality. Our results show that this approach is as a compelling alternative to gradient-based training, with clear advantages in parallelism and the ability to handle non-differentiable operations. 4 authors · Jun 6, 2025
8 JaxMARL: Multi-Agent RL Environments in JAX Benchmarks play an important role in the development of machine learning algorithms. For example, research in reinforcement learning (RL) has been heavily influenced by available environments and benchmarks. However, RL environments are traditionally run on the CPU, limiting their scalability with typical academic compute. Recent advancements in JAX have enabled the wider use of hardware acceleration to overcome these computational hurdles, enabling massively parallel RL training pipelines and environments. This is particularly useful for multi-agent reinforcement learning (MARL) research. First of all, multiple agents must be considered at each environment step, adding computational burden, and secondly, the sample complexity is increased due to non-stationarity, decentralised partial observability, or other MARL challenges. In this paper, we present JaxMARL, the first open-source code base that combines ease-of-use with GPU enabled efficiency, and supports a large number of commonly used MARL environments as well as popular baseline algorithms. When considering wall clock time, our experiments show that per-run our JAX-based training pipeline is up to 12500x faster than existing approaches. This enables efficient and thorough evaluations, with the potential to alleviate the evaluation crisis of the field. We also introduce and benchmark SMAX, a vectorised, simplified version of the popular StarCraft Multi-Agent Challenge, which removes the need to run the StarCraft II game engine. This not only enables GPU acceleration, but also provides a more flexible MARL environment, unlocking the potential for self-play, meta-learning, and other future applications in MARL. We provide code at https://github.com/flairox/jaxmarl. 20 authors · Nov 16, 2023
- Single-seed generation of Brownian paths and integrals for adaptive and high order SDE solvers Despite the success of adaptive time-stepping in ODE simulation, it has so far seen few applications for Stochastic Differential Equations (SDEs). To simulate SDEs adaptively, methods such as the Virtual Brownian Tree (VBT) have been developed, which can generate Brownian motion (BM) non-chronologically. However, in most applications, knowing only the values of Brownian motion is not enough to achieve a high order of convergence; for that, we must compute time-integrals of BM such as int_s^t W_r , dr. With the aim of using high order SDE solvers adaptively, we extend the VBT to generate these integrals of BM in addition to the Brownian increments. A JAX-based implementation of our construction is included in the popular Diffrax library (https://github.com/patrick-kidger/diffrax). Since the entire Brownian path produced by VBT is uniquely determined by a single PRNG seed, previously generated samples need not be stored, which results in a constant memory footprint and enables experiment repeatability and strong error estimation. Based on binary search, the VBT's time complexity is logarithmic in the tolerance parameter varepsilon. Unlike the original VBT algorithm, which was only precise at some dyadic times, we prove that our construction exactly matches the joint distribution of the Brownian motion and its time integrals at any query times, provided they are at least varepsilon apart. We present two applications of adaptive high order solvers enabled by our new VBT. Using adaptive solvers to simulate a high-volatility CIR model, we achieve more than twice the convergence order of constant stepping. We apply an adaptive third order underdamped or kinetic Langevin solver to an MCMC problem, where our approach outperforms the No U-Turn Sampler, while using only a tenth of its function evaluations. 3 authors · May 10, 2024
1 Neural network emulator to constrain the high-$z$ IGM thermal state from Lyman-$α$ forest flux auto-correlation function We present a neural network emulator to constrain the thermal parameters of the intergalactic medium (IGM) at 5.4z6.0 using the Lyman-displaystylealpha (Lydisplaystylealpha) forest flux auto-correlation function. Our auto-differentiable JAX-based framework accelerates the surrogate model generation process using approximately 100 sparsely sampled Nyx hydrodynamical simulations with varying combinations of thermal parameters, i.e., the temperature at mean density T_{{0}}, the slope of the temperaturedisplaystyle-density relation displaystylegamma, and the mean transmission flux langle{F}{rangle}. We show that this emulator has a typical accuracy of 1.0% across the specified redshift range. Bayesian inference of the IGM thermal parameters, incorporating emulator uncertainty propagation, is further expedited using NumPyro Hamiltonian Monte Carlo. We compare both the inference results and computational cost of our framework with the traditional nearest-neighbor interpolation approach applied to the same set of mock Lyalpha flux. By examining the credibility contours of the marginalized posteriors for T_{{0}},gamma,and{langle}{F}{rangle} obtained using the emulator, the statistical reliability of measurements is established through inference on 100 realistic mock data sets of the auto-correlation function. 4 authors · Oct 8, 2024
1 GWKokab: An Implementation to Identify the Properties of Multiple Population of Gravitational Wave Sources The rapidly increasing sensitivity of gravitational wave detectors is enabling the detection of a growing number of compact binary mergers. These events are crucial for understanding the population properties of compact binaries. However, many previous studies rely on computationally expensive inference frameworks, limiting their scalability. In this work, we present GWKokab, a JAX-based framework that enables modular model building with independent rate for each subpopulation such as BBH, BNS, and NSBH binaries. It provides accelerated inference using the normalizing flow based sampler called flowMC and is also compatible with NumPyro samplers. To validate our framework, we generated two synthetic populations, one comprising spinning eccentric binaries and the other circular binaries using a multi-source model. We then recovered their injected parameters at significantly reduced computational cost and demonstrated that eccentricity distribution can be recovered even in spinning eccentric populations. We also reproduced results from two prior studies: one on non-spinning eccentric populations, and another on the BBH mass distribution using the third Gravitational Wave Transient Catalog (GWTC-3). We anticipate that GWKokab will not only reduce computational costs but also enable more detailed subpopulation analyses such as their mass, spin, eccentricity, and redshift distributions in gravitational wave events, offering deeper insights into compact binary formation and evolution. 3 authors · Sep 16, 2025
1 Grad DFT: a software library for machine learning enhanced density functional theory Density functional theory (DFT) stands as a cornerstone method in computational quantum chemistry and materials science due to its remarkable versatility and scalability. Yet, it suffers from limitations in accuracy, particularly when dealing with strongly correlated systems. To address these shortcomings, recent work has begun to explore how machine learning can expand the capabilities of DFT; an endeavor with many open questions and technical challenges. In this work, we present Grad DFT: a fully differentiable JAX-based DFT library, enabling quick prototyping and experimentation with machine learning-enhanced exchange-correlation energy functionals. Grad DFT employs a pioneering parametrization of exchange-correlation functionals constructed using a weighted sum of energy densities, where the weights are determined using neural networks. Moreover, Grad DFT encompasses a comprehensive suite of auxiliary functions, notably featuring a just-in-time compilable and fully differentiable self-consistent iterative procedure. To support training and benchmarking efforts, we additionally compile a curated dataset of experimental dissociation energies of dimers, half of which contain transition metal atoms characterized by strong electronic correlations. The software library is tested against experimental results to study the generalization capabilities of a neural functional across potential energy surfaces and atomic species, as well as the effect of training data noise on the resulting model accuracy. 5 authors · Sep 22, 2023
- A differentiable binary microlensing model using adaptive contour integration method We present microlux, which is a Jax-based code that can compute the binary microlensing light curve and its derivatives both efficiently and accurately. The key feature of microlux is the implementation of a modified version of the adaptive sampling algorithm that was originally proposed by V. Bozza to account for the finite-source effect most efficiently. The efficiency and accuracy of microlux have been verified across the relevant parameter space for binary microlensing. As a differentiable code, microlux makes it possible to apply gradient-based algorithms to the search and posterior estimation of the microlensing modeling. As an example, we use microlux to model a real microlensing event and infer the model posterior via both Fisher information matrix and Hamiltonian Monte Carlo, neither of which would have been possible without the access to accurate model gradients. 2 authors · Jan 13, 2025
- Craftax: A Lightning-Fast Benchmark for Open-Ended Reinforcement Learning Benchmarks play a crucial role in the development and analysis of reinforcement learning (RL) algorithms. We identify that existing benchmarks used for research into open-ended learning fall into one of two categories. Either they are too slow for meaningful research to be performed without enormous computational resources, like Crafter, NetHack and Minecraft, or they are not complex enough to pose a significant challenge, like Minigrid and Procgen. To remedy this, we first present Craftax-Classic: a ground-up rewrite of Crafter in JAX that runs up to 250x faster than the Python-native original. A run of PPO using 1 billion environment interactions finishes in under an hour using only a single GPU and averages 90% of the optimal reward. To provide a more compelling challenge we present the main Craftax benchmark, a significant extension of the Crafter mechanics with elements inspired from NetHack. Solving Craftax requires deep exploration, long term planning and memory, as well as continual adaptation to novel situations as more of the world is discovered. We show that existing methods including global and episodic exploration, as well as unsupervised environment design fail to make material progress on the benchmark. We believe that Craftax can for the first time allow researchers to experiment in a complex, open-ended environment with limited computational resources. 7 authors · Feb 26, 2024
9 Einstein Fields: A Neural Perspective To Computational General Relativity We introduce Einstein Fields, a neural representation that is designed to compress computationally intensive four-dimensional numerical relativity simulations into compact implicit neural network weights. By modeling the metric, which is the core tensor field of general relativity, Einstein Fields enable the derivation of physical quantities via automatic differentiation. However, unlike conventional neural fields (e.g., signed distance, occupancy, or radiance fields), Einstein Fields are Neural Tensor Fields with the key difference that when encoding the spacetime geometry of general relativity into neural field representations, dynamics emerge naturally as a byproduct. Einstein Fields show remarkable potential, including continuum modeling of 4D spacetime, mesh-agnosticity, storage efficiency, derivative accuracy, and ease of use. We address these challenges across several canonical test beds of general relativity and release an open source JAX-based library, paving the way for more scalable and expressive approaches to numerical relativity. Code is made available at https://github.com/AndreiB137/EinFields 4 authors · Jul 15, 2025 1
- LAST: Scalable Lattice-Based Speech Modelling in JAX We introduce LAST, a LAttice-based Speech Transducer library in JAX. With an emphasis on flexibility, ease-of-use, and scalability, LAST implements differentiable weighted finite state automaton (WFSA) algorithms needed for training \& inference that scale to a large WFSA such as a recognition lattice over the entire utterance. Despite these WFSA algorithms being well-known in the literature, new challenges arise from performance characteristics of modern architectures, and from nuances in automatic differentiation. We describe a suite of generally applicable techniques employed in LAST to address these challenges, and demonstrate their effectiveness with benchmarks on TPUv3 and V100 GPU. 4 authors · Apr 25, 2023
- SCENIC: A JAX Library for Computer Vision Research and Beyond Scenic is an open-source JAX library with a focus on Transformer-based models for computer vision research and beyond. The goal of this toolkit is to facilitate rapid experimentation, prototyping, and research of new vision architectures and models. Scenic supports a diverse range of vision tasks (e.g., classification, segmentation, detection)and facilitates working on multi-modal problems, along with GPU/TPU support for multi-host, multi-device large-scale training. Scenic also offers optimized implementations of state-of-the-art research models spanning a wide range of modalities. Scenic has been successfully used for numerous projects and published papers and continues serving as the library of choice for quick prototyping and publication of new research ideas. 5 authors · Oct 18, 2021 1
- Consistent Sampling and Simulation: Molecular Dynamics with Energy-Based Diffusion Models In recent years, diffusion models trained on equilibrium molecular distributions have proven effective for sampling biomolecules. Beyond direct sampling, the score of such a model can also be used to derive the forces that act on molecular systems. However, while classical diffusion sampling usually recovers the training distribution, the corresponding energy-based interpretation of the learned score is often inconsistent with this distribution, even for low-dimensional toy systems. We trace this inconsistency to inaccuracies of the learned score at very small diffusion timesteps, where the model must capture the correct evolution of the data distribution. In this regime, diffusion models fail to satisfy the Fokker--Planck equation, which governs the evolution of the score. We interpret this deviation as one source of the observed inconsistencies and propose an energy-based diffusion model with a Fokker--Planck-derived regularization term to enforce consistency. We demonstrate our approach by sampling and simulating multiple biomolecular systems, including fast-folding proteins, and by introducing a state-of-the-art transferable Boltzmann emulator for dipeptides that supports simulation and achieves improved consistency and efficient sampling. Our code, model weights, and self-contained JAX and PyTorch notebooks are available at https://github.com/noegroup/ScoreMD. 5 authors · Jun 20, 2025
1 SequeL: A Continual Learning Library in PyTorch and JAX Continual Learning is an important and challenging problem in machine learning, where models must adapt to a continuous stream of new data without forgetting previously acquired knowledge. While existing frameworks are built on PyTorch, the rising popularity of JAX might lead to divergent codebases, ultimately hindering reproducibility and progress. To address this problem, we introduce SequeL, a flexible and extensible library for Continual Learning that supports both PyTorch and JAX frameworks. SequeL provides a unified interface for a wide range of Continual Learning algorithms, including regularization-based approaches, replay-based approaches, and hybrid approaches. The library is designed towards modularity and simplicity, making the API suitable for both researchers and practitioners. We release SequeL\url{https://github.com/nik-dim/sequel} as an open-source library, enabling researchers and developers to easily experiment and extend the library for their own purposes. 3 authors · Apr 21, 2023
- Locality-Aware Automatic Differentiation on the GPU for Mesh-Based Computations We present a high-performance system for automatic differentiation (AD) of functions defined on triangle meshes that exploits the inherent sparsity and locality of mesh-based energy functions to achieve fast gradient and Hessian computation on the GPU. Our system is designed around per-element forward-mode differentiation, enabling all local computations to remain in GPU registers or shared memory. Unlike reverse-mode approaches that construct and traverse global computation graphs, our method performs differentiation on the fly, minimizing memory traffic and avoiding global synchronization. Our programming model allows users to define local energy terms while the system handles parallel evaluation, derivative computation, and sparse Hessian assembly. We benchmark our system on a range of applications--cloth simulation, surface parameterization, mesh smoothing, and spherical manifold optimization. We achieve a geometric mean speedup of 6.2x over optimized PyTorch implementations for second-order derivatives, and 2.76x speedup for Hessian-vector products. For first-order derivatives, our system is 6.38x, 2.89x, and 1.98x faster than Warp, JAX, and Dr.JIT, respectively, while remaining on par with hand-written derivatives. 3 authors · Aug 30, 2025
- Multilingual Universal Sentence Encoder for Semantic Retrieval We introduce two pre-trained retrieval focused multilingual sentence encoding models, respectively based on the Transformer and CNN model architectures. The models embed text from 16 languages into a single semantic space using a multi-task trained dual-encoder that learns tied representations using translation based bridge tasks (Chidambaram al., 2018). The models provide performance that is competitive with the state-of-the-art on: semantic retrieval (SR), translation pair bitext retrieval (BR) and retrieval question answering (ReQA). On English transfer learning tasks, our sentence-level embeddings approach, and in some cases exceed, the performance of monolingual, English only, sentence embedding models. Our models are made available for download on TensorFlow Hub. 12 authors · Jul 9, 2019
4 Make-It-Poseable: Feed-forward Latent Posing Model for 3D Humanoid Character Animation Posing 3D characters is a fundamental task in computer graphics and vision. However, existing methods like auto-rigging and pose-conditioned generation often struggle with challenges such as inaccurate skinning weight prediction, topological imperfections, and poor pose conformance, limiting their robustness and generalizability. To overcome these limitations, we introduce Make-It-Poseable, a novel feed-forward framework that reformulates character posing as a latent-space transformation problem. Instead of deforming mesh vertices as in traditional pipelines, our method reconstructs the character in new poses by directly manipulating its latent representation. At the core of our method is a latent posing transformer that manipulates shape tokens based on skeletal motion. This process is facilitated by a dense pose representation for precise control. To ensure high-fidelity geometry and accommodate topological changes, we also introduce a latent-space supervision strategy and an adaptive completion module. Our method demonstrates superior performance in posing quality. It also naturally extends to 3D editing applications like part replacement and refinement. Tencent · Dec 18, 2025 2
- Jumanji: a Diverse Suite of Scalable Reinforcement Learning Environments in JAX Open-source reinforcement learning (RL) environments have played a crucial role in driving progress in the development of AI algorithms. In modern RL research, there is a need for simulated environments that are performant, scalable, and modular to enable their utilization in a wider range of potential real-world applications. Therefore, we present Jumanji, a suite of diverse RL environments specifically designed to be fast, flexible, and scalable. Jumanji provides a suite of environments focusing on combinatorial problems frequently encountered in industry, as well as challenging general decision-making tasks. By leveraging the efficiency of JAX and hardware accelerators like GPUs and TPUs, Jumanji enables rapid iteration of research ideas and large-scale experimentation, ultimately empowering more capable agents. Unlike existing RL environment suites, Jumanji is highly customizable, allowing users to tailor the initial state distribution and problem complexity to their needs. Furthermore, we provide actor-critic baselines for each environment, accompanied by preliminary findings on scaling and generalization scenarios. Jumanji aims to set a new standard for speed, adaptability, and scalability of RL environments. 24 authors · Jun 16, 2023