Skip to content

Base

Sampler abstraction for Jim.

This module defines Sampler, an abstract base class that encapsulates everything Jim needs from a JAX sampler backend.

Samplers operate entirely in the sampling space (flat arrays of shape (n_dims,)). They have zero knowledge of parameter names, transforms, or prior/likelihood details beyond what the injected callables provide. Jim is responsible for building those callables and for converting the sampling-space arrays returned by Sampler.get_samples back to a named prior-space dict via Jim.get_samples.

Sampler ¤

Bases: ABC

Abstract base class for JAX sampler backends.

Each backend receives four injected callables from Jim and operates entirely in the sampling space (flat arrays of shape (n_dims,)). It has no knowledge of parameter names, transforms, or likelihood/prior details beyond what the callables provide.

Initial positions are always supplied by the caller (Jim draws them from the prior before calling sample); samplers never draw initial samples themselves.

Methods:

Name Description
get_diagnostics

Return run-level diagnostics.

get_samples

Return posterior samples after internal post-processing.

sample

Run the sampler and record wall-clock sampling time.

get_diagnostics() -> dict[str, Any] ¤

Return run-level diagnostics.

Only valid after sample has been called.

get_samples() -> dict[str, np.ndarray] abstractmethod ¤

Return posterior samples after internal post-processing.

Returns a dict with exactly two keys:

  • "samples" — 2-D np.ndarray of shape (n, n_dims) in the sampling space.
  • "log_likelihood" — 1-D np.ndarray of shape (n,) with the per-sample log-likelihood values.

Backends that use weighted samples (NS, persistent SMC) perform importance resampling internally so the returned samples are equally-weighted.

Only valid after sample has been called.

sample(rng_key: Key, initial_position: Float[Array, 'n n_dims']) -> None ¤

Run the sampler and record wall-clock sampling time.

Parameters:

Name Type Description Default
rng_key Key

JAX PRNG key.

required
initial_position Float[Array, 'n n_dims']

Starting positions in sampling space, shape (n, n_dims). The expected value of n depends on the backend (n_chains for flowMC, n_live for NS, n_particles for SMC).

required