Samplers¤
Jim supports several JAX sampler backends behind a unified interface.
You select one by passing a typed config object to Jim.
After jim.sample(), retrieve posterior samples with:
samples = jim.get_samples() # dict[str, np.ndarray] keyed by parameter name
Sampler overview¤
| Sampler | Algorithm | Evidence | Extra install | Prior constraint |
|---|---|---|---|---|
| flowMC | normalizing-flow-enhanced MCMC | No | No | None |
| NS-AW | Nested sampling (bilby/dynesty-style acceptance-walk) | Yes | Yes (nested-sampling) |
Uniform prior; unit-cube sampling space |
| NSS | Nested slice sampling | Yes | Yes (nested-sampling) |
Normalised prior |
| SMC | Sequential Monte Carlo | Yes | No | Normalised prior |
flowMC¤
flowMC runs parallel MCMC chains enhanced by a normalizing flow that learns the posterior shape during training, then uses that learned geometry to make global proposals during production.
from jimgw.core.jim import Jim
from jimgw.samplers.config import FlowMCConfig
jim = Jim(
likelihood,
prior,
sampler_config=FlowMCConfig(
n_chains=1000,
n_local_steps=100,
n_global_steps=1000,
n_training_loops=20,
n_production_loops=10,
),
)
jim.sample()
samples = jim.get_samples()
Key parameters:
n_chains— number of parallel MCMC chains.n_training_loops/n_production_loops— how many rounds of training (flow updates) and production (sample collection) to run.n_local_steps/n_global_steps— local MCMC steps and flow-proposal steps per loop.local_kernel— MCMC kernel for local steps; one of"MALA"(default),"HMC", or"GRW".parallel_tempering— parallel tempering settings; disabled by default. Enable withparallel_tempering=True(uses defaults), a plain dict of settings such as{"n_temperatures": 8}, or aParallelTemperingConfiginstance.
Repository: GW-JAX-Team/flowMC
References: Wong, K. W. K., Gabrié, M., Foreman-Mackey, D., "flowMC: Normalizing flow enhanced sampling package for probabilistic inference in JAX", arXiv:2211.06397, JOSS 8 (83) 5021 (2023). Wong, K. W. K., Isi, M., Edwards, T. D. P., "Fast Gravitational-wave Parameter Estimation without Compromises", arXiv:2302.05333, ApJ 958 129 (2023).
BlackJAX SMC¤
Sequential Monte Carlo (SMC) maintains a population of particles and gradually shifts them from the prior toward the posterior through a sequence of intermediate temperature steps.
Normalised-prior requirement — SMC computes a Bayesian evidence estimate and therefore requires a normalised prior. All built-in Jim priors are normalised. If you add custom constraints, check whether the resulting distribution is still normalised; if so, override
is_normalizedto returnTrue. Jim raises aValueErrorat construction if this condition is not met.
from jimgw.samplers.config import BlackJAXSMCConfig
jim = Jim(
likelihood,
prior,
sampler_config=BlackJAXSMCConfig(
n_particles=2000,
n_mcmc_steps_per_dim=100,
),
)
jim.sample()
samples = jim.get_samples()
Key parameters:
n_particles— particle ensemble size.n_mcmc_steps_per_dim— MCMC steps per dimension at each temperature step.target_ess_fraction— target ESS as a fraction ofn_particles(default0.9). The algorithm advances the temperature when the fraction of effectively contributing particles hits this threshold. Values in(0, 1]are valid whenpersistent_sampling=False; persistent sampling may exceed1.0because particles are recycled across steps. Only used with adaptive temperature selection (no effect with a fixedtemperature_ladder).target_ess— target ESS as an absolute particle count.target_ess_fractionandtarget_essare mutually exclusive; set one or the other, not both. Whenpersistent_sampling=False, must be<= n_particles.persistent_sampling— whether to retain particles from all temperature steps (defaultTrue).temperature_ladder— explicit temperature schedule. If given, the sampler advances through this fixed ladder and ignorestarget_ess_fractionandtarget_ess.
Repository: blackjax-devs/blackjax
BlackJAX nested samplers¤
The two BlackJAX nested-sampling backends require additional dependencies. They need a maintained fork of BlackJAX; install it with:
uv sync --group nested-sampling
This pulls in:
- blackjax — pinned to the
GW-JAX-Team/blackjax@jimbranch, which carries the BlackJAX nested-sampling module.
BlackJAX NS-AW¤
Nested sampling with a bilby/dynesty-style adaptive differential-evolution acceptance-walk inner kernel.
Unit-cube requirement — this sampler works in the unit hypercube
[0, 1]^n_dims. All parameters must be mapped into[0, 1]via sample transforms, which the CLI constructs automatically.
from jimgw.samplers.config import BlackJAXNSAWConfig
jim = Jim(
likelihood,
prior,
sampler_config=BlackJAXNSAWConfig(
n_live=1000,
n_delete_frac=0.5,
n_target=60,
max_mcmc=5000,
max_proposals=1000,
termination_dlogz=0.1,
),
sample_transforms=sample_transforms,
likelihood_transforms=likelihood_transforms,
)
jim.sample()
samples = jim.get_samples()
Key parameters:
n_live— number of live points; more live points → more accurate but slower.n_delete_frac— fraction of live points replaced per iteration (e.g.0.5replaces half the live points each step).n_target— target number of accepted proposals per walk.max_mcmc— maximum number of proposals before giving up on a dead point.
Reference: Prathaban, M., Yallup, D., Alvey, J., Yang, M., Templeton, W., Handley, W., "Gravitational-wave inference at GPU speed: A bilby-like nested sampling kernel within blackjax-ns", arXiv:2509.04336 (Sep 2025).
BlackJAX NSS¤
Nested sampling with a slice-sampling inner kernel. Unlike NS-AW, it does not require a unit-cube prior and works in any bounded sampling space.
Normalised-prior requirement — NSS computes a Bayesian evidence estimate and therefore requires a normalised prior. All built-in Jim priors are normalised. If you add custom constraints, check whether the resulting distribution is still normalised; if so, override
is_normalizedto returnTrue. Jim raises aValueErrorat construction if this condition is not met.
from jimgw.samplers.config import BlackJAXNSSConfig
jim = Jim(
likelihood,
prior,
sampler_config=BlackJAXNSSConfig(
n_live=1000,
n_delete_frac=0.5,
num_inner_steps_per_dim=20,
termination_dlogz=0.1,
),
sample_transforms=sample_transforms,
likelihood_transforms=likelihood_transforms,
)
jim.sample()
samples = jim.get_samples()
Key parameters:
n_live— number of live points.n_delete_frac— fraction of live points replaced per iteration.num_inner_steps_per_dim— slice-sampler steps per dimension per dead point; increase for strongly correlated posteriors.
Repository: handley-lab/blackjax
References: Yallup, D., Prathaban, M., Alvey, J., Handley, W., "Parallel Nested Slice Sampling for Gravitational Wave Parameter Estimation", arXiv:2509.24949 (Sep 2025). Yallup, D., Kroupa, N., Handley, W., "Nested Slice Sampling", OpenReview (2025).
Checkpointing and resuming¤
All samplers support checkpoint/resume so long-running jobs can survive interruptions.
Set checkpoint_dir to a directory and checkpoint_interval to the minimum wall-clock seconds between writes:
from jimgw.samplers.config import BlackJAXSMCConfig
jim = Jim(
likelihood,
prior,
sampler_config=BlackJAXSMCConfig(
n_particles=2000,
checkpoint_dir="./my_run",
checkpoint_interval=600, # write at most every 10 minutes
),
)
jim.sample()
The checkpoint is written atomically (checkpoint.pkl.tmp → checkpoint.pkl) so a mid-write crash never leaves a corrupt file.
To resume after an interruption, construct the same config pointing at the same checkpoint_dir and call jim.sample() again — the sampler detects the existing file and picks up from the last saved state:
# resume — identical config, same checkpoint_dir
jim = Jim(
likelihood,
prior,
sampler_config=BlackJAXSMCConfig(
n_particles=2000,
checkpoint_dir="./my_run",
checkpoint_interval=600,
),
)
jim.sample() # resumes from ./my_run/checkpoint.pkl
The same fields work identically for FlowMCConfig, BlackJAXNSAWConfig, and BlackJAXNSSConfig.
| Field | Default | Notes |
|---|---|---|
checkpoint_dir |
None (disabled) |
Directory where checkpoint.pkl is written. Created automatically if absent. |
checkpoint_interval |
0.0 (disabled) |
Minimum seconds between writes. 0 disables checkpointing entirely. |
Validation — setting
checkpoint_interval > 0withoutcheckpoint_dirraises aValidationErrorat config construction time.
When using the CLI, checkpointing is enabled automatically (600 s, writing to output.dir).
Set checkpoint_interval = 0 in the [sampler] block to opt out.
Periodic parameters¤
All samplers accept a periodic field to handle parameters that wrap around (e.g. angles).
Pass a dict of parameter_name: (lower, upper) bounds:
config = FlowMCConfig(
...,
periodic={"phase_c": (0.0, 6.283185), "psi": (0.0, 3.141593)},
)
BlackJAX NS-AW operates in [0, 1] per dimension, so its periodic field takes a plain list of parameter names:
config = BlackJAXNSAWConfig(
...,
periodic=["phase_c", "psi"],
)
Posterior samples¤
jim.get_samples() is the primary way to retrieve posterior samples.
samples = jim.get_samples()
# keys: prior parameter names + "log_likelihood"
samples["M_c"] # np.ndarray — chirp mass in prior space
samples["log_likelihood"] # np.ndarray — per-sample log-likelihood
Each backend's get_samples() returns equally-weighted posterior samples:
- NS-AW / NSS: uses anesthetic's
posterior_points()to resample the dead-point collection to equal-weight samples. - SMC (persistent): resamples all-temperature particles weighted by the persistent-sampling weights.
- SMC (non-persistent): returns all final-temperature particles.
- flowMC: returns all production samples across all chains.
Pass n_samples to jim.get_samples() to further downsample uniformly without replacement:
samples = jim.get_samples(n_samples=2000)
Run diagnostics¤
jim.get_diagnostics() is a thin wrapper around the sampler's own get_diagnostics(), which returns a plain dict[str, Any].
diag = jim.get_diagnostics()
diag["n_likelihood_evaluations"] # int — total number of likelihood calls
diag["sampling_time"] # float — wall-clock sampling time in seconds
Backend-specific keys:
# flowMC
diag["n_training_loops_actual"] # int — training loops run (may be less than configured)
diag["training_loss_history"] # np.ndarray — normalizing-flow loss per epoch
diag["acceptance_training_local"] # np.ndarray — local acceptance rate per training loop
diag["acceptance_training_global"] # np.ndarray — global acceptance rate per training loop
diag["acceptance_production_local"] # np.ndarray — local acceptance rate per production loop
diag["acceptance_production_global"] # np.ndarray — global acceptance rate per production loop
# NS-AW and NSS — also include evidence estimate
diag["n_iterations"] # int — number of nested-sampling steps
diag["log_Z"] # float — log Bayesian evidence
diag["log_Z_error"] # float — standard deviation from 100 bootstrap samples
# SMC
diag["acceptance_history"] # np.ndarray — mean acceptance rate at each step
diag["ess_history"] # np.ndarray — ESS at each step
# Adaptive temperature only:
diag["n_iterations"] # int — number of temperature steps
diag["tempering_schedule"] # np.ndarray — inverse temperature at each step
# Persistent sampling only:
diag["persistent_log_Z"] # np.ndarray — cumulative log Z after each step
diag["log_Z"] # float — final log Bayesian evidence
Writing your own sampler¤
This section is for advanced users who want to integrate a custom sampling backend with Jim. It requires familiarity with JAX and the Jim sampler internals.
Subclass Sampler, implement three methods, and register it:
_sample(rng_key, initial_position)— run the sampler and store results. The base class wraps this insample(), which also recordssampling_time.get_samples()— return a dict with"samples"and"log_likelihood"keys._get_diagnostics()— return a plain dict with diagnostic information. The base class wraps this inget_diagnostics(), which injectssampling_time.
from typing import Any, Literal, Optional
import numpy as np
from jimgw.samplers import register_sampler
from jimgw.samplers.base import Sampler
from jimgw.samplers.config import BaseSamplerConfig
class MyConfig(BaseSamplerConfig):
type: Literal["my-sampler"] = "my-sampler"
n_steps: int = 1000
class MySampler(Sampler):
_config: MyConfig
def __init__(self, *, n_dims, log_prior_fn, log_likelihood_fn,
log_posterior_fn, config: Optional[MyConfig] = None,
parameter_names=(), periodic=None):
if config is None:
config = MyConfig()
super().__init__(n_dims=n_dims, log_prior_fn=log_prior_fn,
log_likelihood_fn=log_likelihood_fn,
log_posterior_fn=log_posterior_fn, config=config)
self._result = None
def _sample(self, rng_key, initial_position) -> None:
# initial_position: shape (n_chains, n_dims), drawn from the prior by Jim.
# ... run your sampler for self._config.n_steps steps ...
self._result = np.asarray(initial_position)
def get_samples(self) -> dict[str, np.ndarray]:
if self._result is None:
raise RuntimeError("call sample() first")
return {
"samples": self._result,
"log_likelihood": np.zeros(self._result.shape[0]),
}
def _get_diagnostics(self) -> dict[str, Any]:
if self._result is None:
raise RuntimeError("call sample() first")
return {
"n_likelihood_evaluations": self._config.n_steps,
}
register_sampler("my-sampler", lambda: MySampler)
Then pass MyConfig(n_steps=500) as sampler_config to Jim.