Periodic
Adapters that translate Jim's periodic-parameter spec into the form each sampler backend expects.
periodic_index = {1: (0.0, 2 * math.pi), ...} # key = dimension index
Each backend wants a different shape: flowMC already accepts an index-keyed dict
directly; BlackJAX NS-AW needs a stepper function on flat arrays; BlackJAX NSS
needs a stepper returning a (position, accepted) tuple; BlackJAX SMC needs a
displacement wrapper. The adapters below handle those conversions.
All adapters operate on flat JAX arrays of shape (n_dims,).
Functions:
| Name | Description |
|---|---|
to_displacement_wrapper |
Displacement wrapper for BlackJAX SMC (prior space). |
to_prior_space_stepper |
Stepper function for BlackJAX NSS (prior space). |
to_unit_cube_stepper |
Stepper function for BlackJAX NS-AW (unit-cube space). |
to_displacement_wrapper(periodic_index: Optional[dict[int, tuple[float, float]]], n_dims: int) -> Callable
¤
Displacement wrapper for BlackJAX SMC (prior space).
Signature: wrapper_fn(proposed_displacement, current_position) -> wrapped_displacement
Displacement and position are flat JAX arrays of shape (n_dims,).
SMC's inner kernel operates on displacements. For periodic parameters the
displacement is adjusted so that current + wrapped_displacement stays
within [lower, upper):
wrapped_displacement = lower + mod(current + disp - lower, period) - current
to_prior_space_stepper(periodic_index: Optional[dict[int, tuple[float, float]]], n_dims: int) -> Callable
¤
Stepper function for BlackJAX NSS (prior space).
Signature: stepper_fn(position, direction, step_size) -> (new_position, accepted)
Position and direction are flat JAX arrays of shape (n_dims,).
NSS requires the stepper to return a (position, bool) tuple.
Periodic parameters are wrapped with
lower + mod(pos + step_size * dir - lower, period).
to_unit_cube_stepper(periodic_index: Optional[list[int]], n_dims: int) -> Callable
¤
Stepper function for BlackJAX NS-AW (unit-cube space).
Signature: stepper_fn(position, direction, step_size) -> new_position
periodic_index is a list of dimension indices to wrap; bounds are implicit
because NS-AW always operates in [0, 1]^n_dims, so wrapping is always
mod(pos + step_size * dir, 1.0).