Skip to content

Parallel tempering

flowMC.strategy.parallel_tempering ¤

ParallelTempering ¤

Bases: Strategy

Sample a tempered PDF with one exchange step. This is in essence closer to TakeSteps than global tuning. Considering the tempered version of the PDF is only there to help with convergence, by default the extra information in temperature not equal to 1 is not saved.

There should be a version of this class that saves the extra information in the temperature not equal to 1, which could be used for other purposes such as diagnostics or training.

n_steps: int = n_steps instance-attribute ¤
tempered_logpdf_name: str = tempered_logpdf_name instance-attribute ¤
kernel_name: str = kernel_name instance-attribute ¤
tempered_buffer_names: list[str] = tempered_buffer_names instance-attribute ¤
state_name = state_name instance-attribute ¤
__init__(n_steps: int, tempered_logpdf_name: str, kernel_name: str, tempered_buffer_names: list[str], state_name: str, verbose: bool = False) -> None ¤

Parameters:

Name Type Description Default
n_steps int

Number of local kernel steps per temperature per call.

required
tempered_logpdf_name str

Resource key for the :class:~flowMC.resource.logPDF.TemperedPDF.

required
kernel_name str

Resource key for the local proposal kernel.

required
tempered_buffer_names list[str]

Resource keys for the tempered-position buffer and the temperature buffer (in that order).

required
state_name str

Resource key for the sampler :class:~flowMC.resource.states.State.

required
verbose bool

Enable debug logging. Defaults to False.

False
__call__(rng_key: Key, resources: dict[str, Resource], initial_position: Float[Array, 'n_chains n_dims'], data: dict) -> tuple[Key, dict[str, Resource], Float[Array, 'n_chains n_dim']] ¤
Resources must contain
  • TemperedPDF
  • Local kernel
  • A buffer holding the tempered positions
  • A buffer holding the temperatures

This strategy has 3 main steps: 1. Sample from the tempered PDF using the local kernel for n_steps 2. Exchange the samples between the temperatures 3. Adapt the temperatures based on the acceptance rate

TODO: Add way to turn of temperature adaptation to maintain detail balance.

_individual_step_body(kernel: ProposalBase, carry: tuple[Key, Float[Array, ' n_dims'], Float[Array, 1], TemperedPDF, Float[Array, ' n_temps'], dict], aux) -> tuple[tuple[Key, Float[Array, ' n_dims'], Float[Array, 1], TemperedPDF, Float[Array, ' n_temps'], dict], tuple[Float[Array, ' n_dims'], Float[Array, 1], Int[Array, 1]]] ¤

Take a step using the kernel and the tempered logpdf. This should not be called directly but instead used in a jax.lax.scan to take multiple steps.

Parameters:

Name Type Description Default
kernel ProposalBase

The kernel to use.

required
carry tuple

The current state of the chain. - key (Key): jax random key. - position (Float[Array, "n_dims"]): Current position of the chain. - log_prob (Float[Array, "1"]): Current log probability of the chain. - logpdf (TemperedPDF): The tempered LogPDF class. - temperatures (Float[Array, "n_temps"]): Array of temperatures. - data (dict): Additional data to pass to the logpdf.

required
aux None

Not used.

required

Returns: tuple: Updated carry and the result of the kernel step. - carry (tuple): Updated state of the chain. - key (Key): jax random key. - position (Float[Array, "n_dims"]): New position of the chain. - log_prob (Float[Array, "1"]): New log probability of the chain. - logpdf (TemperedPDF): The tempered LogPDF class. - temperatures (Float[Array, "n_temps"]): Array of temperatures. - data (dict): Additional data to pass to the logpdf. - result (tuple): Result of the kernel step. - position (Float[Array, "n_dims"]): New position of the chain. - log_prob (Float[Array, "1"]): New log probability of the chain. - do_accept (Int[Array, "1"]): Whether the new position is accepted.

_individal_step(kernel: ProposalBase, rng_key: Key, positions: Float[Array, ' n_dims'], logpdf: TemperedPDF, temperatures: Float[Array, ' n_temps'], data: dict) -> tuple[Float[Array, ' n_dims'], Float[Array, 1], Int[Array, 1]] ¤

Perform a series of individual steps for a single chain using the kernel.

Parameters:

Name Type Description Default
kernel ProposalBase

The kernel to use for proposing new positions.

required
rng_key Key

jax random key for reproducibility.

required
positions Float[Array, n_dims]

Current positions of the chain.

required
logpdf TemperedPDF

The tempered log probability density function.

required
temperatures Float[Array, n_temps]

Array of temperatures.

required
data dict

Additional data to pass to the logpdf.

required

Returns:

Name Type Description
tuple tuple[Float[Array, ' n_dims'], Float[Array, 1], Int[Array, 1]]
  • positions (Float[Array, "n_dims"]): Updated positions of the chain.
  • log_probs (Float[Array, "1"]): Log probabilities of the chain.
  • do_accept (Int[Array, "1"]): Acceptance flag for the new position.
_ensemble_step(kernel: ProposalBase, rng_key: Key, positions: Float[Array, 'n_temps n_dims'], logpdf: TemperedPDF, temperatures: Float[Array, ' n_temps'], data: dict) -> tuple[Float[Array, 'n_temps n_dims'], Float[Array, ' n_temps'], Int[Array, ' n_temps']] ¤

Perform ensemble steps for all chains and temperatures.

Parameters:

Name Type Description Default
kernel ProposalBase

The kernel to use for proposing new positions.

required
rng_key Key

Random key for reproducibility.

required
positions Float[Array, 'n_temps n_dims']

Current positions for all temperatures.

required
logpdf TemperedPDF

The tempered log probability density function.

required
temperatures Float[Array, n_temps]

Array of temperatures.

required
data dict

Additional data to pass to the logpdf.

required

Returns:

Name Type Description
tuple tuple[Float[Array, 'n_temps n_dims'], Float[Array, ' n_temps'], Int[Array, ' n_temps']]
  • positions (Float[Array, "n_temps n_dims"]): Updated positions for all temperatures.
  • log_probs (Float[Array, "n_temps"]): Log probabilities for all temperatures.
  • do_accept (Int[Array, "n_temps"]): Acceptance flags for each temperature.
_exchange_step_body(carry: tuple[Key, Float[Array, 'n_temps n_dims'], Float[Array, ' n_temps'], int, TemperedPDF, Float[Array, ' n_temps'], dict], aux: None) -> tuple[tuple[Key, Float[Array, 'n_temps n_dims'], Float[Array, ' n_temps'], int, TemperedPDF, Float[Array, ' n_temps'], dict], Int[Array, 1]] ¤
_exchange(key: Key, positions: Float[Array, 'n_temps n_dims'], logpdf: TemperedPDF, temperatures: Float[Array, ' n_temps'], data: dict) -> tuple[Float[Array, 'n_temps n_dims'], Float[Array, ' n_temps'], Int[Array, ' n_temps - 1']] ¤

Perform exchange steps between adjacent temperatures.

Parameters:

Name Type Description Default
key Key

jax random key for reproducibility.

required
positions Float[Array, 'n_temps n_dims']

Current positions for all temperatures.

required
logpdf TemperedPDF

The tempered log probability density function.

required
temperatures Float[Array, n_temps]

Array of temperatures.

required
data dict

Additional data to pass to the logpdf.

required

Returns:

Name Type Description
tuple tuple[Float[Array, 'n_temps n_dims'], Float[Array, ' n_temps'], Int[Array, ' n_temps - 1']]
  • positions (Float[Array, "n_temps n_dims"]): Updated positions for all temperatures.
  • log_probs (Float[Array, "n_temps"]): Log probabilities for all temperatures.
  • do_accept (Int[Array, "n_temps - 1"]): Acceptance flags for each temperature.
_adapt_temperature(temperatures: Float[Array, ' n_temps'], do_accept: Int[Array, 'n_chains n_temps 1']) -> Float[Array, ' n_temps'] ¤

Adapt the temperatures based on the acceptance rates.

Parameters:

Name Type Description Default
temperatures Float[Array, n_temps]

Current temperatures.

required
do_accept Int[Array, 'n_chains n_temps 1']

Acceptance flags for each chain and temperature.

required

Returns:

Type Description
Float[Array, ' n_temps']

Float[Array, "n_temps"]: Updated temperatures.

TODO: The adaptation now let's the temperature to go above the maximum temperature. Need to add a check to prevent this.