Skip to content

FAQ¤

Float precision¤

JAX defaults to float32, but gravitational-wave waveform computations typically require float64 precision. Always enable it at the top of your script, before any JAX operations:

import jax
jax.config.update("jax_enable_x64", True)

Without this, you may see unexpected numerical errors or inaccurate waveforms, particularly at high frequencies or for long signals.

JIT compilation time¤

The first call to a JIT-compiled waveform (e.g. via jax.jit) triggers XLA compilation, which can take several seconds. This is normal — subsequent calls will be much faster. If you are timing ripple for benchmarking purposes, discard the first call.

To disable JIT for debugging:

jax.config.update("jax_disable_jit", True)

Compilation is slow for complex models¤

If you wrap a ripple waveform inside a larger likelihood with many operations or Python-level loops, JAX may take a long time to compile the full computational graph. Replacing Python loops with jax.lax.scan or jax.vmap where possible can significantly reduce compilation time.