RQSpline HMC
RQSpline_HMC_Bundle
¤
Bases: ResourceStrategyBundle
A bundle that uses a Rational Quadratic Spline as a normalizing flow model and Hamiltonian Monte Carlo as a local sampler.
This is similar to the RQSpline_MALA_Bundle but uses HMC instead of MALA for local sampling.
Methods:
| Name | Description |
|---|---|
__init__ |
Build all resources and strategies for an RQSpline + HMC sampling run. |
__init__(rng_key: Key, n_chains: int, n_dims: int, logpdf: Callable[[Float[Array, ' n_dim'], dict], FloatScalar], n_local_steps: int, n_global_steps: int, n_training_loops: int, n_production_loops: int, n_epochs: int, hmc_step_size: float = 0.1, hmc_n_leapfrog: int = 10, condition_matrix: float | Float[Array, ' n_dim'] = 1, adapt_step_size: bool = True, adapt_step_size_per_dim: bool = True, periodic: Optional[dict[int, tuple[float, float]]] = None, rq_spline_hidden_units: list[int] = [32, 32], rq_spline_n_bins: int = 8, rq_spline_n_layers: int = 4, n_NFproposal_batch_size: int = 10000, learning_rate: float = 0.001, batch_size: int = 10000, n_max_examples: int = 10000, history_window: int = 100, chain_batch_size: int = 0, local_thinning: int = 1, global_thinning: int = 1, early_stopping: bool = False, early_stopping_tolerance: float = 0.05, early_stopping_patience: int = 3, early_stopping_min_acceptance: float = 0.1, verbose: bool = False) -> None
¤
Build all resources and strategies for an RQSpline + HMC sampling run.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng_key
|
Key
|
JAX PRNGKey used to initialise the normalizing flow. |
required |
n_chains
|
int
|
Number of parallel MCMC chains. |
required |
n_dims
|
int
|
Dimensionality of the target distribution. |
required |
logpdf
|
Callable
|
Log-PDF |
required |
n_local_steps
|
int
|
HMC steps per training/production loop iteration. |
required |
n_global_steps
|
int
|
NF-proposal steps per training/production loop iteration. |
required |
n_training_loops
|
int
|
Number of train-then-sample iterations (warmup). |
required |
n_production_loops
|
int
|
Number of production sampling iterations. |
required |
n_epochs
|
int
|
NF training epochs per training loop. |
required |
hmc_step_size
|
float
|
Initial leapfrog step size. Defaults to 0.1. |
0.1
|
hmc_n_leapfrog
|
int
|
Number of leapfrog steps per HMC proposal. Defaults to 10. |
10
|
condition_matrix
|
float | Float[Array, n_dim]
|
Diagonal mass-matrix elements; scalar (broadcast) or per-dimension array. Defaults to 1. |
1
|
adapt_step_size
|
bool
|
Adapt the HMC step size during training. Defaults to True. |
True
|
adapt_step_size_per_dim
|
bool
|
Also tune per-dimension effective step sizes via the mass matrix using the empirical std of recent chain positions. Runs after adapt_step_size. Defaults to True. |
True
|
periodic
|
dict[int, tuple[float, float]] | None
|
Periodic boundary
conditions as |
None
|
rq_spline_hidden_units
|
list[int]
|
Hidden units per conditioner MLP layer.
Defaults to |
[32, 32]
|
rq_spline_n_bins
|
int
|
Number of RQ-spline bins. Defaults to 8. |
8
|
rq_spline_n_layers
|
int
|
Number of masked coupling layers. Defaults to 4. |
4
|
n_NFproposal_batch_size
|
int
|
NF log-prob evaluation batch size. Defaults to 10000. |
10000
|
learning_rate
|
float
|
Adam learning rate for NF training. Defaults to 1e-3. |
0.001
|
batch_size
|
int
|
Mini-batch size for NF training. Defaults to 10000. |
10000
|
n_max_examples
|
int
|
Maximum training examples per call. Defaults to 10000. |
10000
|
history_window
|
int
|
Use only the last |
100
|
chain_batch_size
|
int
|
Process chains in sub-batches of this size to reduce peak memory. 0 disables batching. Defaults to 0. |
0
|
local_thinning
|
int
|
Store every |
1
|
global_thinning
|
int
|
Store every |
1
|
early_stopping
|
bool
|
Enable early stopping based on global acceptance rate stability. Defaults to False. |
False
|
early_stopping_tolerance
|
float
|
Relative change threshold for early stopping. Defaults to 0.05. |
0.05
|
early_stopping_patience
|
int
|
Consecutive stable loops required before stopping. Defaults to 3. |
3
|
early_stopping_min_acceptance
|
float
|
Minimum global acceptance rate that also triggers early stopping. Defaults to 0.1. |
0.1
|
verbose
|
bool
|
Enable progress bars and debug logging. Defaults to False. |
False
|