Skip to content

Generating Waveforms¤

ripple provides a set of waveform models as callable Python classes. Each model takes a frequency (or time) array and a dictionary of physical parameters, and returns the two gravitational-wave polarizations \(h_+\) and \(h_\times\) as JAX arrays.

Because ripple is built on JAX, every waveform is: - Differentiable — gradients via jax.grad / jax.jacobian - JIT-compilable — fast repeated evaluation via jax.jit - Vectorisable — batch evaluation via jax.vmap

Example: IMRPhenomD¤

IMRPhenomD models a non-precessing, aligned-spin BBH merger. The only constructor argument is f_ref, the reference frequency (in Hz) at which the phase is defined (default: 20 Hz).

The parameter dictionary uses: - M_c — chirp mass \(\mathcal{M}\) in solar masses - eta — symmetric mass ratio \(\eta = m_1 m_2 / (m_1+m_2)^2 \in (0, 0.25]\) - s1_z, s2_z — aligned spin components \(\chi_{1z}, \chi_{2z} \in [-1, 1]\) - d_L — luminosity distance in Mpc - phase_c — coalescence phase in radians - iota — inclination angle in radians

import jax
import jax.numpy as jnp
import ripplegw
jax.config.update("jax_enable_x64", True)  # Use double precision for better accuracy

# Frequency grid: 20 Hz to 512 Hz at 0.25 Hz resolution
f_low = 20.0   # Hz
f_high = 512.0  # Hz
delta_f = 0.25   # Hz
frequency = jnp.arange(f_low, f_high, delta_f)

# BBH parameters
params_bbh = {
    "M_c": 28.3,   # chirp mass [Msun]
    "eta": 0.247,  # symmetric mass ratio
    "s1_z": 0.0,   # primary spin z-component
    "s2_z": 0.0,   # secondary spin z-component
    "d_L": 440.0,  # luminosity distance [Mpc]
    "phase_c": 0.0,
    "iota": 2.0,
}

waveform = ripplegw.IMRPhenomD(f_ref=20.0)
polarizations = waveform(frequency, params_bbh)

hp = polarizations["p"]  # h_+
hc = polarizations["c"]  # h_x

print(f"h+ shape: {hp.shape}, dtype: {hp.dtype}")

Visualise the frequency-domain waveform¤

import matplotlib.pyplot as plt

plt.loglog(frequency, jnp.abs(hp), label=r"$|h_+|$")
plt.loglog(frequency, jnp.abs(hc), label=r"$|h_\times|$", ls="--")
plt.xlim(f_low, f_high)
plt.xlabel("Frequency [Hz]")
plt.ylabel("Strain amplitude")
plt.title("Amplitude spectrum")
plt.legend()

plt.suptitle("IMRPhenomD — GW150914-like BBH")
plt.tight_layout()
plt.show()

Differentiability¤

Because ripple uses JAX, you can compute gradients of the waveform with respect to any parameter. This is key for gradient-based inference.

Here we compute the gradient of the total strain power with respect to the chirp mass \(\mathcal{M}\).

def total_power(M_c):
    p = {**params_bbh, "M_c": M_c}
    h = ripplegw.IMRPhenomD()(frequency, p)
    return jnp.sum(jnp.abs(h["p"]) ** 2)

print(f"power = {total_power(params_bbh['M_c']):.10e}")
grad_Mc = jax.grad(total_power)(params_bbh["M_c"])
print(f"d(power)/d(M_c) = {grad_Mc:.10e}")

JIT compilation¤

Wrap the call in jax.jit for fast repeated evaluation — e.g., inside a likelihood or sampler loop.

fast_waveform = jax.jit(waveform)

# First call traces and compiles
_ = fast_waveform(frequency, params_bbh)

# Subsequent calls are fast
import time
t0 = time.perf_counter()
for _ in range(100):
    pol = fast_waveform(frequency, params_bbh)
pol["p"].block_until_ready()
print(f"100 waveform evaluations: {time.perf_counter() - t0:.3f} s")

Switching waveform models¤

All waveform classes share the same __call__ interface, so swapping models only requires changing one line. waveform_preset is a convenience dictionary mapping name strings to classes.

print("Available waveforms:", list(ripplegw.waveform_preset.keys()))

# Swap IMRPhenomD -> IMRPhenomXAS (same parameter keys)
for name, cls in [("IMRPhenomD", ripplegw.IMRPhenomD),
                  ("IMRPhenomXAS", ripplegw.IMRPhenomXAS)]:
    h = cls()(frequency, params_bbh)
    print(f"{name}: max|h+| = {jnp.max(jnp.abs(h['p'])):.3e}")

Summary¤

The ripple interface is:

waveform = ripplegw.IMRPhenomD(f_ref=20.0)
polarizations = waveform(frequency, params)
hp, hc = polarizations["p"], polarizations["c"]

The same pattern applies to every model. Because all waveforms are pure JAX functions, they compose naturally with jax.grad, jax.jit, and jax.vmap for gradient-based inference and GPU-accelerated batch evaluation.