Skip to content

Periodic

Adapters that translate Jim's periodic-parameter spec into the form each sampler backend expects.

periodic_index = {1: (0.0, 2 * math.pi), ...}   # key = dimension index

Each backend wants a different shape: flowMC already accepts an index-keyed dict directly; BlackJAX NS-AW needs a stepper function on flat arrays; BlackJAX NSS needs a stepper returning a (position, accepted) tuple; BlackJAX SMC needs a displacement wrapper. The adapters below handle those conversions.

All adapters operate on flat JAX arrays of shape (n_dims,).

Functions:

Name Description
to_displacement_wrapper

Displacement wrapper for BlackJAX SMC (prior space).

to_prior_space_stepper

Stepper function for BlackJAX NSS (prior space).

to_unit_cube_stepper

Stepper function for BlackJAX NS-AW (unit-cube space).

to_displacement_wrapper(periodic_index: Optional[dict[int, tuple[float, float]]], n_dims: int) -> Callable ¤

Displacement wrapper for BlackJAX SMC (prior space).

Signature: wrapper_fn(proposed_displacement, current_position) -> wrapped_displacement

Displacement and position are flat JAX arrays of shape (n_dims,). SMC's inner kernel operates on displacements. For periodic parameters the displacement is adjusted so that current + wrapped_displacement stays within [lower, upper):

wrapped_displacement = lower + mod(current + disp - lower, period) - current

to_prior_space_stepper(periodic_index: Optional[dict[int, tuple[float, float]]], n_dims: int) -> Callable ¤

Stepper function for BlackJAX NSS (prior space).

Signature: stepper_fn(position, direction, step_size) -> (new_position, accepted)

Position and direction are flat JAX arrays of shape (n_dims,). NSS requires the stepper to return a (position, bool) tuple. Periodic parameters are wrapped with lower + mod(pos + step_size * dir - lower, period).

to_unit_cube_stepper(periodic_index: Optional[list[int]], n_dims: int) -> Callable ¤

Stepper function for BlackJAX NS-AW (unit-cube space).

Signature: stepper_fn(position, direction, step_size) -> new_position

periodic_index is a list of dimension indices to wrap; bounds are implicit because NS-AW always operates in [0, 1]^n_dims, so wrapping is always mod(pos + step_size * dir, 1.0).