Wildfire simulation using THRML¶
Spatio-temporal Potts model on a HxW grid. Each cell has a binary variable $x_{i,j,t} \in \{0,1\}$ indicating whether there is an active wildfire at location (i,j) and time t.
We add our imports
from typing import List, Tuple
import sys
import time
import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx
# thrml dependencies
from thrml.block_management import Block
from thrml.block_sampling import BlockGibbsSpec, SamplingSchedule, sample_states
from thrml.factor import AbstractFactor, FactorSamplingProgram
from thrml.interaction import InteractionGroup
from thrml.conditional_samplers import AbstractParametricConditionalSampler
from thrml.pgm import AbstractNode
We declare our FireNodes, inheriting from AbstractNode (nodes in a PGM).
class FireNode(AbstractNode):
"""Node representing a grid cell in the wildfire model."""
pass
The wind kernel is a field $w_{i,j}$ that modulates the influence of neighboring cells based on wind direction and speed. For example, if the wind is blowing eastward, cells to the west will have a higher influence on the probability of fire spread to cell $(i,j)$.
def neighbor_kernel_from_wind(
wind: jnp.ndarray,
kappa: float = 2.3,
base: float = 1.0,
diag_scale: float = 0.7,
normalize: bool = False,
) -> jnp.ndarray:
"""
Build a 3x3 kernel emphasizing downwind spread.
Grid convention: x -> east (cols), y -> south (rows). wind = [dx, dy].
"""
w = jnp.asarray(wind, dtype=jnp.float32)
w = w / (jnp.linalg.norm(w) + 1e-6)
w_rc = jnp.array([w[1], w[0]], dtype=jnp.float32)
kernel = jnp.zeros((3, 3), dtype=jnp.float32)
for dy in (-1, 0, 1):
for dx in (-1, 0, 1):
if dy == 0 and dx == 0:
continue
v = jnp.array([dy, dx], dtype=jnp.float32)
v = v / (jnp.linalg.norm(v) + 1e-6)
dot = jnp.dot(v, w_rc)
scale = base * (diag_scale if (dy != 0 and dx != 0) else 1.0)
kernel = kernel.at[dy + 1, dx + 1].set(scale * jnp.exp(kappa * dot))
if normalize:
s = jnp.sum(kernel)
kernel = jnp.where(s > 0, kernel / s, kernel)
return kernel
Tiny JAX convolution, single-channel 2D cross-correlation (SAME padding) via jax.lax.conv_general_dilated.
def conv2d_same(x: jnp.ndarray, kernel: jnp.ndarray) -> jnp.ndarray:
"""2D cross-correlation with SAME padding on a single-channel image."""
x = jnp.asarray(x, dtype=jnp.float32)
k = jnp.asarray(kernel, dtype=jnp.float32)
x4 = x[jnp.newaxis, ..., jnp.newaxis] # (1, H, W, 1)
k4 = k[..., jnp.newaxis, jnp.newaxis] # (kh, kw, 1, 1)
y4 = jax.lax.conv_general_dilated(
lhs=x4, rhs=k4,
window_strides=(1, 1),
padding="SAME",
dimension_numbers=("NHWC", "HWIO", "NHWC"),
)
return y4[0, ..., 0]
Interactions and Factors (see 01_all_of_thrml.ipynb).
LinearInteraction(weights)→ adds per-node bias to logits.PairInteraction(weights)→ addsW @ tail_stateto logits.UnaryFactor(weights, block)→ exposesLinearInteraction.NeighborCouplingFactor(weights, (head_block, tail_block))→ exposes two directionalPairInteractions.
class LinearInteraction(eqx.Module):
"""Interaction of the form theta_i * x_i"""
weights: jnp.ndarray # shape [n]
class PairInteraction(eqx.Module):
"""Interaction of the form sum_k W_{ik} * x_i * z_k (one head i; multiple tails k)"""
weights: jnp.ndarray # shape [n, m] mapping head->tail
class UnaryFactor(AbstractFactor):
r"""Unary factor: \sum_i h_i x_i"""
weights: jnp.ndarray
def __init__(self, weights: jnp.ndarray, block: Block):
super().__init__([block])
self.weights = weights
def to_interaction_groups(self) -> list[InteractionGroup]:
return [
InteractionGroup(
interaction=LinearInteraction(self.weights),
head_nodes=self.node_groups[0],
tail_nodes=[],
)
]
class NeighborCouplingFactor(AbstractFactor):
r"""Pair factor: \sum_{i,k} W_{ik} x_i z_k (z = neighbor states)"""
weights: jnp.ndarray # shape [n_pairs, 1] for aligned pair blocks
def __init__(self, weights: jnp.ndarray, blocks: Tuple[Block, Block]):
super().__init__(list(blocks))
self.weights = weights
def to_interaction_groups(self) -> list[InteractionGroup]:
# Two directed influences: head<-tail and tail<-head, using LinearInteraction
return [
InteractionGroup(
interaction=LinearInteraction(self.weights),
head_nodes=self.node_groups[0],
tail_nodes=[self.node_groups[1]],
),
InteractionGroup(
interaction=LinearInteraction(self.weights),
head_nodes=self.node_groups[1],
tail_nodes=[self.node_groups[0]],
),
]
Custom conditional: Bernoulli over $\{0,1\}$ using Linear & Pair interactions.
class FireSpinConditional(AbstractParametricConditionalSampler):
def compute_parameters(self, key, interactions, active_flags, states, sampler_state, output_sd):
# Build logits from interactions
logits = jnp.zeros(output_sd.shape, dtype=jnp.float32)
for idx, (interaction, active, state) in enumerate(zip(interactions, active_flags, states)):
if isinstance(interaction, LinearInteraction):
# w * product(tail_states) if tails exist, else just w
state_prod = jnp.array(1.0, dtype=jnp.float32)
if len(state) > 0:
# for pairwise interactions, state[0] has shape [n_head, k_max]
state_prod = jnp.prod(jnp.stack(state, -1), -1)
contrib = jnp.sum(interaction.weights * active * state_prod, axis=-1)
logits += contrib
else:
raise ValueError(f"Unknown interaction type: {type(interaction)}")
return logits, sampler_state
def sample_given_parameters(self, key, parameters, sampler_state, output_sd):
# Bernoulli sampling for binary {0,1} variables
# Convert logits to probabilities using sigmoid
probs = jax.nn.sigmoid(parameters)
sample = jax.random.bernoulli(key, p=probs).astype(jnp.float32)
return sample, sampler_state
def init(self):
return None
Block builders for grid strips and neighbor alignment:
def create_grid_nodes(H: int, W: int) -> List[FireNode]:
"""Create node objects for each grid cell."""
return [FireNode() for _ in range(H * W)]
def make_full_block(nodes: List[FireNode]) -> Block:
"""A single block covering all cells."""
return Block(nodes)
def make_strip_blocks(nodes: List[FireNode], H: int, W: int, width: int, vertical: bool) -> List[Block]:
"""Create wind-aligned strip blocks covering the grid (no overlaps)."""
blocks: List[Block] = []
if vertical:
for c0 in range(0, W, width):
strip_nodes = []
for r in range(H):
for c in range(c0, min(c0 + width, W)):
strip_nodes.append(nodes[r * W + c])
blocks.append(Block(strip_nodes))
else:
for r0 in range(0, H, width):
strip_nodes = []
for r in range(r0, min(r0 + width, H)):
for c in range(W):
strip_nodes.append(nodes[r * W + c])
blocks.append(Block(strip_nodes))
return blocks
def neighbor_pair_blocks(nodes: List[FireNode], H: int, W: int, direction: str) -> Tuple[Block, Block]:
"""Build two blocks (head, tail) containing aligned neighbor pairs in a given direction."""
heads = []
tails = []
if direction == "E":
for r in range(H):
for c in range(W - 1):
heads.append(nodes[r * W + c])
tails.append(nodes[r * W + (c + 1)])
elif direction == "W":
for r in range(H):
for c in range(1, W):
heads.append(nodes[r * W + c])
tails.append(nodes[r * W + (c - 1)])
elif direction == "S":
for r in range(H - 1):
for c in range(W):
heads.append(nodes[r * W + c])
tails.append(nodes[(r + 1) * W + c])
elif direction == "N":
for r in range(1, H):
for c in range(W):
heads.append(nodes[r * W + c])
tails.append(nodes[(r - 1) * W + c])
else:
raise ValueError("direction must be one of N,S,E,W")
return Block(heads), Block(tails)
Within each time slice, spatial coupling $J$ discourages isolated flips and promotes compact fire fronts. The pairwise grid smoothness is anisotropic, which makes couplings stronger in the wind direction.
def anisotropic_couplings(H: int, W: int, wind: jnp.ndarray, J0: float, kappa: float) -> dict:
"""Returns dict of coupling weights per direction, shaped to match neighbor_pair_blocks order."""
# Direction unit vectors in (dx, dy): E=(1,0), W=(-1,0), S=(0,1), N=(0,-1)
dirs = {
"E": jnp.array([1.0, 0.0], dtype=jnp.float32),
"W": jnp.array([-1.0, 0.0], dtype=jnp.float32),
"S": jnp.array([0.0, 1.0], dtype=jnp.float32),
"N": jnp.array([0.0, -1.0], dtype=jnp.float32),
}
w = wind / (jnp.linalg.norm(wind) + 1e-6)
weights = {d: float(J0 * jnp.exp(kappa * (v @ w))) for d, v in dirs.items()}
# Build per-pair arrays
pair_weights = {}
for d in ["E", "W", "S", "N"]:
if d in ["E", "W"]:
n_pairs = H * (W - 1)
else:
n_pairs = (H - 1) * W
pair_weights[d] = jnp.full((n_pairs, 1), weights[d], dtype=jnp.float32)
return pair_weights
Dataviz code: we want to display the grid in the terminal, updating at each timestep.
States shown:
- Y: trees (unburned)
- /: barriers (roads, rivers)
- đźś‚: active fire (newly burning)
- _: burned soil (previously burned, may regrow)
def visualize_grid(
burn_state: jnp.ndarray,
prev_burn_state: jnp.ndarray,
barrier_map: jnp.ndarray,
t: int,
total: int,
burn_frac: float,
use_unicode: bool = True,
):
H, W = burn_state.shape
# ANSI escape codes
CLEAR_SCREEN = "\033[2J\033[H"
GREEN = "\033[32m"
RED = "\033[91m"
GRAY = "\033[90m"
RESET = "\033[0m"
# For first call, clear screen; otherwise just move cursor to home
if t == 0:
sys.stdout.write(CLEAR_SCREEN)
else:
sys.stdout.write("\033[H")
# Header
sys.stdout.write(f"Timestep {t}/{total} | Burn fraction: {burn_frac:.2%}\n")
sys.stdout.write("=" * min(W, 80) + "\n")
sys.stdout.write(f"{GREEN}Y{RESET}=trees {GRAY}/{RESET}=barriers {RED}{'đźś‚' if use_unicode else '*'}{RESET}=fire {GRAY}_{RESET}=burned\n")
# Downsample if grid is too large for terminal
max_display_size = 80
if W > max_display_size or H > max_display_size // 2:
scale_w = max(1, W // max_display_size)
scale_h = max(1, H // (max_display_size // 2))
display_burn = np.array(burn_state)[::scale_h, ::scale_w]
display_prev = np.array(prev_burn_state)[::scale_h, ::scale_w]
display_barrier = np.array(barrier_map)[::scale_h, ::scale_w]
else:
display_burn = np.array(burn_state)
display_prev = np.array(prev_burn_state)
display_barrier = np.array(barrier_map)
fire_char = "đźś‚" if use_unicode else "*"
# Draw grid
for r in range(len(display_burn)):
line = ""
for c in range(len(display_burn[0])):
is_barrier = display_barrier[r, c] > 0.5
is_burned = display_burn[r, c] > 0.5
was_burned = display_prev[r, c] > 0.5
is_new_fire = is_burned and not was_burned
if is_barrier:
line += f"{GRAY}/{RESET}"
elif is_new_fire:
line += f"{RED}{fire_char}{RESET}"
elif is_burned:
line += f"{GRAY}_{RESET}"
else:
line += f"{GREEN}Y{RESET}"
sys.stdout.write(line + "\n")
sys.stdout.flush()
Core simulation:
Algorithm¶
For a single step, we define an energy over $x_t$ as: $$E(x_t|x_{t-1}) = -\sum_{i,j}h_{i,j}(t)x_{i,j,t} - \sum_{\langle(i,j),(p,q)\rangle} J_{(i,j),(p,q)} x_{i,j,t} x_{p,q,t}$$
Where the first sum is the unary (dryness, seeds, previous fire), and the second sum is the spatial smoothness (anisotropic neighbor couplings).
Simulation parameters¶
We define a grid size (H, W), number of frames T, wind vector, blocking (strip_width), and model knobs:
alpha,beta,gamma,eta: weights for dryness, barriers, persistence, neighbor ignition. They are evaluated as follows:- Dryness/fuel is $\alpha d_{i,j}$ where $d_{i,j} \in [0,1]$ is the dryness map.
- Barriers $-\beta b_{i,j}$ where $b_{i,j} \in \{0,1\}$ indicates presence of a barrier (e.g. water, road).
- Persistence $+\gamma x_{i,j,t-1}$ encourages ongoing fires to continue burning.
- Neighbor ignition $+\eta \sum_{(p,q)\in \mathcal{N}(i,j)} w^{dir}_{(i,j)\leftarrow (p,q)} x_{p,q,t-1}$ encourages ignition if any neighbor was burning in the previous timestep.
J0,kappa→ (reserved) spatial smoothness & anisotropy strength for pairwise terms.rho→ reforestation probability.burn_duration→ how long a cell is “actively burning” (able to ignite neighbors).sweeps_per_frame→ how many Gibbs sweeps you do before emitting a frame.visualize+viz_*→ optional terminal animation.seed→ RNG determinism.
The THRML graph defines the nodes & block construction for the graph. The strip_blocks are wind-aligned strip blocks that define the Gibbs update order.
The THRML "runtime" are the Block Gibbs spec, samplers, and schedule:
- Spec: which blocks are free (to be sampled), which are clamped (none here), and expected shapes/dtypes.
- Samplers: one conditional per block (same instance reused).
- Schedule: a full sweep through all blocks per sample.
We wrap the factors into a THRML program and run one scheduled Gibbs sweep across strips. THRML handles gather/scatter, padding, and calling the conditional with the right neighbor states.
def run_wildfire(
H: int = 96,
W: int = 96,
T: int = 60,
wind: Tuple[float, float] = (1.0, 0.2),
strip_width: int = 12,
alpha: float = 1.2, # dryness
beta: float = 5.0, # barrier penalty
gamma: float = 1.5, # persistence
eta: float = 2.0, # neighbor ignition influence
J0: float = 0.6, # spatial smoothness base
kappa: float = 2.5, # anisotropy strength
rho: float = 0.0, # reforestation probability per timestep
burn_duration: int = 3, # sweeps a cell actively burns before burning out
sweeps_per_frame: int = 20, # Gibbs sweeps per visualization frame
visualize: bool = False, # show terminal visualization
viz_delay: float = 0.1, # delay between frames (seconds)
viz_unicode: bool = True, # use unicode fire symbol (đźś‚) vs ASCII (*)
seed: int = 0,
):
key = jax.random.PRNGKey(seed)
wind_vec = jnp.array(wind, dtype=jnp.float32)
# synthetic maps
yy, xx = jnp.meshgrid(jnp.linspace(0, 1, H), jnp.linspace(0, 1, W), indexing="ij")
dryness = 0.4 + 0.6 * (0.5 * jnp.sin(6*xx) + 0.5 * jnp.cos(4*yy)) # [0,1]ish
barrier = ((jnp.sin(10*yy) > 0.8) & (xx > 0.3) & (xx < 0.7)).astype(jnp.float32) # toy rivers
# Realistic ignition: single point in bottom-left (one tree/campfire)
# Find cell closest to (2%, 2%) position
r_ignite = int(H * 0.02)
c_ignite = int(W * 0.02)
ignite0 = jnp.zeros((H, W), dtype=jnp.float32)
ignite0 = ignite0.at[r_ignite, c_ignite].set(1.0)
# Create grid nodes and index mapping
grid_nodes = create_grid_nodes(H, W)
node_to_idx = {node: idx for idx, node in enumerate(grid_nodes)}
# blocks
vertical: bool = abs(wind[0]) >= abs(wind[1])
strip_blocks = make_strip_blocks(grid_nodes, H, W, strip_width, vertical=vertical)
full_block = make_full_block(grid_nodes)
# pairwise couplings
pair_W = anisotropic_couplings(H, W, wind_vec, J0=J0, kappa=kappa)
pair_blocks = {d: neighbor_pair_blocks(grid_nodes, H, W, d) for d in ["E", "W", "S", "N"]}
# prebuild factors that do not change over time (pairwise)
pair_factors = []
for d in ["E", "W", "S", "N"]:
# weights shape [n_head, n_tail]; here n_tail=1 per pair (we pack as (n_pairs, 1))
blocks = pair_blocks[d]
pair_factors.append(NeighborCouplingFactor(pair_W[d], blocks))
# conditional sampler
conditional = FireSpinConditional()
# Block Gibbs spec: defines which blocks to sample (strip blocks cover all nodes)
node_shape_dtypes = {FireNode: jax.ShapeDtypeStruct((), jnp.float32)}
gibbs_spec = BlockGibbsSpec(strip_blocks, [], node_shape_dtypes)
# One sampler per block (all use the same conditional)
samplers = [conditional] * len(strip_blocks)
# Schedule: 1 sweep through all blocks per Gibbs iteration
# IMPORTANT: n_warmup must be >= len(blocks) to ensure actual sampling occurs!
schedule = SamplingSchedule(n_warmup=len(strip_blocks), n_samples=1, steps_per_sample=len(strip_blocks))
# frames buffer
frames = []
metrics_burn = []
metrics_frontier = []
x_prev = ignite0 # once-burned stays burned; we OR with new fires
x_prev_prev = jnp.zeros_like(x_prev) # track previous state for fire detection
# Track fire age for 3-state model: 0=unburned, 1 to burn_duration=actively burning, >burn_duration=burned out
fire_age = jnp.where(ignite0 > 0, 1.0, 0.0) # initial ignition starts at age 1
K = neighbor_kernel_from_wind(wind_vec, kappa=2.3, base=1.0, diag_scale=0.7, normalize=False)
for t in range(T):
print(f"Frame {t}/{T}, burned cells: {jnp.sum(x_prev):.1f}")
# Track state at start of frame for fire detection in visualization
x_frame_start = x_prev.copy()
# Multiple Gibbs sweeps per frame for gradual spread
for sweep in range(sweeps_per_frame):
key, subkey = jax.random.split(key)
# Recompute field based on current state (allows fire to propagate gradually)
# IMPORTANT: Only actively burning cells (age 1 to burn_duration) spread fire
active_fire_mask = (fire_age >= 1) & (fire_age <= burn_duration)
active_fires = x_prev * active_fire_mask.astype(jnp.float32)
neigh = conv2d_same(active_fires, K)
# Field with negative baseline to prevent spontaneous ignition
# Trees only ignite from strong neighbor influence (eta * neigh)
ignition_threshold = -5.0 # Cells need strong positive field to ignite
h = ignition_threshold + alpha * dryness - beta * barrier + gamma * x_prev + eta * neigh
if t == 0 and sweep == 0:
print(f" h stats: min={jnp.min(h):.2f}, max={jnp.max(h):.2f}, mean={jnp.mean(h):.2f}")
print(f" h at ignition: {h[jnp.where(ignite0 > 0)][0]:.2f}")
# build unary factor for this sweep
unary_factor = UnaryFactor(h.reshape(-1), full_block)
# assemble factor list - try without pairwise factors first
factors = [unary_factor] # + pair_factors
# Create sampling program for this sweep
program = FactorSamplingProgram(
gibbs_spec=gibbs_spec,
samplers=samplers,
factors=factors,
other_interaction_groups=[],
)
# Initialize state: one array per block containing the state of nodes in that block
x_flat = x_prev.reshape(-1)
init_state = []
for block in strip_blocks:
# Get indices of nodes in this block using fast lookup
node_idxs = jnp.array([node_to_idx[node] for node in block.nodes])
init_state.append(x_flat[node_idxs])
# Sample and collect state from all blocks
sampled_states = sample_states(
key=subkey,
program=program,
schedule=schedule,
init_state_free=init_state,
state_clamp=[],
nodes_to_sample=strip_blocks,
)
# Reassemble full grid from strip blocks (sampled_states has one entry per strip)
x_t_flat = jnp.zeros(H * W, dtype=jnp.float32)
for block_idx, block in enumerate(strip_blocks):
node_idxs = jnp.array([node_to_idx[node] for node in block.nodes])
# sampled_states[block_idx] has shape [n_samples, n_nodes_in_block]
x_t_flat = x_t_flat.at[node_idxs].set(sampled_states[block_idx][0])
x_t = x_t_flat.reshape(H, W).astype(jnp.float32)
# Update fire age: increment for burning cells, keep at 0 for unburned
# New ignitions (x_t=1 where x_prev was 0) start at age 1
newly_ignited = (x_t > 0.5) & (x_prev < 0.5)
fire_age = jnp.where(newly_ignited, 1.0, fire_age)
# Increment age for cells that were already burning
fire_age = jnp.where(x_prev > 0.5, fire_age + 1, fire_age)
# Update accumulated burn state
x_prev = jnp.clip(x_prev + x_t, 0, 1) # once burned, stays 1
# Reforestation: ONLY burned cells can regrow with probability rho
if rho > 0:
key, subkey = jax.random.split(key)
# Only apply regrowth to cells that are actually burned (x_prev > 0.5)
regrow_candidates = x_prev > 0.5
regrow_dice = jax.random.uniform(subkey, x_prev.shape) < rho
regrow_mask = regrow_candidates & regrow_dice
x_prev = x_prev * (1 - regrow_mask.astype(jnp.float32))
# Reset fire age for regrown cells
fire_age = jnp.where(regrow_mask, 0.0, fire_age)
# After all sweeps, record final state for this frame
frames.append(x_prev.copy())
# Metrics (based on final state after all sweeps)
burn_frac = jnp.mean(x_prev)
# frontier = count of burning neighbors transitioning; simple proxy: perimeter via gradient
gx = jnp.abs(jnp.diff(x_prev, axis=1)).sum()
gy = jnp.abs(jnp.diff(x_prev, axis=0)).sum()
frontier = (gx + gy) / (2 * (H + W))
metrics_burn.append(float(burn_frac))
metrics_frontier.append(float(frontier))
# Visualization (if enabled)
if visualize:
visualize_grid(
burn_state=x_prev,
prev_burn_state=x_prev_prev,
barrier_map=barrier,
t=t,
total=T,
burn_frac=float(jnp.mean(x_prev)),
use_unicode=viz_unicode,
)
time.sleep(viz_delay)
# Update state tracking for next frame
x_prev_prev = x_frame_start
frames = jnp.stack(frames, axis=0)
return {
"frames": np.array(frames),
"metrics": {
"burn_frac": np.array(metrics_burn),
"frontier": np.array(metrics_frontier),
},
"params": dict(H=H, W=W, T=T, wind=wind, strip_width=strip_width,
alpha=alpha, beta=beta, gamma=gamma, eta=eta, J0=J0, kappa=kappa),
}
Run it with
out = run_wildfire(
H=96,
W=96,
T=60,
alpha=0.05,
gamma=0.0,
eta=2.0,
beta=10.0,
rho=0.01, # 1% chance burned cells regrow per sweep
burn_duration=3, # Cells burn for 3 sweeps before becoming ash
sweeps_per_frame=1, # ONE sweep per frame for very gradual spread
visualize=True,
viz_delay=0.1, # 100ms between frames
viz_unicode=True,
)
print("\n" + "="*80)
print("Simulation complete!")
print("Frames:", out["frames"].shape)
print("Burn fraction (first 5):", out["metrics"]["burn_frac"][:5])
print("Burn fraction (last 5):", out["metrics"]["burn_frac"][-5:])
print("Frontier (first 5):", out["metrics"]["frontier"][:5])