Parallel tempering
flowMC.strategy.parallel_tempering
¤
ParallelTempering
¤
Bases: Strategy
Sample a tempered PDF with one exchange step. This is in essence closer to TakeSteps than global tuning. Considering the tempered version of the PDF is only there to help with convergence, by default the extra information in temperature not equal to 1 is not saved.
There should be a version of this class that saves the extra information in the temperature not equal to 1, which could be used for other purposes such as diagnostics or training.
n_steps: int = n_steps
instance-attribute
¤
tempered_logpdf_name: str = tempered_logpdf_name
instance-attribute
¤
kernel_name: str = kernel_name
instance-attribute
¤
tempered_buffer_names: list[str] = tempered_buffer_names
instance-attribute
¤
state_name = state_name
instance-attribute
¤
__init__(n_steps: int, tempered_logpdf_name: str, kernel_name: str, tempered_buffer_names: list[str], state_name: str, verbose: bool = False) -> None
¤
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_steps
|
int
|
Number of local kernel steps per temperature per call. |
required |
tempered_logpdf_name
|
str
|
Resource key for the
:class: |
required |
kernel_name
|
str
|
Resource key for the local proposal kernel. |
required |
tempered_buffer_names
|
list[str]
|
Resource keys for the tempered-position buffer and the temperature buffer (in that order). |
required |
state_name
|
str
|
Resource key for the sampler
:class: |
required |
verbose
|
bool
|
Enable debug logging. Defaults to False. |
False
|
__call__(rng_key: Key, resources: dict[str, Resource], initial_position: Float[Array, 'n_chains n_dims'], data: dict) -> tuple[Key, dict[str, Resource], Float[Array, 'n_chains n_dim']]
¤
Resources must contain
- TemperedPDF
- Local kernel
- A buffer holding the tempered positions
- A buffer holding the temperatures
This strategy has 3 main steps: 1. Sample from the tempered PDF using the local kernel for n_steps 2. Exchange the samples between the temperatures 3. Adapt the temperatures based on the acceptance rate
TODO: Add way to turn of temperature adaptation to maintain detail balance.
_individual_step_body(kernel: ProposalBase, carry: tuple[Key, Float[Array, ' n_dims'], Float[Array, 1], TemperedPDF, Float[Array, ' n_temps'], dict], aux) -> tuple[tuple[Key, Float[Array, ' n_dims'], Float[Array, 1], TemperedPDF, Float[Array, ' n_temps'], dict], tuple[Float[Array, ' n_dims'], Float[Array, 1], Int[Array, 1]]]
¤
Take a step using the kernel and the tempered logpdf. This should not be called directly but instead used in a jax.lax.scan to take multiple steps.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
kernel
|
ProposalBase
|
The kernel to use. |
required |
carry
|
tuple
|
The current state of the chain. - key (Key): jax random key. - position (Float[Array, "n_dims"]): Current position of the chain. - log_prob (Float[Array, "1"]): Current log probability of the chain. - logpdf (TemperedPDF): The tempered LogPDF class. - temperatures (Float[Array, "n_temps"]): Array of temperatures. - data (dict): Additional data to pass to the logpdf. |
required |
aux
|
None
|
Not used. |
required |
Returns: tuple: Updated carry and the result of the kernel step. - carry (tuple): Updated state of the chain. - key (Key): jax random key. - position (Float[Array, "n_dims"]): New position of the chain. - log_prob (Float[Array, "1"]): New log probability of the chain. - logpdf (TemperedPDF): The tempered LogPDF class. - temperatures (Float[Array, "n_temps"]): Array of temperatures. - data (dict): Additional data to pass to the logpdf. - result (tuple): Result of the kernel step. - position (Float[Array, "n_dims"]): New position of the chain. - log_prob (Float[Array, "1"]): New log probability of the chain. - do_accept (Int[Array, "1"]): Whether the new position is accepted.
_individal_step(kernel: ProposalBase, rng_key: Key, positions: Float[Array, ' n_dims'], logpdf: TemperedPDF, temperatures: Float[Array, ' n_temps'], data: dict) -> tuple[Float[Array, ' n_dims'], Float[Array, 1], Int[Array, 1]]
¤
Perform a series of individual steps for a single chain using the kernel.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
kernel
|
ProposalBase
|
The kernel to use for proposing new positions. |
required |
rng_key
|
Key
|
jax random key for reproducibility. |
required |
positions
|
Float[Array, n_dims]
|
Current positions of the chain. |
required |
logpdf
|
TemperedPDF
|
The tempered log probability density function. |
required |
temperatures
|
Float[Array, n_temps]
|
Array of temperatures. |
required |
data
|
dict
|
Additional data to pass to the logpdf. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
tuple |
tuple[Float[Array, ' n_dims'], Float[Array, 1], Int[Array, 1]]
|
|
_ensemble_step(kernel: ProposalBase, rng_key: Key, positions: Float[Array, 'n_temps n_dims'], logpdf: TemperedPDF, temperatures: Float[Array, ' n_temps'], data: dict) -> tuple[Float[Array, 'n_temps n_dims'], Float[Array, ' n_temps'], Int[Array, ' n_temps']]
¤
Perform ensemble steps for all chains and temperatures.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
kernel
|
ProposalBase
|
The kernel to use for proposing new positions. |
required |
rng_key
|
Key
|
Random key for reproducibility. |
required |
positions
|
Float[Array, 'n_temps n_dims']
|
Current positions for all temperatures. |
required |
logpdf
|
TemperedPDF
|
The tempered log probability density function. |
required |
temperatures
|
Float[Array, n_temps]
|
Array of temperatures. |
required |
data
|
dict
|
Additional data to pass to the logpdf. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
tuple |
tuple[Float[Array, 'n_temps n_dims'], Float[Array, ' n_temps'], Int[Array, ' n_temps']]
|
|
_exchange_step_body(carry: tuple[Key, Float[Array, 'n_temps n_dims'], Float[Array, ' n_temps'], int, TemperedPDF, Float[Array, ' n_temps'], dict], aux: None) -> tuple[tuple[Key, Float[Array, 'n_temps n_dims'], Float[Array, ' n_temps'], int, TemperedPDF, Float[Array, ' n_temps'], dict], Int[Array, 1]]
¤
_exchange(key: Key, positions: Float[Array, 'n_temps n_dims'], logpdf: TemperedPDF, temperatures: Float[Array, ' n_temps'], data: dict) -> tuple[Float[Array, 'n_temps n_dims'], Float[Array, ' n_temps'], Int[Array, ' n_temps - 1']]
¤
Perform exchange steps between adjacent temperatures.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
Key
|
jax random key for reproducibility. |
required |
positions
|
Float[Array, 'n_temps n_dims']
|
Current positions for all temperatures. |
required |
logpdf
|
TemperedPDF
|
The tempered log probability density function. |
required |
temperatures
|
Float[Array, n_temps]
|
Array of temperatures. |
required |
data
|
dict
|
Additional data to pass to the logpdf. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
tuple |
tuple[Float[Array, 'n_temps n_dims'], Float[Array, ' n_temps'], Int[Array, ' n_temps - 1']]
|
|
_adapt_temperature(temperatures: Float[Array, ' n_temps'], do_accept: Int[Array, 'n_chains n_temps 1']) -> Float[Array, ' n_temps']
¤
Adapt the temperatures based on the acceptance rates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
temperatures
|
Float[Array, n_temps]
|
Current temperatures. |
required |
do_accept
|
Int[Array, 'n_chains n_temps 1']
|
Acceptance flags for each chain and temperature. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, ' n_temps']
|
Float[Array, "n_temps"]: Updated temperatures. |
TODO: The adaptation now let's the temperature to go above the maximum temperature. Need to add a check to prevent this.