Skip to content

Sampler

Sampler ¤

Top level API that the users primarily interact with.

Parameters:

Name Type Description Default
n_dim int

Dimension of the parameter space.

required
n_chains int

Number of chains to sample.

required
rng_key Key

JAX PRNGKey.

required
resources dict[str, Resource]

Resources to be used by the sampler.

None
strategies dict[str, Strategy]

Strategies to be used by the sampler.

None
strategy_order list[str]

Order of strategies to execute.

None
resource_strategy_bundles ResourceStrategyBundle

Pre-configured bundle containing resources and strategies.

None
checkpoint_interval float

Minimum wall-clock seconds that must elapse since the previous write before a new checkpoint is written. Default 600 (10 minutes). Set to 0 to disable checkpointing entirely.

600.0

Methods:

Name Description
__init__

Initialize the sampler.

deserialize

Deserialize the sampler object.

sample

Execute the strategy loop and populate resource buffers with samples.

serialize

Serialize the sampler object.

__init__(*, n_dim: int, n_chains: int, rng_key: Key, resources: Optional[dict[str, Resource]] = None, strategies: Optional[dict[str, Strategy]] = None, strategy_order: Optional[list[str]] = None, resource_strategy_bundles: Optional[ResourceStrategyBundle] = None, outdir: str = './outdir/', checkpoint_interval: float = 600.0) -> None ¤

Initialize the sampler.

Provide either resources + strategies + strategy_order or a resource_strategy_bundles pre-configured bundle.

Parameters:

Name Type Description Default
n_dim int

Dimension of the parameter space.

required
n_chains int

Number of parallel chains.

required
rng_key Key

JAX PRNG key.

required
resources Optional[dict[str, Resource]]

Dictionary of named resources (kernels, buffers, etc.). Must be paired with strategies.

None
strategies Optional[dict[str, Strategy]]

Dictionary of named strategies. Must be paired with resources.

None
strategy_order Optional[list[str]]

Ordered list of strategy names to execute each call to sample.

None
resource_strategy_bundles Optional[ResourceStrategyBundle]

Pre-configured bundle that provides resources, strategies, and ordering.

None
outdir str

Directory where checkpoint.pkl is written. Created automatically if it does not exist. Default "./outdir/".

'./outdir/'
checkpoint_interval float

Minimum wall-clock seconds that must elapse since the previous write before a new checkpoint is written. Default 600 (10 minutes). Set to 0 to disable checkpointing. When checkpointing is enabled the JAX XLA compilation cache is also activated at {outdir}/jax_cache/ so that compiled step functions are reused across processes (no recompilation on resume).

600.0

Raises: ValueError: If neither resources/strategies nor resource_strategy_bundles is provided.

deserialize() ¤

Deserialize the sampler object.

sample(initial_position: Float[Array, 'n_chains n_dim'], data: dict) ¤

Execute the strategy loop and populate resource buffers with samples.

If outdir is set and checkpoint_interval > 0, the sampler writes a checkpoint to {outdir}/checkpoint.pkl atomically after each complete training loop (when at least checkpoint_interval seconds have elapsed since the previous write). On the next call with the same outdir the sampler detects the existing file, validates it against the current configuration, and resumes from the strategy immediately after the last checkpointed one.

Checkpoint validation raises ValueError if:

  • n_dim or n_chains differ from the checkpoint.
  • strategy_order differs from the checkpoint.
  • The logpdf fingerprint (evaluated at the first chain's position) differs by more than 1e-6, indicating a change in the likelihood or data.

Parameters:

Name Type Description Default
initial_position Float[Array, 'n_chains n_dim']

Starting chain positions, shape (n_chains, n_dim) or broadcastable. Ignored when resuming from a checkpoint (the checkpointed positions are used instead).

required
data dict

Arbitrary data dict forwarded to every strategy call. Must be consistent with any checkpoint on disk; changes are detected via the logpdf fingerprint when available.

required

serialize() ¤

Serialize the sampler object.