new

Get trending papers in your email inbox!

Subscribe

Daily Papers

byAK and the research community

Jan 5

JaxMARL: Multi-Agent RL Environments in JAX

Benchmarks play an important role in the development of machine learning algorithms. For example, research in reinforcement learning (RL) has been heavily influenced by available environments and benchmarks. However, RL environments are traditionally run on the CPU, limiting their scalability with typical academic compute. Recent advancements in JAX have enabled the wider use of hardware acceleration to overcome these computational hurdles, enabling massively parallel RL training pipelines and environments. This is particularly useful for multi-agent reinforcement learning (MARL) research. First of all, multiple agents must be considered at each environment step, adding computational burden, and secondly, the sample complexity is increased due to non-stationarity, decentralised partial observability, or other MARL challenges. In this paper, we present JaxMARL, the first open-source code base that combines ease-of-use with GPU enabled efficiency, and supports a large number of commonly used MARL environments as well as popular baseline algorithms. When considering wall clock time, our experiments show that per-run our JAX-based training pipeline is up to 12500x faster than existing approaches. This enables efficient and thorough evaluations, with the potential to alleviate the evaluation crisis of the field. We also introduce and benchmark SMAX, a vectorised, simplified version of the popular StarCraft Multi-Agent Challenge, which removes the need to run the StarCraft II game engine. This not only enables GPU acceleration, but also provides a more flexible MARL environment, unlocking the potential for self-play, meta-learning, and other future applications in MARL. We provide code at https://github.com/flairox/jaxmarl.

  • 20 authors
·
Nov 16, 2023

Single-seed generation of Brownian paths and integrals for adaptive and high order SDE solvers

Despite the success of adaptive time-stepping in ODE simulation, it has so far seen few applications for Stochastic Differential Equations (SDEs). To simulate SDEs adaptively, methods such as the Virtual Brownian Tree (VBT) have been developed, which can generate Brownian motion (BM) non-chronologically. However, in most applications, knowing only the values of Brownian motion is not enough to achieve a high order of convergence; for that, we must compute time-integrals of BM such as int_s^t W_r , dr. With the aim of using high order SDE solvers adaptively, we extend the VBT to generate these integrals of BM in addition to the Brownian increments. A JAX-based implementation of our construction is included in the popular Diffrax library (https://github.com/patrick-kidger/diffrax). Since the entire Brownian path produced by VBT is uniquely determined by a single PRNG seed, previously generated samples need not be stored, which results in a constant memory footprint and enables experiment repeatability and strong error estimation. Based on binary search, the VBT's time complexity is logarithmic in the tolerance parameter varepsilon. Unlike the original VBT algorithm, which was only precise at some dyadic times, we prove that our construction exactly matches the joint distribution of the Brownian motion and its time integrals at any query times, provided they are at least varepsilon apart. We present two applications of adaptive high order solvers enabled by our new VBT. Using adaptive solvers to simulate a high-volatility CIR model, we achieve more than twice the convergence order of constant stepping. We apply an adaptive third order underdamped or kinetic Langevin solver to an MCMC problem, where our approach outperforms the No U-Turn Sampler, while using only a tenth of its function evaluations.

  • 3 authors
·
May 10, 2024

GWKokab: An Implementation to Identify the Properties of Multiple Population of Gravitational Wave Sources

The rapidly increasing sensitivity of gravitational wave detectors is enabling the detection of a growing number of compact binary mergers. These events are crucial for understanding the population properties of compact binaries. However, many previous studies rely on computationally expensive inference frameworks, limiting their scalability. In this work, we present GWKokab, a JAX-based framework that enables modular model building with independent rate for each subpopulation such as BBH, BNS, and NSBH binaries. It provides accelerated inference using the normalizing flow based sampler called flowMC and is also compatible with NumPyro samplers. To validate our framework, we generated two synthetic populations, one comprising spinning eccentric binaries and the other circular binaries using a multi-source model. We then recovered their injected parameters at significantly reduced computational cost and demonstrated that eccentricity distribution can be recovered even in spinning eccentric populations. We also reproduced results from two prior studies: one on non-spinning eccentric populations, and another on the BBH mass distribution using the third Gravitational Wave Transient Catalog (GWTC-3). We anticipate that GWKokab will not only reduce computational costs but also enable more detailed subpopulation analyses such as their mass, spin, eccentricity, and redshift distributions in gravitational wave events, offering deeper insights into compact binary formation and evolution.

  • 3 authors
·
Sep 16, 2025

Grad DFT: a software library for machine learning enhanced density functional theory

Density functional theory (DFT) stands as a cornerstone method in computational quantum chemistry and materials science due to its remarkable versatility and scalability. Yet, it suffers from limitations in accuracy, particularly when dealing with strongly correlated systems. To address these shortcomings, recent work has begun to explore how machine learning can expand the capabilities of DFT; an endeavor with many open questions and technical challenges. In this work, we present Grad DFT: a fully differentiable JAX-based DFT library, enabling quick prototyping and experimentation with machine learning-enhanced exchange-correlation energy functionals. Grad DFT employs a pioneering parametrization of exchange-correlation functionals constructed using a weighted sum of energy densities, where the weights are determined using neural networks. Moreover, Grad DFT encompasses a comprehensive suite of auxiliary functions, notably featuring a just-in-time compilable and fully differentiable self-consistent iterative procedure. To support training and benchmarking efforts, we additionally compile a curated dataset of experimental dissociation energies of dimers, half of which contain transition metal atoms characterized by strong electronic correlations. The software library is tested against experimental results to study the generalization capabilities of a neural functional across potential energy surfaces and atomic species, as well as the effect of training data noise on the resulting model accuracy.

  • 5 authors
·
Sep 22, 2023

Consistent Sampling and Simulation: Molecular Dynamics with Energy-Based Diffusion Models

In recent years, diffusion models trained on equilibrium molecular distributions have proven effective for sampling biomolecules. Beyond direct sampling, the score of such a model can also be used to derive the forces that act on molecular systems. However, while classical diffusion sampling usually recovers the training distribution, the corresponding energy-based interpretation of the learned score is often inconsistent with this distribution, even for low-dimensional toy systems. We trace this inconsistency to inaccuracies of the learned score at very small diffusion timesteps, where the model must capture the correct evolution of the data distribution. In this regime, diffusion models fail to satisfy the Fokker--Planck equation, which governs the evolution of the score. We interpret this deviation as one source of the observed inconsistencies and propose an energy-based diffusion model with a Fokker--Planck-derived regularization term to enforce consistency. We demonstrate our approach by sampling and simulating multiple biomolecular systems, including fast-folding proteins, and by introducing a state-of-the-art transferable Boltzmann emulator for dipeptides that supports simulation and achieves improved consistency and efficient sampling. Our code, model weights, and self-contained JAX and PyTorch notebooks are available at https://github.com/noegroup/ScoreMD.

  • 5 authors
·
Jun 20, 2025