RQSpline HMC
flowMC.resource_strategy_bundle.RQSpline_HMC
¤
RQSpline_HMC_Bundle
¤
Bases: ResourceStrategyBundle
A bundle that uses a Rational Quadratic Spline as a normalizing flow model and Hamiltonian Monte Carlo as a local sampler.
This is similar to the RQSpline_MALA_Bundle but uses HMC instead of MALA for local sampling.
resources = {'logpdf': logpdf, 'positions_training': positions_training, 'log_prob_training': log_prob_training, 'local_accs_training': local_accs_training, 'global_accs_training': global_accs_training, 'loss_buffer': loss_buffer, 'positions_production': position_production, 'log_prob_production': log_prob_production, 'local_accs_production': local_accs_production, 'global_accs_production': global_accs_production, 'local_sampler': local_sampler, 'global_sampler': global_sampler, 'model': model, 'optimizer': optimizer, 'sampler_state': sampler_state}
instance-attribute
¤
strategies = {'local_stepper': local_stepper, 'global_stepper': global_stepper, 'model_trainer': model_trainer, 'update_state': update_state, 'update_global_step': update_global_step, 'update_local_step': update_local_step, 'reset_steppers': reset_steppers_lambda, 'update_model': update_model_lambda, 'adapt_local_sampler': adapt_local_sampler, 'check_early_stop': check_early_stop}
instance-attribute
¤
strategy_order = strategy_order
instance-attribute
¤
__repr__()
¤
__init__(rng_key: Key, n_chains: int, n_dims: int, logpdf: Callable[[Float[Array, ' n_dim'], dict], Float], n_local_steps: int, n_global_steps: int, n_training_loops: int, n_production_loops: int, n_epochs: int, hmc_step_size: float = 0.1, hmc_n_leapfrog: int = 10, condition_matrix: Float | Float[Array, ' n_dim'] = 1, adapt_step_size: bool = True, periodic: Optional[dict[int, tuple[float, float]]] = None, rq_spline_hidden_units: list[int] = [32, 32], rq_spline_n_bins: int = 8, rq_spline_n_layers: int = 4, n_NFproposal_batch_size: int = 10000, learning_rate: float = 0.001, batch_size: int = 10000, n_max_examples: int = 10000, history_window: int = 100, chain_batch_size: int = 0, local_thinning: int = 1, global_thinning: int = 1, early_stopping: bool = False, early_stopping_tolerance: float = 0.05, early_stopping_patience: int = 3, early_stopping_min_acceptance: float = 0.1, verbose: bool = False) -> None
¤
Build all resources and strategies for an RQSpline + HMC sampling run.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng_key
|
Key
|
JAX PRNGKey used to initialise the normalizing flow. |
required |
n_chains
|
int
|
Number of parallel MCMC chains. |
required |
n_dims
|
int
|
Dimensionality of the target distribution. |
required |
logpdf
|
Callable
|
Log-PDF |
required |
n_local_steps
|
int
|
HMC steps per training/production loop iteration. |
required |
n_global_steps
|
int
|
NF-proposal steps per training/production loop iteration. |
required |
n_training_loops
|
int
|
Number of train-then-sample iterations (warmup). |
required |
n_production_loops
|
int
|
Number of production sampling iterations. |
required |
n_epochs
|
int
|
NF training epochs per training loop. |
required |
hmc_step_size
|
float
|
Initial leapfrog step size. Defaults to 0.1. |
0.1
|
hmc_n_leapfrog
|
int
|
Number of leapfrog steps per HMC proposal. Defaults to 10. |
10
|
condition_matrix
|
float | Float[Array, n_dim]
|
Diagonal mass-matrix elements; scalar (broadcast) or per-dimension array. Defaults to 1. |
1
|
adapt_step_size
|
bool
|
Adapt the HMC step size during training. Defaults to True. |
True
|
periodic
|
dict[int, tuple[float, float]] | None
|
Periodic boundary
conditions as |
None
|
rq_spline_hidden_units
|
list[int]
|
Hidden units per conditioner MLP layer.
Defaults to |
[32, 32]
|
rq_spline_n_bins
|
int
|
Number of RQ-spline bins. Defaults to 8. |
8
|
rq_spline_n_layers
|
int
|
Number of masked coupling layers. Defaults to 4. |
4
|
n_NFproposal_batch_size
|
int
|
NF log-prob evaluation batch size. Defaults to 10000. |
10000
|
learning_rate
|
float
|
Adam learning rate for NF training. Defaults to 1e-3. |
0.001
|
batch_size
|
int
|
Mini-batch size for NF training. Defaults to 10000. |
10000
|
n_max_examples
|
int
|
Maximum training examples per call. Defaults to 10000. |
10000
|
history_window
|
int
|
Use only the last |
100
|
chain_batch_size
|
int
|
Process chains in sub-batches of this size to reduce peak memory. 0 disables batching. Defaults to 0. |
0
|
local_thinning
|
int
|
Store every |
1
|
global_thinning
|
int
|
Store every |
1
|
early_stopping
|
bool
|
Enable early stopping based on global acceptance rate stability. Defaults to False. |
False
|
early_stopping_tolerance
|
float
|
Relative change threshold for early stopping. Defaults to 0.05. |
0.05
|
early_stopping_patience
|
int
|
Consecutive stable loops required before stopping. Defaults to 3. |
3
|
early_stopping_min_acceptance
|
float
|
Minimum global acceptance rate that also triggers early stopping. Defaults to 0.1. |
0.1
|
verbose
|
bool
|
Enable progress bars and debug logging. Defaults to False. |
False
|