Skip to content

Hyperparameter Reference¤

Quick-reference index by category. Click any parameter name to jump to its description.

Required¤

Parameter Description
rng_key JAX PRNG key
n_chains Number of parallel chains
n_dims Dimensionality of the parameter space
logpdf Log-density function
n_local_steps Local MCMC steps per loop
n_global_steps NF proposal steps per loop
n_training_loops Number of training phase loops
n_production_loops Number of production phase loops
n_epochs Training epochs per loop

Local sampler — MALA¤

Parameter Default Description
mala_step_size 1e-1 Initial MALA step size
adapt_step_size True Auto-adapt step size
periodic None Periodic dimensions

Local sampler — HMC¤

Parameter Default Description
hmc_step_size 0.1 Initial leapfrog step size
hmc_n_leapfrog 10 Leapfrog steps per proposal
condition_matrix 1 Diagonal inverse mass matrix
adapt_step_size True Auto-adapt step size
periodic None Periodic dimensions

Local sampler — Gaussian random walk¤

Parameter Default Description
grw_step_size 1e-1 Initial random walk step size
adapt_step_size True Auto-adapt step size
periodic None Periodic dimensions

Normalizing flow¤

Parameter Default Description
rq_spline_n_layers 4 Number of coupling layers
rq_spline_hidden_units [32, 32] Hidden layer widths
rq_spline_n_bins 8 Spline bins per layer
n_NFproposal_batch_size 10000 NF sampling batch size

Training¤

Parameter Default Description
learning_rate 1e-3 AdamW learning rate
batch_size 10000 Training mini-batch size
n_max_examples 10000 Max samples in training buffer
history_window 100 Most-recent steps per chain used as NF training data

Execution¤

Parameter Default Description
chain_batch_size 0 Split vmap over chains
local_thinning 1 Thinning for local samples
global_thinning 1 Thinning for global samples
verbose False Print progress

Early stopping¤

Parameter Default Description
early_stopping False Enable early stopping
early_stopping_tolerance 0.05 Relative acceptance-rate change threshold
early_stopping_patience 3 Consecutive loops to confirm stability
early_stopping_min_acceptance 0.1 Acceptance rate that triggers early stopping

Parallel tempering (PT bundles only)¤

Parameter Default Description
logprior flat Log-prior function
n_temperatures 5 Number of temperature levels
max_temperature 5.0 Highest temperature
n_tempered_steps -1 Swap attempts per loop

Required parameters¤

rng_key¤

A JAX PRNG key used to initialise all random state inside the bundle (chain positions, model weights, optimizer state).

n_chains¤

Number of parallel chains. The algorithm parallelises all chains simultaneously via vmap, so the performance benefit saturates once the chains fill the available hardware threads. More chains give a better picture of the global landscape, which helps train the normalizing flow faster and reduces the chance of mode collapse.

n_dims¤

Dimensionality of the parameter space. Must match the input dimension of logpdf.

logpdf¤

A callable with signature logpdf(x, data) -> float where x is a 1-D array of length n_dims and data is the auxiliary data passed to Sampler.sample. The function must be JAX-differentiable if you use MALA or HMC.

n_local_steps¤

Number of local MCMC steps executed per chain in each training and production loop. Together with n_chains this controls how many samples are added to the training buffer per loop.

n_global_steps¤

Number of normalizing-flow proposal steps executed per chain in each loop. These are independent-MH moves drawn from the current flow, so they are most effective once the flow has been trained for a few loops.

n_training_loops¤

Number of local–global–train cycles to run in the training phase. During each cycle the local sampler runs, the flow is trained on the accumulated samples, and the global proposer runs. Increasing this gives the flow more opportunities to improve, at the cost of a longer warmup.

n_production_loops¤

Number of local–global cycles to run in the production phase. The flow is not updated during production, so detailed balance is restored and standard MCMC convergence diagnostics apply. All production samples are stored.

n_epochs¤

Number of gradient-descent epochs per training step. Higher values improve the flow fit but increase training time per loop.


Local sampler parameters¤

mala_step_size¤

MALA bundles only. Initial step size for MALA. Default 1e-1. Accepts either a scalar (broadcast to all dimensions) or a 1-D array of length n_dims to set a different initial step size per dimension. The optimal MALA acceptance rate is ~57%, and adapt_step_size will drive the step size toward that target automatically.

hmc_step_size¤

HMC bundles only. Initial leapfrog step size. Default 0.1. The optimal HMC acceptance rate is ~65%.

hmc_n_leapfrog¤

HMC bundles only. Number of leapfrog steps per HMC proposal. Default 10. More steps produce proposals further along the trajectory but increase cost per step.

condition_matrix¤

HMC bundles only. Diagonal of the inverse mass matrix (preconditioning matrix). Default 1 (scalar, broadcast to all dimensions). Pass a 1-D array of length n_dims to set per-dimension scales. Useful when parameters have very different characteristic scales.

grw_step_size¤

GRW bundles only. Initial step size for the Gaussian random walk proposal. Default 1e-1. Accepts either a scalar (broadcast to all dimensions) or a 1-D array of length n_dims to set a different initial step size per dimension. The optimal random-walk acceptance rate is ~23%.

adapt_step_size¤

Whether to automatically adapt the local step size during the training phase. Default True. The step size is updated after each training loop using a simple multiplicative rule targeting the algorithm-specific optimal acceptance rate. Adaptation is disabled during the production phase.

periodic¤

A dict[int, tuple[float, float]] mapping dimension indices to (lower, upper) bounds for periodic dimensions. Default None (no periodic dimensions). Proposals in periodic dimensions are automatically wrapped into [lower, upper].

# e.g. dim 0 is periodic on [0, 2π] and dim 2 on [-1, 1]
periodic = {0: (0.0, 6.283), 2: (-1.0, 1.0)}

Normalizing flow parameters¤

rq_spline_n_layers¤

Number of masked coupling layers in the RQ-spline normalizing flow. Default 4. More layers increase expressive power at the cost of memory and training time.

rq_spline_hidden_units¤

List of hidden layer widths for the conditioner MLPs inside each coupling layer. Default [32, 32]. Increasing the width or depth improves the flow's capacity to represent complex distributions.

rq_spline_n_bins¤

Number of bins for the rational-quadratic spline bijection. Default 8. More bins give a finer piecewise approximation.

n_NFproposal_batch_size¤

Batch size used when running the NF proposal over many steps. Default 10000. When n_global_steps exceeds this value, the flow samples are computed in chunks of this size using jax.lax.map rather than a single vmap, reducing peak memory.


Training parameters¤

learning_rate¤

Learning rate for the AdamW optimizer used to train the normalizing flow. Default 1e-3.

batch_size¤

Mini-batch size for each gradient step during flow training. Default 10000. The training loss is computed on random mini-batches of the accumulated samples. Within the memory budget, larger batches tend to be better because the training dataset is continuously evolving and overfitting is not a concern.

n_max_examples¤

Maximum number of samples retained in the training buffer. Default 10000. When the buffer exceeds this size, only the most recent n_max_examples samples are kept for training. Choose this based on your device's memory capacity. Setting it too low risks mode collapse, as the flow may forget regions of the posterior not recently visited by the chains.

history_window¤

Number of most-recent position steps per chain used when selecting training data for the normalizing flow. Default 100. At each training step, only the last history_window positions from each chain's buffer are considered; n_max_examples samples are then drawn from that pool. Reducing this value focuses the flow on the current mode but risks forgetting regions not recently visited.


Execution parameters¤

chain_batch_size¤

If non-zero, splits the vmap over chains into sequential batches of this size. Default 0 (no splitting; all chains are vmapped at once). Use this when vmapping over all chains simultaneously exceeds device memory.

local_thinning¤

Store every local_thinning-th position from the local sampler. Default 1 (store all). All n_local_steps are always executed — thinning only controls which steps are written to the buffer. Increasing it reduces memory and the number of training samples available to the normalizing flow, but also reduces autocorrelation in the stored samples. The acceptance rate stored per slot is the mean over the corresponding thinning window. Note that local_thinning must not exceed n_local_steps.

global_thinning¤

Store every global_thinning-th position from the NF global proposal. Default 1. Same semantics as local_thinning: all n_global_steps are always executed, but only every global_thinning-th result is written to the buffer. This affects memory, the acceptance-rate values seen by early stopping, and the number of samples available in the training buffer. global_thinning must not exceed n_global_steps.

A key practical motivation for increasing global_thinning is that the normalizing flow can never perfectly represent the posterior, so the global acceptance rate is typically low. When many consecutive proposals are rejected, the chain stays at the same position, and storing every step fills the buffer with duplicate points that contribute nothing to the posterior estimate or NF training. Thinning skips over these repeated positions while still executing all n_global_steps, preserving the full opportunity to make large jumps whenever a proposal is accepted.

verbose¤

Print training loss and acceptance rates during sampling. Default False.


Early stopping parameters¤

early_stopping¤

Whether to enable early stopping of the training phase. Default False. When enabled, the training phase ends early and the sampler transitions to the production phase once either of the two stopping conditions is met (see early_stopping_tolerance and early_stopping_min_acceptance below). The first 3 training loops are always completed regardless.

early_stopping_tolerance¤

Relative change threshold for the global acceptance rate stability check. Early stopping triggers when both the mean global acceptance rate and its coefficient of variation across chains change by less than this fraction for early_stopping_patience consecutive loops: |current - prev| / prev < tolerance. Default 0.05.

early_stopping_patience¤

Number of consecutive loops either stopping criterion must be satisfied before the training phase ends. Default 3.

early_stopping_min_acceptance¤

Global acceptance rate threshold that, once reached, triggers early stopping regardless of stability (after early_stopping_patience consecutive loops). Default 0.1. Set to 0 to disable this trigger and rely solely on the stability check.


Parallel tempering parameters¤

These parameters are only available in the PT bundles: RQSpline_MALA_PT_Bundle, RQSpline_HMC_PT_Bundle, RQSpline_GRW_PT_Bundle.

logprior¤

Log-prior function with signature logprior(x, data) -> float. Default is a flat prior that always returns 0. The tempered log-density is

\[\log \tilde{p}_T(x) = \frac{1}{T} \log \mathcal{L}(x) + \log \pi(x)\]

so the prior is preserved at all temperatures.

n_temperatures¤

Number of temperature levels, including the target (temperature = 1). Default 5. Chains at higher temperatures explore the prior more freely and occasionally swap positions with the target chains, helping escape local modes.

max_temperature¤

Highest temperature in the ladder. Default 5.0. The temperatures are spaced linearly between 1 and max_temperature. A larger value allows more prior-like exploration but reduces the swap acceptance rate between adjacent levels.

n_tempered_steps¤

Number of parallel-tempering swap attempts per loop. Default -1, which is interpreted as equal to n_local_steps. Reduce this to save computation if swap acceptance is low.