Smc
BlackJAX SMC samplers for Jim.
Supports four mode combinations selected by
BlackJAXSMCConfig:
persistent_sampling=True, temperature_ladder=None→ adaptive persistent SMCpersistent_sampling=True, temperature_ladder=given→ fixed-ladder persistent SMCpersistent_sampling=False, temperature_ladder=None→ adaptive tempered SMCpersistent_sampling=False, temperature_ladder=given→ fixed-ladder tempered SMC
BlackJAXSMCSampler
¤
Bases: Sampler
BlackJAX SMC sampler.
Uses a Gaussian random-walk MCMC inner kernel with initial covariance estimated from the starting particles. With adaptive temperature selection the covariance is re-estimated at each step.
Supports checkpoint/resume via config.checkpoint_dir: a checkpoint.pkl
checkpoint is written atomically after each tempering iteration (subject
to config.checkpoint_interval) and the sampler resumes from it if one
already exists at that path.
Operates on flat (n_dims,) arrays.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_dims
|
int
|
Dimension of the sampling space. |
required |
log_prior_fn
|
Callable
|
Log-prior callable |
required |
log_likelihood_fn
|
Callable
|
Log-likelihood callable |
required |
log_posterior_fn
|
Callable
|
Log-posterior callable |
required |
config
|
Optional[BlackJAXSMCConfig]
|
Optional |
None
|
periodic
|
Optional[dict[int, tuple[float, float]]]
|
Optional periodic-parameter spec in index space,
|
None
|
Methods:
| Name | Description |
|---|---|
get_samples |
Return posterior samples. |
get_samples() -> dict[str, np.ndarray]
¤
Return posterior samples.
When persistent_sampling=True: samples are drawn with replacement
from all-temperature particles weighted by the persistent-sampling
weight formula. The number of returned samples approximately equals
the effective sample size 1 / max(weights).
When persistent_sampling=False: returns all final-temperature
particles with equal weight.
Returns:
| Type | Description |
|---|---|
dict[str, ndarray]
|
Dict with keys |
dict[str, ndarray]
|
|