Skip to content

Common

Bijection ¤

Bases: Module

Base class for bijective transformations.

Subclasses must implement forward and inverse. The default __call__ delegates to forward.

This is an abstract template that should not be directly used.

Methods:

Name Description
__call__

Apply the forward transformation.

forward

Transform from input space to output space.

inverse

Transform from output space back to input space.

__call__(x: Float[Array, ' n_dim'], condition: Float[Array, ' n_condition']) -> tuple[Float[Array, ' n_dim'], Float[Array, ' n_dim']] ¤

Apply the forward transformation.

Parameters:

Name Type Description Default
x Float[Array, n_dim]

Input array.

required
condition Float[Array, n_condition]

Conditioning variables.

required

Returns:

Type Description
tuple[Float[Array, ' n_dim'], Float[Array, ' n_dim']]

tuple[Float[Array, "n_dim"], Float[Array, "n_dim"]]: Transformed output and per-dimension log-det Jacobian.

forward(x: Float[Array, ' n_dim'], condition: Float[Array, ' n_condition']) -> tuple[Float[Array, ' n_dim'], Float[Array, ' n_dim']] abstractmethod ¤

Transform from input space to output space.

Parameters:

Name Type Description Default
x Float[Array, n_dim]

Input array.

required
condition Float[Array, n_condition]

Conditioning variables.

required

Returns:

Type Description
tuple[Float[Array, ' n_dim'], Float[Array, ' n_dim']]

tuple[Float[Array, "n_dim"], Float[Array, "n_dim"]]: Transformed output and per-dimension log-det Jacobian.

inverse(x: Float[Array, ' n_dim'], condition: Float[Array, ' n_condition']) -> tuple[Float[Array, ' n_dim'], Float[Array, ' n_dim']] abstractmethod ¤

Transform from output space back to input space.

Parameters:

Name Type Description Default
x Float[Array, n_dim]

Array in the output (transformed) space.

required
condition Float[Array, n_condition]

Conditioning variables.

required

Returns:

Type Description
tuple[Float[Array, ' n_dim'], Float[Array, ' n_dim']]

tuple[Float[Array, "n_dim"], Float[Array, "n_dim"]]: Inverse output and per-dimension log-det Jacobian.

Distribution ¤

Bases: Module

Base class for probability distributions.

Subclasses must implement log_prob and sample. The default __call__ delegates to log_prob.

This is an abstract template that should not be directly used.

Methods:

Name Description
__call__

Evaluate the log-probability of x.

__call__(x: Array, key: Optional[Key] = None) -> Array ¤

Evaluate the log-probability of x.

Parameters:

Name Type Description Default
x Array

Input sample.

required
key Key

Unused; reserved for subclass compatibility.

None

Returns:

Name Type Description
Array Array

Log-probability of x.

Gaussian ¤

Bases: Distribution

Multivariate Gaussian distribution.

Parameters:

Name Type Description Default
mean Array

Mean.

required
cov Array

Covariance matrix.

required
learnable bool

Whether the mean and covariance matrix are learnable parameters.

False

Attributes:

Name Type Description
mean Array

Mean.

cov Array

Covariance matrix.

MLP ¤

Bases: Module

Multilayer perceptron.

Parameters:

Name Type Description Default
shape List[int]

Shape of the MLP. The first element is the input dimension, the last element is the output dimension.

required
key Key

Random key.

required

Attributes:

Name Type Description
layers List

List of layers.

activation Callable

Activation function.

use_bias bool

Whether to use bias.

MaskedCouplingLayer ¤

Bases: Bijection

Masked coupling layer.

f(x) = (1-m)b(x;c(mx;z)) + m*x where b is the inner bijector, m is the mask, and c is the conditioner.

Parameters:

Name Type Description Default
bijector Bijection

inner bijector in the masked coupling layer.

required
mask Array

Mask. 0 for the input variables that are transformed, 1 for the input variables that are not transformed.

required