Parallel Tempering¤
import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from jaxtyping import Float, Array
from typing import Any
from flowMC.resource.kernel.MALA import MALA
from flowMC.resource.buffers import Buffer
from flowMC.resource.states import State
from flowMC.strategy.take_steps import TakeSerialSteps
from flowMC.strategy.parallel_tempering import ParallelTempering
from flowMC.Sampler import Sampler
from flowMC.resource.logPDF import TemperedPDF, LogPDF
def target_dual_moon(x: Float[Array, "n_dims"], data: dict[str, Any]) -> Float:
"""
Term 2 and 3 separate the distribution and smear it along the first and second dimension
"""
term1 = 0.5 * ((jnp.linalg.norm(x) - 2) / 0.1) ** 2
term2 = -0.5 * ((x[:1] + jnp.array([-3.0, 3.0])) / 0.8) ** 2
term3 = -0.5 * ((x[1:2] + jnp.array([-3.0, 3.0])) / 0.6) ** 2
return -(term1 - logsumexp(term2) - logsumexp(term3))
Let's setup a normal run with a MALA sampler
# Defining hyperparameters
n_chains = 5
rng_key = jax.random.PRNGKey(0)
n_steps = 400
n_dims = 5
step_size = jnp.full(n_dims, 0.1)
data = {"data": jnp.arange(n_dims).astype(jnp.float32)}
# Setting up resources
MALA_sampler = MALA(step_size=step_size)
positions = Buffer("positions", (n_chains, n_steps, n_dims), 1)
log_prob = Buffer("log_prob", (n_chains, n_steps), 1)
acceptance = Buffer("acceptance", (n_chains, n_steps), 1)
sampler_state = State(
{
"positions": "positions",
"log_prob": "log_prob",
"acceptance": "acceptance",
},
name="sampler_state",
)
resource = {
"positions": positions,
"log_prob": log_prob,
"acceptance": acceptance,
"MALA": MALA_sampler,
"logpdf": LogPDF(target_dual_moon, n_dims=n_dims),
"state": sampler_state,
}
# Defining strategy
strategy = TakeSerialSteps(
logpdf_name="logpdf",
kernel_name="MALA",
state_name="state",
buffer_names=["positions", "log_prob", "acceptance"],
n_steps=n_steps,
)
# Initializing sampler
sampler = Sampler(
n_dim=n_dims,
n_chains=n_chains,
rng_key=rng_key,
resources=resource,
strategies={"take_steps": strategy},
strategy_order=["take_steps"],
)
sampler.sample(
initial_position=jax.random.normal(rng_key, (n_chains, n_dims)),
data=data,
)
import corner
import numpy as np
import matplotlib.pyplot as plt
chains = sampler.resources["positions"].data
labels = [f"x{i}" for i in range(n_dims)]
# Plotting chain from flowMC
fig = plt.figure(figsize=(6, 6))
fig = corner.corner(np.array(chains.reshape(-1, n_dims)), fig=fig, labels=labels)
fig = plt.figure(figsize=(6, 6))
plt.plot(chains[:, :, 0].T, chains[:, :, 1].T, alpha=0.5)
plt.xlabel("x0")
plt.ylabel("x1")
plt.show()
Now let's add parallel tempering to the mix¤
import itertools
# Defining hyperparameters
n_chains = 5
n_steps = 10
n_dims = 5
step_size = jnp.full(n_dims, 0.1)
n_temps = 5
n_loops = 40
data = {
"data": jnp.arange(n_dims).astype(jnp.float32),
}
# Setting up resources
logpdf = LogPDF(target_dual_moon, n_dims=n_dims)
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
initial_position = jax.random.normal(subkey, shape=(n_chains, n_dims))
MALA_sampler = MALA(step_size=step_size)
positions = Buffer("positions", (n_chains, n_steps * n_loops, n_dims), 1)
log_prob = Buffer("log_prob", (n_chains, n_steps * n_loops), 1)
acceptance = Buffer("acceptance", (n_chains, n_steps * n_loops), 1)
sampler_state = State(
{
"positions": "positions",
"log_prob": "log_prob",
"acceptance": "acceptance",
"training": True,
},
name="sampler_state",
)
positions.update_buffer(initial_position[:, None])
log_prob.update_buffer(
jax.vmap(logpdf, in_axes=(0, None))(initial_position, data)[:, None]
)
take_steps = TakeSerialSteps(
logpdf_name="logpdf",
kernel_name="MALA",
state_name="state",
buffer_names=["positions", "log_prob", "acceptance"],
n_steps=n_steps,
)
Here below are the tempering specific code to define the extra resources and strategies
key, subkey = jax.random.split(key)
tempered_logpdf = TemperedPDF(
target_dual_moon,
lambda x, data: jnp.array(0.0),
n_dims=n_dims,
n_temps=n_temps,
)
tempered_initial_position = jax.random.normal(
subkey, shape=(n_chains, n_temps - 1, n_dims)
)
tempered_positions = Buffer("tempered_positions", (n_chains, n_temps - 1, n_dims), 2)
tempered_positions.update_buffer(tempered_initial_position)
temperatures = Buffer("temperature", (n_temps,), 0)
temperatures.update_buffer(jnp.arange(n_temps) * 5 + 1.0)
resources = {
"logpdf": logpdf,
"MALA": MALA_sampler,
"positions": positions,
"log_prob": log_prob,
"acceptance": acceptance,
"tempered_logpdf": tempered_logpdf,
"tempered_positions": tempered_positions,
"temperatures": temperatures,
"state": sampler_state,
}
parallel_tempering_strat = ParallelTempering(
n_steps=n_steps,
tempered_logpdf_name="tempered_logpdf",
kernel_name="MALA",
state_name="state",
tempered_buffer_names=["tempered_positions", "temperatures"],
)
strategy_order = [["take_steps", "parallel_tempering"] for _ in range(n_loops)]
strategy_order = list(itertools.chain.from_iterable(strategy_order))
sampler = Sampler(
n_dim=n_dims,
n_chains=n_chains,
rng_key=rng_key,
resources=resources,
strategies={
"take_steps": take_steps,
"parallel_tempering": parallel_tempering_strat,
},
strategy_order=strategy_order,
)
sampler.sample(
initial_position=initial_position,
data=data,
)
chains = sampler.resources["positions"].data
print(chains.shape)
labels = [f"x{i}" for i in range(n_dims)]
# Plotting chain from flowMC
fig = plt.figure(figsize=(6, 6))
fig = corner.corner(np.array(chains.reshape(-1, n_dims)), fig=fig, labels=labels)
fig = plt.figure(figsize=(6, 6))
plt.plot(chains[:, :, 0].T, chains[:, :, 1].T, alpha=0.5)
plt.xlabel("x0")
plt.ylabel("x1")
plt.show()
Now instead of defining your sampler from scratch, you can also use the parallel tempering bundle, which is the original flowMC sampler with the parallel tempering strategy. The training of the normalizing flow is a computationally expensive task, so one can expect this strategy bundle to be slower than the lighter version as shown above.
from flowMC.resource_strategy_bundle.RQSpline_MALA_PT import RQSpline_MALA_PT_Bundle
rng_key = jax.random.PRNGKey(0)
rng_key, subkey = jax.random.split(rng_key)
n_chains = 5
n_dims = 5
bundle = RQSpline_MALA_PT_Bundle(
subkey,
n_chains=n_chains,
n_dims=n_dims,
logpdf=target_dual_moon,
n_local_steps=2,
n_global_steps=2,
n_training_loops=100,
n_production_loops=100,
n_epochs=2,
rq_spline_hidden_units=[32, 32],
rq_spline_n_layers=2,
max_temperature=20.0,
n_temperatures=5,
n_tempered_steps=3,
)
rng_key, subkey = jax.random.split(rng_key)
initial_position = jax.random.normal(subkey, shape=(n_chains, n_dims))
sampler = Sampler(
n_dim=n_dims,
n_chains=n_chains,
rng_key=rng_key,
resource_strategy_bundles=bundle,
)
sampler.sample(
initial_position=initial_position,
data=data,
)
chains = sampler.resources["positions_training"].data
print(chains.shape)
labels = [f"x{i}" for i in range(n_dims)]
# Plotting chain from flowMC
fig = plt.figure(figsize=(6, 6))
fig = corner.corner(np.array(chains.reshape(-1, n_dims)), fig=fig, labels=labels)
fig = plt.figure(figsize=(6, 6))
plt.plot(chains[:, :, 0].T, chains[:, :, 1].T, alpha=0.5)
plt.xlabel("x0")
plt.ylabel("x1")
plt.show()