Skip to content

RQSpline HMC

flowMC.resource_strategy_bundle.RQSpline_HMC ¤

RQSpline_HMC_Bundle ¤

Bases: ResourceStrategyBundle

A bundle that uses a Rational Quadratic Spline as a normalizing flow model and Hamiltonian Monte Carlo as a local sampler.

This is similar to the RQSpline_MALA_Bundle but uses HMC instead of MALA for local sampling.

resources = {'logpdf': logpdf, 'positions_training': positions_training, 'log_prob_training': log_prob_training, 'local_accs_training': local_accs_training, 'global_accs_training': global_accs_training, 'loss_buffer': loss_buffer, 'positions_production': position_production, 'log_prob_production': log_prob_production, 'local_accs_production': local_accs_production, 'global_accs_production': global_accs_production, 'local_sampler': local_sampler, 'global_sampler': global_sampler, 'model': model, 'optimizer': optimizer, 'sampler_state': sampler_state} instance-attribute ¤
strategies = {'local_stepper': local_stepper, 'global_stepper': global_stepper, 'model_trainer': model_trainer, 'update_state': update_state, 'update_global_step': update_global_step, 'update_local_step': update_local_step, 'reset_steppers': reset_steppers_lambda, 'update_model': update_model_lambda, 'adapt_local_sampler': adapt_local_sampler, 'check_early_stop': check_early_stop} instance-attribute ¤
strategy_order = strategy_order instance-attribute ¤
__repr__() ¤
__init__(rng_key: Key, n_chains: int, n_dims: int, logpdf: Callable[[Float[Array, ' n_dim'], dict], Float], n_local_steps: int, n_global_steps: int, n_training_loops: int, n_production_loops: int, n_epochs: int, hmc_step_size: float = 0.1, hmc_n_leapfrog: int = 10, condition_matrix: Float | Float[Array, ' n_dim'] = 1, adapt_step_size: bool = True, periodic: Optional[dict[int, tuple[float, float]]] = None, rq_spline_hidden_units: list[int] = [32, 32], rq_spline_n_bins: int = 8, rq_spline_n_layers: int = 4, n_NFproposal_batch_size: int = 10000, learning_rate: float = 0.001, batch_size: int = 10000, n_max_examples: int = 10000, history_window: int = 100, chain_batch_size: int = 0, local_thinning: int = 1, global_thinning: int = 1, early_stopping: bool = False, early_stopping_tolerance: float = 0.05, early_stopping_patience: int = 3, early_stopping_min_acceptance: float = 0.1, verbose: bool = False) -> None ¤

Build all resources and strategies for an RQSpline + HMC sampling run.

Parameters:

Name Type Description Default
rng_key Key

JAX PRNGKey used to initialise the normalizing flow.

required
n_chains int

Number of parallel MCMC chains.

required
n_dims int

Dimensionality of the target distribution.

required
logpdf Callable

Log-PDF f(x, data) -> Float.

required
n_local_steps int

HMC steps per training/production loop iteration.

required
n_global_steps int

NF-proposal steps per training/production loop iteration.

required
n_training_loops int

Number of train-then-sample iterations (warmup).

required
n_production_loops int

Number of production sampling iterations.

required
n_epochs int

NF training epochs per training loop.

required
hmc_step_size float

Initial leapfrog step size. Defaults to 0.1.

0.1
hmc_n_leapfrog int

Number of leapfrog steps per HMC proposal. Defaults to 10.

10
condition_matrix float | Float[Array, n_dim]

Diagonal mass-matrix elements; scalar (broadcast) or per-dimension array. Defaults to 1.

1
adapt_step_size bool

Adapt the HMC step size during training. Defaults to True.

True
periodic dict[int, tuple[float, float]] | None

Periodic boundary conditions as {dim_index: (lower, upper)}. Defaults to None.

None
rq_spline_hidden_units list[int]

Hidden units per conditioner MLP layer. Defaults to [32, 32].

[32, 32]
rq_spline_n_bins int

Number of RQ-spline bins. Defaults to 8.

8
rq_spline_n_layers int

Number of masked coupling layers. Defaults to 4.

4
n_NFproposal_batch_size int

NF log-prob evaluation batch size. Defaults to 10000.

10000
learning_rate float

Adam learning rate for NF training. Defaults to 1e-3.

0.001
batch_size int

Mini-batch size for NF training. Defaults to 10000.

10000
n_max_examples int

Maximum training examples per call. Defaults to 10000.

10000
history_window int

Use only the last history_window stored steps as training data. Defaults to 100.

100
chain_batch_size int

Process chains in sub-batches of this size to reduce peak memory. 0 disables batching. Defaults to 0.

0
local_thinning int

Store every local_thinning-th local step. Defaults to 1.

1
global_thinning int

Store every global_thinning-th global step. Defaults to 1.

1
early_stopping bool

Enable early stopping based on global acceptance rate stability. Defaults to False.

False
early_stopping_tolerance float

Relative change threshold for early stopping. Defaults to 0.05.

0.05
early_stopping_patience int

Consecutive stable loops required before stopping. Defaults to 3.

3
early_stopping_min_acceptance float

Minimum global acceptance rate that also triggers early stopping. Defaults to 0.1.

0.1
verbose bool

Enable progress bars and debug logging. Defaults to False.

False