Base
Sampler abstraction for Jim.
This module defines Sampler, an abstract base
class that encapsulates everything Jim needs from a JAX sampler backend.
Samplers operate entirely in the sampling space (flat arrays of shape
(n_dims,)). They have zero knowledge of parameter names, transforms, or
prior/likelihood details beyond what the injected callables provide.
Jim is responsible for building those callables and for converting
the sampling-space arrays returned by Sampler.get_samples back to a
named prior-space dict via Jim.get_samples.
Sampler
¤
Bases: ABC
Abstract base class for JAX sampler backends.
Each backend receives four injected callables from
Jim and operates entirely in the sampling
space (flat arrays of shape (n_dims,)). It has no knowledge of
parameter names, transforms, or likelihood/prior details beyond what
the callables provide.
Initial positions are always supplied by the caller (Jim draws them from
the prior before calling sample); samplers never draw initial
samples themselves.
Methods:
| Name | Description |
|---|---|
get_diagnostics |
Return run-level diagnostics. |
get_samples |
Return posterior samples after internal post-processing. |
sample |
Run the sampler and record wall-clock sampling time. |
get_diagnostics() -> dict[str, Any]
¤
Return run-level diagnostics.
Only valid after sample has been called.
get_samples() -> dict[str, np.ndarray]
abstractmethod
¤
Return posterior samples after internal post-processing.
Returns a dict with exactly two keys:
"samples"— 2-Dnp.ndarrayof shape(n, n_dims)in the sampling space."log_likelihood"— 1-Dnp.ndarrayof shape(n,)with the per-sample log-likelihood values.
Backends that use weighted samples (NS, persistent SMC) perform importance resampling internally so the returned samples are equally-weighted.
Only valid after sample has been called.
sample(rng_key: Key, initial_position: Float[Array, 'n n_dims']) -> None
¤
Run the sampler and record wall-clock sampling time.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng_key
|
Key
|
JAX PRNG key. |
required |
initial_position
|
Float[Array, 'n n_dims']
|
Starting positions in sampling space, shape
|
required |