Skip to content

Smc

BlackJAX SMC samplers for Jim.

Supports four mode combinations selected by BlackJAXSMCConfig:

  • persistent_sampling=True, temperature_ladder=None → adaptive persistent SMC
  • persistent_sampling=True, temperature_ladder=given → fixed-ladder persistent SMC
  • persistent_sampling=False, temperature_ladder=None → adaptive tempered SMC
  • persistent_sampling=False, temperature_ladder=given → fixed-ladder tempered SMC

BlackJAXSMCSampler ¤

Bases: Sampler

BlackJAX SMC sampler.

Uses a Gaussian random-walk MCMC inner kernel with initial covariance estimated from the starting particles. With adaptive temperature selection the covariance is re-estimated at each step.

Supports checkpoint/resume via config.checkpoint_dir: a checkpoint.pkl checkpoint is written atomically after each tempering iteration (subject to config.checkpoint_interval) and the sampler resumes from it if one already exists at that path.

Operates on flat (n_dims,) arrays.

Parameters:

Name Type Description Default
n_dims int

Dimension of the sampling space.

required
log_prior_fn Callable

Log-prior callable (arr,) -> float.

required
log_likelihood_fn Callable

Log-likelihood callable (arr,) -> float.

required
log_posterior_fn Callable

Log-posterior callable (arr,) -> float.

required
config Optional[BlackJAXSMCConfig]

Optional BlackJAXSMCConfig; defaults to all-default values.

None
periodic Optional[dict[int, tuple[float, float]]]

Optional periodic-parameter spec in index space, dict[int, (lo, hi)] where the key is the dimension index and the value is the (lower, upper) period bounds. None means no periodic parameters. Provided by Jim after resolving names.

None

Methods:

Name Description
get_samples

Return posterior samples.

get_samples() -> dict[str, np.ndarray] ¤

Return posterior samples.

When persistent_sampling=True: samples are drawn with replacement from all-temperature particles weighted by the persistent-sampling weight formula. The number of returned samples approximately equals the effective sample size 1 / max(weights).

When persistent_sampling=False: returns all final-temperature particles with equal weight.

Returns:

Type Description
dict[str, ndarray]

Dict with keys "samples" (shape (n, n_dims)) and

dict[str, ndarray]

"log_likelihood" (shape (n,)).