Skip to content

RQSpline HMC

RQSpline_HMC_Bundle ¤

Bases: ResourceStrategyBundle

A bundle that uses a Rational Quadratic Spline as a normalizing flow model and Hamiltonian Monte Carlo as a local sampler.

This is similar to the RQSpline_MALA_Bundle but uses HMC instead of MALA for local sampling.

Methods:

Name Description
__init__

Build all resources and strategies for an RQSpline + HMC sampling run.

__init__(rng_key: Key, n_chains: int, n_dims: int, logpdf: Callable[[Float[Array, ' n_dim'], dict], FloatScalar], n_local_steps: int, n_global_steps: int, n_training_loops: int, n_production_loops: int, n_epochs: int, hmc_step_size: float = 0.1, hmc_n_leapfrog: int = 10, condition_matrix: float | Float[Array, ' n_dim'] = 1, adapt_step_size: bool = True, adapt_step_size_per_dim: bool = True, periodic: Optional[dict[int, tuple[float, float]]] = None, rq_spline_hidden_units: list[int] = [32, 32], rq_spline_n_bins: int = 8, rq_spline_n_layers: int = 4, n_NFproposal_batch_size: int = 10000, learning_rate: float = 0.001, batch_size: int = 10000, n_max_examples: int = 10000, history_window: int = 100, chain_batch_size: int = 0, local_thinning: int = 1, global_thinning: int = 1, early_stopping: bool = False, early_stopping_tolerance: float = 0.05, early_stopping_patience: int = 3, early_stopping_min_acceptance: float = 0.1, verbose: bool = False) -> None ¤

Build all resources and strategies for an RQSpline + HMC sampling run.

Parameters:

Name Type Description Default
rng_key Key

JAX PRNGKey used to initialise the normalizing flow.

required
n_chains int

Number of parallel MCMC chains.

required
n_dims int

Dimensionality of the target distribution.

required
logpdf Callable

Log-PDF f(x, data) -> Float.

required
n_local_steps int

HMC steps per training/production loop iteration.

required
n_global_steps int

NF-proposal steps per training/production loop iteration.

required
n_training_loops int

Number of train-then-sample iterations (warmup).

required
n_production_loops int

Number of production sampling iterations.

required
n_epochs int

NF training epochs per training loop.

required
hmc_step_size float

Initial leapfrog step size. Defaults to 0.1.

0.1
hmc_n_leapfrog int

Number of leapfrog steps per HMC proposal. Defaults to 10.

10
condition_matrix float | Float[Array, n_dim]

Diagonal mass-matrix elements; scalar (broadcast) or per-dimension array. Defaults to 1.

1
adapt_step_size bool

Adapt the HMC step size during training. Defaults to True.

True
adapt_step_size_per_dim bool

Also tune per-dimension effective step sizes via the mass matrix using the empirical std of recent chain positions. Runs after adapt_step_size. Defaults to True.

True
periodic dict[int, tuple[float, float]] | None

Periodic boundary conditions as {dim_index: (lower, upper)}. Defaults to None.

None
rq_spline_hidden_units list[int]

Hidden units per conditioner MLP layer. Defaults to [32, 32].

[32, 32]
rq_spline_n_bins int

Number of RQ-spline bins. Defaults to 8.

8
rq_spline_n_layers int

Number of masked coupling layers. Defaults to 4.

4
n_NFproposal_batch_size int

NF log-prob evaluation batch size. Defaults to 10000.

10000
learning_rate float

Adam learning rate for NF training. Defaults to 1e-3.

0.001
batch_size int

Mini-batch size for NF training. Defaults to 10000.

10000
n_max_examples int

Maximum training examples per call. Defaults to 10000.

10000
history_window int

Use only the last history_window stored steps as training data. Defaults to 100.

100
chain_batch_size int

Process chains in sub-batches of this size to reduce peak memory. 0 disables batching. Defaults to 0.

0
local_thinning int

Store every local_thinning-th local step. Defaults to 1.

1
global_thinning int

Store every global_thinning-th global step. Defaults to 1.

1
early_stopping bool

Enable early stopping based on global acceptance rate stability. Defaults to False.

False
early_stopping_tolerance float

Relative change threshold for early stopping. Defaults to 0.05.

0.05
early_stopping_patience int

Consecutive stable loops required before stopping. Defaults to 3.

3
early_stopping_min_acceptance float

Minimum global acceptance rate that also triggers early stopping. Defaults to 0.1.

0.1
verbose bool

Enable progress bars and debug logging. Defaults to False.

False