Skip to content

Prior

jimgw.core.prior ¤

Prior ¤

Bases: Module

Base class for prior distributions.

This class should not be used directly. It provides a common interface and bookkeeping for parameter names and transforms.

n_dims: int property ¤
parameter_names: tuple[str, ...] = tuple(parameter_names) instance-attribute ¤
__init__(parameter_names: list[str]) ¤

Parameters:

Name Type Description Default
parameter_names list[str]

A list of names for the parameters of the prior.

required
add_name(x: Float[Array, n_dims]) -> dict[str, Float] ¤

Turn an array into a dictionary.

Parameters:

Name Type Description Default
x Array

An array of parameters. Shape (n_dims,).

required
__call__(x: dict[str, Float]) -> Float ¤
log_prob(z: dict[str, Float]) -> Float abstractmethod ¤
sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']] abstractmethod ¤

CompositePrior ¤

Bases: Prior

Composite prior consisting of multiple priors, including SequentialTransformPrior and CombinePrior. This class is used to create complex prior distributions from simpler ones.

Attributes:

Name Type Description
base_prior tuple[Prior, ...]

Tuple of prior objects.

parameter_names tuple[str, ...]

Names of all parameters in the composite prior.

base_prior: tuple[Prior, ...] = tuple(priors) instance-attribute ¤
parameter_names = tuple([name for prior in priors for name in (prior.parameter_names)]) instance-attribute ¤
n_dims: int property ¤
__repr__() ¤
__init__(priors: list[Prior]) ¤
trace_prior_parent(output: Optional[list[Prior]] = None) -> list[Prior] ¤

Recursively collect all leaf (non-composite) priors.

Parameters:

Name Type Description Default
output Optional[list[Prior]]

Accumulator list. If None, a new list is created. Defaults to None.

None

Returns:

Type Description
list[Prior]

list[Prior]: List of all leaf prior objects in this composite.

add_name(x: Float[Array, n_dims]) -> dict[str, Float] ¤

Turn an array into a dictionary.

Parameters:

Name Type Description Default
x Array

An array of parameters. Shape (n_dims,).

required
__call__(x: dict[str, Float]) -> Float ¤
log_prob(z: dict[str, Float]) -> Float abstractmethod ¤
sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']] abstractmethod ¤

LogisticDistribution ¤

Bases: Prior

One-dimensional logistic distribution prior.

Attributes:

Name Type Description
parameter_names list[str]

Name of the parameter.

parameter_names: tuple[str, ...] = tuple(parameter_names) instance-attribute ¤
n_dims: int property ¤
__repr__() ¤
__init__(parameter_names: list[str], **kwargs) ¤
sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']] ¤

Sample from a logistic distribution.

Parameters:

Name Type Description Default
rng_key Key

A random key to use for sampling.

required
n_samples int

The number of samples to draw.

required

Returns:

Name Type Description
dict dict[str, Float[Array, ' n_samples']]

Samples from the distribution. The keys are the names of the parameters.

log_prob(z: dict[str, Float]) -> Float ¤
add_name(x: Float[Array, n_dims]) -> dict[str, Float] ¤

Turn an array into a dictionary.

Parameters:

Name Type Description Default
x Array

An array of parameters. Shape (n_dims,).

required
__call__(x: dict[str, Float]) -> Float ¤

StandardNormalDistribution ¤

Bases: Prior

One-dimensional standard normal (Gaussian) distribution prior.

Attributes:

Name Type Description
parameter_names list[str]

Name of the parameter.

parameter_names: tuple[str, ...] = tuple(parameter_names) instance-attribute ¤
n_dims: int property ¤
__repr__() ¤
__init__(parameter_names: list[str], **kwargs) ¤
sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']] ¤

Sample from a standard normal distribution.

Parameters:

Name Type Description Default
rng_key Key

A random key to use for sampling.

required
n_samples int

The number of samples to draw.

required

Returns:

Name Type Description
dict dict[str, Float[Array, ' n_samples']]

Samples from the distribution. The keys are the names of the parameters.

log_prob(z: dict[str, Float]) -> Float ¤
add_name(x: Float[Array, n_dims]) -> dict[str, Float] ¤

Turn an array into a dictionary.

Parameters:

Name Type Description Default
x Array

An array of parameters. Shape (n_dims,).

required
__call__(x: dict[str, Float]) -> Float ¤

UniformDistribution ¤

Bases: Prior

One-dimensional uniform distribution prior over [0, 1].

Attributes:

Name Type Description
parameter_names list[str]

Name of the parameter.

xmin: float = 0.0 class-attribute instance-attribute ¤
xmax: float = 1.0 class-attribute instance-attribute ¤
parameter_names: tuple[str, ...] = tuple(parameter_names) instance-attribute ¤
n_dims: int property ¤
__repr__() ¤
__init__(parameter_names: list[str], **kwargs) ¤
sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']] ¤

Sample from a uniform distribution.

Parameters:

Name Type Description Default
rng_key Key

A random key to use for sampling.

required
n_samples int

The number of samples to draw.

required

Returns:

Name Type Description
dict dict[str, Float[Array, ' n_samples']]

Samples from the distribution. The keys are the names of the parameters.

log_prob(z: dict[str, Float]) -> Float ¤
add_name(x: Float[Array, n_dims]) -> dict[str, Float] ¤

Turn an array into a dictionary.

Parameters:

Name Type Description Default
x Array

An array of parameters. Shape (n_dims,).

required
__call__(x: dict[str, Float]) -> Float ¤

SequentialTransformPrior ¤

Bases: CompositePrior

Prior distribution transformed by a sequence of bijective transforms.

Attributes:

Name Type Description
base_prior tuple[Prior, ...]

The base prior to transform (must be length 1).

transforms tuple[BijectiveTransform, ...]

Transforms applied sequentially in the forward direction.

parameter_names tuple[str, ...]

Names of the parameters after all transforms.

transforms: tuple[BijectiveTransform, ...] = tuple(transforms) instance-attribute ¤
parameter_names = tuple(transform.propagate_name(self.parameter_names)) instance-attribute ¤
n_dims: int property ¤
base_prior: tuple[Prior, ...] = tuple(priors) instance-attribute ¤
__repr__() ¤
__init__(base_prior: list[Prior], transforms: list[BijectiveTransform]) ¤

Parameters:

Name Type Description Default
base_prior list[Prior]

A single-element list containing the base prior.

required
transforms list[BijectiveTransform]

Ordered list of bijective transforms to apply to samples from the base prior.

required
sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']] ¤

Sample by drawing from the base prior and applying all transforms.

Parameters:

Name Type Description Default
rng_key Key

JAX PRNG key.

required
n_samples int

Number of samples to draw.

required

Returns:

Type Description
dict[str, Float[Array, ' n_samples']]

dict[str, Float[Array, " n_samples"]]: Transformed samples keyed by parameter name.

log_prob(z: dict[str, Float]) -> Float ¤

Evaluate the log-probability of a transformed sample z.

Applies the inverse transforms in reverse order, accumulating log-Jacobian determinants, then evaluates the base prior.

Parameters:

Name Type Description Default
z dict[str, Float]

Sample in the transformed (output) space.

required

Returns:

Name Type Description
Float Float

Log-probability of z under the induced distribution.

transform(x: dict[str, Float]) -> dict[str, Float] ¤

Apply all transforms sequentially (forward direction).

Parameters:

Name Type Description Default
x dict[str, Float]

Sample in the base prior space.

required

Returns:

Type Description
dict[str, Float]

dict[str, Float]: Transformed sample.

add_name(x: Float[Array, n_dims]) -> dict[str, Float] ¤

Turn an array into a dictionary.

Parameters:

Name Type Description Default
x Array

An array of parameters. Shape (n_dims,).

required
__call__(x: dict[str, Float]) -> Float ¤
trace_prior_parent(output: Optional[list[Prior]] = None) -> list[Prior] ¤

Recursively collect all leaf (non-composite) priors.

Parameters:

Name Type Description Default
output Optional[list[Prior]]

Accumulator list. If None, a new list is created. Defaults to None.

None

Returns:

Type Description
list[Prior]

list[Prior]: List of all leaf prior objects in this composite.

BoundedMixin ¤

Mixin class that adds bounds checking to log_prob.

This mixin should be placed BEFORE the main prior class in the inheritance list (e.g., class MyPrior(BoundedMixin, SequentialTransformPrior)) to ensure the bounds check is applied before delegating to the base prior's log_prob.

Classes using this mixin can override xmin and xmax attributes to set bounds. By default, the bounds are (-inf, inf), meaning no bounds checking.

The mixin returns -inf for values outside [xmin, xmax].

Note: This is a mixin and should not be used standalone. It relies on the presence of parameter_names attribute from the Prior class and the implementation of log_prob in the base class.

xmin: float = -jnp.inf class-attribute instance-attribute ¤
xmax: float = jnp.inf class-attribute instance-attribute ¤
log_prob(z: dict[str, Float]) -> Float ¤

CombinePrior ¤

Bases: CompositePrior

Multivariate prior constructed by joining multiple independent priors.

The joint log-probability is the sum of the individual log-probabilities, which is valid when all component priors are independent.

Attributes:

Name Type Description
base_prior tuple[Prior, ...]

Tuple of independent priors.

parameter_names tuple[str, ...]

Names of all parameters in the combined prior.

base_prior: tuple[Prior, ...] = field(default_factory=tuple) class-attribute instance-attribute ¤
parameter_names = tuple([name for prior in priors for name in (prior.parameter_names)]) instance-attribute ¤
n_dims: int property ¤
__repr__() ¤
__init__(priors: list[Prior]) ¤

Parameters:

Name Type Description Default
priors list[Prior]

List of independent prior objects to combine.

required
sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']] ¤

Sample from all component priors independently.

Parameters:

Name Type Description Default
rng_key Key

JAX PRNG key (split internally for each component).

required
n_samples int

Number of samples to draw.

required

Returns:

Type Description
dict[str, Float[Array, ' n_samples']]

dict[str, Float[Array, " n_samples"]]: Combined samples from all component priors, keyed by parameter name.

log_prob(z: dict[str, Float]) -> Float ¤

Evaluate the joint log-probability as the sum of component log-probabilities.

Parameters:

Name Type Description Default
z dict[str, Float]

Dictionary of parameter values.

required

Returns:

Name Type Description
Float Float

Sum of log-probabilities from all component priors.

add_name(x: Float[Array, n_dims]) -> dict[str, Float] ¤

Turn an array into a dictionary.

Parameters:

Name Type Description Default
x Array

An array of parameters. Shape (n_dims,).

required
__call__(x: dict[str, Float]) -> Float ¤
trace_prior_parent(output: Optional[list[Prior]] = None) -> list[Prior] ¤

Recursively collect all leaf (non-composite) priors.

Parameters:

Name Type Description Default
output Optional[list[Prior]]

Accumulator list. If None, a new list is created. Defaults to None.

None

Returns:

Type Description
list[Prior]

list[Prior]: List of all leaf prior objects in this composite.

UniformPrior ¤

Bases: SequentialTransformPrior

Uniform prior over a finite interval [xmin, xmax].

Attributes:

Name Type Description
xmin float

Lower bound of the interval.

xmax float

Upper bound of the interval.

parameter_names list[str]

Name of the parameter.

xmax: float = xmax instance-attribute ¤
xmin: float = xmin instance-attribute ¤
parameter_names = tuple(transform.propagate_name(self.parameter_names)) instance-attribute ¤
n_dims: int property ¤
base_prior: tuple[Prior, ...] = tuple(priors) instance-attribute ¤
transforms: tuple[BijectiveTransform, ...] = tuple(transforms) instance-attribute ¤
__repr__() ¤
__init__(xmin: float, xmax: float, parameter_names: list[str]) ¤
add_name(x: Float[Array, n_dims]) -> dict[str, Float] ¤

Turn an array into a dictionary.

Parameters:

Name Type Description Default
x Array

An array of parameters. Shape (n_dims,).

required
__call__(x: dict[str, Float]) -> Float ¤
log_prob(z: dict[str, Float]) -> Float ¤

Evaluate the log-probability of a transformed sample z.

Applies the inverse transforms in reverse order, accumulating log-Jacobian determinants, then evaluates the base prior.

Parameters:

Name Type Description Default
z dict[str, Float]

Sample in the transformed (output) space.

required

Returns:

Name Type Description
Float Float

Log-probability of z under the induced distribution.

sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']] ¤

Sample by drawing from the base prior and applying all transforms.

Parameters:

Name Type Description Default
rng_key Key

JAX PRNG key.

required
n_samples int

Number of samples to draw.

required

Returns:

Type Description
dict[str, Float[Array, ' n_samples']]

dict[str, Float[Array, " n_samples"]]: Transformed samples keyed by parameter name.

trace_prior_parent(output: Optional[list[Prior]] = None) -> list[Prior] ¤

Recursively collect all leaf (non-composite) priors.

Parameters:

Name Type Description Default
output Optional[list[Prior]]

Accumulator list. If None, a new list is created. Defaults to None.

None

Returns:

Type Description
list[Prior]

list[Prior]: List of all leaf prior objects in this composite.

transform(x: dict[str, Float]) -> dict[str, Float] ¤

Apply all transforms sequentially (forward direction).

Parameters:

Name Type Description Default
x dict[str, Float]

Sample in the base prior space.

required

Returns:

Type Description
dict[str, Float]

dict[str, Float]: Transformed sample.

GaussianPrior ¤

Bases: SequentialTransformPrior

Gaussian (normal) prior with specified mean and standard deviation.

Attributes:

Name Type Description
mu float

Mean of the distribution.

sigma float

Standard deviation of the distribution.

parameter_names list[str]

Name of the parameter.

mu: float = mu instance-attribute ¤
sigma: float = sigma instance-attribute ¤
parameter_names = tuple(transform.propagate_name(self.parameter_names)) instance-attribute ¤
n_dims: int property ¤
base_prior: tuple[Prior, ...] = tuple(priors) instance-attribute ¤
transforms: tuple[BijectiveTransform, ...] = tuple(transforms) instance-attribute ¤
__repr__() ¤
__init__(mu: float, sigma: float, parameter_names: list[str]) ¤

A convenient wrapper distribution on top of the StandardNormalDistribution class which scale and translate the distribution according to the mean and standard deviation.

Args mu: The mean of the distribution. sigma: The standard deviation of the distribution. parameter_names: A list of names for the parameters of the prior.

add_name(x: Float[Array, n_dims]) -> dict[str, Float] ¤

Turn an array into a dictionary.

Parameters:

Name Type Description Default
x Array

An array of parameters. Shape (n_dims,).

required
__call__(x: dict[str, Float]) -> Float ¤
log_prob(z: dict[str, Float]) -> Float ¤

Evaluate the log-probability of a transformed sample z.

Applies the inverse transforms in reverse order, accumulating log-Jacobian determinants, then evaluates the base prior.

Parameters:

Name Type Description Default
z dict[str, Float]

Sample in the transformed (output) space.

required

Returns:

Name Type Description
Float Float

Log-probability of z under the induced distribution.

sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']] ¤

Sample by drawing from the base prior and applying all transforms.

Parameters:

Name Type Description Default
rng_key Key

JAX PRNG key.

required
n_samples int

Number of samples to draw.

required

Returns:

Type Description
dict[str, Float[Array, ' n_samples']]

dict[str, Float[Array, " n_samples"]]: Transformed samples keyed by parameter name.

trace_prior_parent(output: Optional[list[Prior]] = None) -> list[Prior] ¤

Recursively collect all leaf (non-composite) priors.

Parameters:

Name Type Description Default
output Optional[list[Prior]]

Accumulator list. If None, a new list is created. Defaults to None.

None

Returns:

Type Description
list[Prior]

list[Prior]: List of all leaf prior objects in this composite.

transform(x: dict[str, Float]) -> dict[str, Float] ¤

Apply all transforms sequentially (forward direction).

Parameters:

Name Type Description Default
x dict[str, Float]

Sample in the base prior space.

required

Returns:

Type Description
dict[str, Float]

dict[str, Float]: Transformed sample.

SinePrior ¤

Bases: BoundedMixin, SequentialTransformPrior

Prior with PDF proportional to sin(x) over [0, pi].

Attributes:

Name Type Description
parameter_names list[str]

Name of the parameter.

xmin: float = 0.0 class-attribute instance-attribute ¤
xmax: float = jnp.pi class-attribute instance-attribute ¤
parameter_names = tuple(transform.propagate_name(self.parameter_names)) instance-attribute ¤
n_dims: int property ¤
base_prior: tuple[Prior, ...] = tuple(priors) instance-attribute ¤
transforms: tuple[BijectiveTransform, ...] = tuple(transforms) instance-attribute ¤
__repr__() ¤
__init__(parameter_names: list[str]) ¤
add_name(x: Float[Array, n_dims]) -> dict[str, Float] ¤

Turn an array into a dictionary.

Parameters:

Name Type Description Default
x Array

An array of parameters. Shape (n_dims,).

required
__call__(x: dict[str, Float]) -> Float ¤
log_prob(z: dict[str, Float]) -> Float ¤
sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']] ¤

Sample by drawing from the base prior and applying all transforms.

Parameters:

Name Type Description Default
rng_key Key

JAX PRNG key.

required
n_samples int

Number of samples to draw.

required

Returns:

Type Description
dict[str, Float[Array, ' n_samples']]

dict[str, Float[Array, " n_samples"]]: Transformed samples keyed by parameter name.

trace_prior_parent(output: Optional[list[Prior]] = None) -> list[Prior] ¤

Recursively collect all leaf (non-composite) priors.

Parameters:

Name Type Description Default
output Optional[list[Prior]]

Accumulator list. If None, a new list is created. Defaults to None.

None

Returns:

Type Description
list[Prior]

list[Prior]: List of all leaf prior objects in this composite.

transform(x: dict[str, Float]) -> dict[str, Float] ¤

Apply all transforms sequentially (forward direction).

Parameters:

Name Type Description Default
x dict[str, Float]

Sample in the base prior space.

required

Returns:

Type Description
dict[str, Float]

dict[str, Float]: Transformed sample.

CosinePrior ¤

Bases: BoundedMixin, SequentialTransformPrior

Prior with PDF proportional to cos(x) over [-pi/2, pi/2].

Attributes:

Name Type Description
parameter_names list[str]

Name of the parameter.

xmin: float = -jnp.pi / 2 class-attribute instance-attribute ¤
xmax: float = jnp.pi / 2 class-attribute instance-attribute ¤
parameter_names = tuple(transform.propagate_name(self.parameter_names)) instance-attribute ¤
n_dims: int property ¤
base_prior: tuple[Prior, ...] = tuple(priors) instance-attribute ¤
transforms: tuple[BijectiveTransform, ...] = tuple(transforms) instance-attribute ¤
__repr__() ¤
__init__(parameter_names: list[str]) ¤
add_name(x: Float[Array, n_dims]) -> dict[str, Float] ¤

Turn an array into a dictionary.

Parameters:

Name Type Description Default
x Array

An array of parameters. Shape (n_dims,).

required
__call__(x: dict[str, Float]) -> Float ¤
log_prob(z: dict[str, Float]) -> Float ¤
sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']] ¤

Sample by drawing from the base prior and applying all transforms.

Parameters:

Name Type Description Default
rng_key Key

JAX PRNG key.

required
n_samples int

Number of samples to draw.

required

Returns:

Type Description
dict[str, Float[Array, ' n_samples']]

dict[str, Float[Array, " n_samples"]]: Transformed samples keyed by parameter name.

trace_prior_parent(output: Optional[list[Prior]] = None) -> list[Prior] ¤

Recursively collect all leaf (non-composite) priors.

Parameters:

Name Type Description Default
output Optional[list[Prior]]

Accumulator list. If None, a new list is created. Defaults to None.

None

Returns:

Type Description
list[Prior]

list[Prior]: List of all leaf prior objects in this composite.

transform(x: dict[str, Float]) -> dict[str, Float] ¤

Apply all transforms sequentially (forward direction).

Parameters:

Name Type Description Default
x dict[str, Float]

Sample in the base prior space.

required

Returns:

Type Description
dict[str, Float]

dict[str, Float]: Transformed sample.

UniformSpherePrior ¤

Bases: CombinePrior

Uniform prior over a sphere, parameterized by magnitude, theta, and phi.

Attributes:

Name Type Description
parameter_names list[str]

Names of the vector, theta, and phi parameters.

parameter_names = tuple([name for prior in priors for name in (prior.parameter_names)]) instance-attribute ¤
n_dims: int property ¤
base_prior: tuple[Prior, ...] = field(default_factory=tuple) class-attribute instance-attribute ¤
__repr__() ¤
__init__(parameter_names: list[str], max_mag: float = 1.0) ¤
add_name(x: Float[Array, n_dims]) -> dict[str, Float] ¤

Turn an array into a dictionary.

Parameters:

Name Type Description Default
x Array

An array of parameters. Shape (n_dims,).

required
__call__(x: dict[str, Float]) -> Float ¤
log_prob(z: dict[str, Float]) -> Float ¤

Evaluate the joint log-probability as the sum of component log-probabilities.

Parameters:

Name Type Description Default
z dict[str, Float]

Dictionary of parameter values.

required

Returns:

Name Type Description
Float Float

Sum of log-probabilities from all component priors.

sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']] ¤

Sample from all component priors independently.

Parameters:

Name Type Description Default
rng_key Key

JAX PRNG key (split internally for each component).

required
n_samples int

Number of samples to draw.

required

Returns:

Type Description
dict[str, Float[Array, ' n_samples']]

dict[str, Float[Array, " n_samples"]]: Combined samples from all component priors, keyed by parameter name.

trace_prior_parent(output: Optional[list[Prior]] = None) -> list[Prior] ¤

Recursively collect all leaf (non-composite) priors.

Parameters:

Name Type Description Default
output Optional[list[Prior]]

Accumulator list. If None, a new list is created. Defaults to None.

None

Returns:

Type Description
list[Prior]

list[Prior]: List of all leaf prior objects in this composite.

RayleighPrior ¤

Bases: BoundedMixin, SequentialTransformPrior

Rayleigh distribution prior with scale parameter sigma.

Attributes:

Name Type Description
sigma float

Scale parameter of the Rayleigh distribution.

parameter_names list[str]

Name of the parameter.

xmin: float = 0.0 class-attribute instance-attribute ¤
xmax: float = jnp.inf class-attribute instance-attribute ¤
sigma: float = sigma instance-attribute ¤
parameter_names = tuple(transform.propagate_name(self.parameter_names)) instance-attribute ¤
n_dims: int property ¤
base_prior: tuple[Prior, ...] = tuple(priors) instance-attribute ¤
transforms: tuple[BijectiveTransform, ...] = tuple(transforms) instance-attribute ¤
__repr__() ¤
__init__(sigma: float, parameter_names: list[str]) ¤
add_name(x: Float[Array, n_dims]) -> dict[str, Float] ¤

Turn an array into a dictionary.

Parameters:

Name Type Description Default
x Array

An array of parameters. Shape (n_dims,).

required
__call__(x: dict[str, Float]) -> Float ¤
log_prob(z: dict[str, Float]) -> Float ¤
sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']] ¤

Sample by drawing from the base prior and applying all transforms.

Parameters:

Name Type Description Default
rng_key Key

JAX PRNG key.

required
n_samples int

Number of samples to draw.

required

Returns:

Type Description
dict[str, Float[Array, ' n_samples']]

dict[str, Float[Array, " n_samples"]]: Transformed samples keyed by parameter name.

trace_prior_parent(output: Optional[list[Prior]] = None) -> list[Prior] ¤

Recursively collect all leaf (non-composite) priors.

Parameters:

Name Type Description Default
output Optional[list[Prior]]

Accumulator list. If None, a new list is created. Defaults to None.

None

Returns:

Type Description
list[Prior]

list[Prior]: List of all leaf prior objects in this composite.

transform(x: dict[str, Float]) -> dict[str, Float] ¤

Apply all transforms sequentially (forward direction).

Parameters:

Name Type Description Default
x dict[str, Float]

Sample in the base prior space.

required

Returns:

Type Description
dict[str, Float]

dict[str, Float]: Transformed sample.

PowerLawPrior ¤

Bases: SequentialTransformPrior

Power-law prior over [xmin, xmax] with exponent alpha.

Attributes:

Name Type Description
xmin float

Lower bound of the interval.

xmax float

Upper bound of the interval.

alpha float

Power-law exponent.

parameter_names list[str]

Name of the parameter.

xmax: float = xmax instance-attribute ¤
xmin: float = xmin instance-attribute ¤
alpha: float = alpha instance-attribute ¤
parameter_names = tuple(transform.propagate_name(self.parameter_names)) instance-attribute ¤
n_dims: int property ¤
base_prior: tuple[Prior, ...] = tuple(priors) instance-attribute ¤
transforms: tuple[BijectiveTransform, ...] = tuple(transforms) instance-attribute ¤
__repr__() ¤
__init__(xmin: float, xmax: float, alpha: float, parameter_names: list[str]) ¤
add_name(x: Float[Array, n_dims]) -> dict[str, Float] ¤

Turn an array into a dictionary.

Parameters:

Name Type Description Default
x Array

An array of parameters. Shape (n_dims,).

required
__call__(x: dict[str, Float]) -> Float ¤
log_prob(z: dict[str, Float]) -> Float ¤

Evaluate the log-probability of a transformed sample z.

Applies the inverse transforms in reverse order, accumulating log-Jacobian determinants, then evaluates the base prior.

Parameters:

Name Type Description Default
z dict[str, Float]

Sample in the transformed (output) space.

required

Returns:

Name Type Description
Float Float

Log-probability of z under the induced distribution.

sample(rng_key: Key, n_samples: int) -> dict[str, Float[Array, ' n_samples']] ¤

Sample by drawing from the base prior and applying all transforms.

Parameters:

Name Type Description Default
rng_key Key

JAX PRNG key.

required
n_samples int

Number of samples to draw.

required

Returns:

Type Description
dict[str, Float[Array, ' n_samples']]

dict[str, Float[Array, " n_samples"]]: Transformed samples keyed by parameter name.

trace_prior_parent(output: Optional[list[Prior]] = None) -> list[Prior] ¤

Recursively collect all leaf (non-composite) priors.

Parameters:

Name Type Description Default
output Optional[list[Prior]]

Accumulator list. If None, a new list is created. Defaults to None.

None

Returns:

Type Description
list[Prior]

list[Prior]: List of all leaf prior objects in this composite.

transform(x: dict[str, Float]) -> dict[str, Float] ¤

Apply all transforms sequentially (forward direction).

Parameters:

Name Type Description Default
x dict[str, Float]

Sample in the base prior space.

required

Returns:

Type Description
dict[str, Float]

dict[str, Float]: Transformed sample.