Common
Bijection
¤
Bases: Module
Base class for bijective transformations.
Subclasses must implement forward and inverse.
The default __call__ delegates to forward.
This is an abstract template that should not be directly used.
Methods:
| Name | Description |
|---|---|
__call__ |
Apply the forward transformation. |
forward |
Transform from input space to output space. |
inverse |
Transform from output space back to input space. |
__call__(x: Float[Array, ' n_dim'], condition: Float[Array, ' n_condition']) -> tuple[Float[Array, ' n_dim'], Float[Array, ' n_dim']]
¤
Apply the forward transformation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, n_dim]
|
Input array. |
required |
condition
|
Float[Array, n_condition]
|
Conditioning variables. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ' n_dim'], Float[Array, ' n_dim']]
|
tuple[Float[Array, "n_dim"], Float[Array, "n_dim"]]: Transformed output and per-dimension log-det Jacobian. |
forward(x: Float[Array, ' n_dim'], condition: Float[Array, ' n_condition']) -> tuple[Float[Array, ' n_dim'], Float[Array, ' n_dim']]
abstractmethod
¤
Transform from input space to output space.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, n_dim]
|
Input array. |
required |
condition
|
Float[Array, n_condition]
|
Conditioning variables. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ' n_dim'], Float[Array, ' n_dim']]
|
tuple[Float[Array, "n_dim"], Float[Array, "n_dim"]]: Transformed output and per-dimension log-det Jacobian. |
inverse(x: Float[Array, ' n_dim'], condition: Float[Array, ' n_condition']) -> tuple[Float[Array, ' n_dim'], Float[Array, ' n_dim']]
abstractmethod
¤
Transform from output space back to input space.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, n_dim]
|
Array in the output (transformed) space. |
required |
condition
|
Float[Array, n_condition]
|
Conditioning variables. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ' n_dim'], Float[Array, ' n_dim']]
|
tuple[Float[Array, "n_dim"], Float[Array, "n_dim"]]: Inverse output and per-dimension log-det Jacobian. |
Distribution
¤
Bases: Module
Base class for probability distributions.
Subclasses must implement log_prob and sample.
The default __call__ delegates to log_prob.
This is an abstract template that should not be directly used.
Methods:
| Name | Description |
|---|---|
__call__ |
Evaluate the log-probability of |
__call__(x: Array, key: Optional[Key] = None) -> Array
¤
Evaluate the log-probability of x.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input sample. |
required |
key
|
Key
|
Unused; reserved for subclass compatibility. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
Array |
Array
|
Log-probability of |
Gaussian
¤
Bases: Distribution
Multivariate Gaussian distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mean
|
Array
|
Mean. |
required |
cov
|
Array
|
Covariance matrix. |
required |
learnable
|
bool
|
Whether the mean and covariance matrix are learnable parameters. |
False
|
Attributes:
| Name | Type | Description |
|---|---|---|
mean |
Array
|
Mean. |
cov |
Array
|
Covariance matrix. |
MLP
¤
Bases: Module
Multilayer perceptron.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shape
|
List[int]
|
Shape of the MLP. The first element is the input dimension, the last element is the output dimension. |
required |
key
|
Key
|
Random key. |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
layers |
List
|
List of layers. |
activation |
Callable
|
Activation function. |
use_bias |
bool
|
Whether to use bias. |
MaskedCouplingLayer
¤
Bases: Bijection
Masked coupling layer.
f(x) = (1-m)b(x;c(mx;z)) + m*x where b is the inner bijector, m is the mask, and c is the conditioner.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bijector
|
Bijection
|
inner bijector in the masked coupling layer. |
required |
mask
|
Array
|
Mask. 0 for the input variables that are transformed, 1 for the input variables that are not transformed. |
required |