Skip to content

Nss

BlackJAX Nested Slice Sampling (NSS).

BlackJAXNSSSampler ¤

Bases: Sampler

BlackJAX Nested Slice Sampler (NSS).

NSS combines nested sampling with an adaptive slice-sampling inner kernel. It works directly in the sampling space defined by sample_transforms (no unit-cube constraint required). Operates on flat arrays of shape (n_dims,); the NSS kernel is pytree-generic.

Configure via BlackJAXNSSConfig.

Parameters:

Name Type Description Default
n_dims int

Dimension of the sampling space.

required
log_prior_fn Callable

Log-prior callable (arr,) -> float.

required
log_likelihood_fn Callable

Log-likelihood callable (arr,) -> float.

required
log_posterior_fn Callable

Log-posterior callable (arr,) -> float.

required
config Optional[BlackJAXNSSConfig]

Optional BlackJAXNSSConfig; defaults to all-default values.

None
periodic Optional[dict[int, tuple[float, float]]]

Optional periodic-parameter spec in index space, dict[int, (lo, hi)] where the key is the dimension index and the value is the (lower, upper) period bounds. None means no periodic parameters. Provided by Jim after resolving names.

None

Methods:

Name Description
get_samples

Return equally-weighted posterior samples via anesthetic's posterior_points.

get_samples() -> dict[str, np.ndarray] ¤

Return equally-weighted posterior samples via anesthetic's posterior_points.

Uses NestedSamples.posterior_points to resample the nested dead-point collection to a set of truly equal-weight samples (rows duplicated proportional to integer weights).

Returns:

Type Description
dict[str, ndarray]

Dict with keys "samples" (shape (n, n_dims)) and

dict[str, ndarray]

"log_likelihood" (shape (n,)).