01 About 02 Experience 03 Projects
03

Wildfire simulation using THRML¶

Spatio-temporal Potts model on a HxW grid. Each cell has a binary variable $x_{i,j,t} \in \{0,1\}$ indicating whether there is an active wildfire at location (i,j) and time t.

We add our imports

In [2]:
from typing import List, Tuple
import sys
import time
import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx
In [3]:
# thrml dependencies
from thrml.block_management import Block
from thrml.block_sampling import BlockGibbsSpec, SamplingSchedule, sample_states
from thrml.factor import AbstractFactor, FactorSamplingProgram
from thrml.interaction import InteractionGroup
from thrml.conditional_samplers import AbstractParametricConditionalSampler
from thrml.pgm import AbstractNode

We declare our FireNodes, inheriting from AbstractNode (nodes in a PGM).

In [4]:
class FireNode(AbstractNode):
    """Node representing a grid cell in the wildfire model."""
    pass

The wind kernel is a field $w_{i,j}$ that modulates the influence of neighboring cells based on wind direction and speed. For example, if the wind is blowing eastward, cells to the west will have a higher influence on the probability of fire spread to cell $(i,j)$.

In [5]:
def neighbor_kernel_from_wind(
    wind: jnp.ndarray,
    kappa: float = 2.3,
    base: float = 1.0,
    diag_scale: float = 0.7,
    normalize: bool = False,
) -> jnp.ndarray:
    """
    Build a 3x3 kernel emphasizing downwind spread.
    Grid convention: x -> east (cols), y -> south (rows). wind = [dx, dy].
    """
    w = jnp.asarray(wind, dtype=jnp.float32)
    w = w / (jnp.linalg.norm(w) + 1e-6)
    w_rc = jnp.array([w[1], w[0]], dtype=jnp.float32)

    kernel = jnp.zeros((3, 3), dtype=jnp.float32)
    for dy in (-1, 0, 1):
        for dx in (-1, 0, 1):
            if dy == 0 and dx == 0:
                continue
            v = jnp.array([dy, dx], dtype=jnp.float32)
            v = v / (jnp.linalg.norm(v) + 1e-6)
            dot = jnp.dot(v, w_rc)
            scale = base * (diag_scale if (dy != 0 and dx != 0) else 1.0)
            kernel = kernel.at[dy + 1, dx + 1].set(scale * jnp.exp(kappa * dot))

    if normalize:
        s = jnp.sum(kernel)
        kernel = jnp.where(s > 0, kernel / s, kernel)
    return kernel

Tiny JAX convolution, single-channel 2D cross-correlation (SAME padding) via jax.lax.conv_general_dilated.

In [6]:
def conv2d_same(x: jnp.ndarray, kernel: jnp.ndarray) -> jnp.ndarray:
    """2D cross-correlation with SAME padding on a single-channel image."""
    x = jnp.asarray(x, dtype=jnp.float32)
    k = jnp.asarray(kernel, dtype=jnp.float32)
    x4 = x[jnp.newaxis, ..., jnp.newaxis]       # (1, H, W, 1)
    k4 = k[..., jnp.newaxis, jnp.newaxis]       # (kh, kw, 1, 1)
    y4 = jax.lax.conv_general_dilated(
        lhs=x4, rhs=k4,
        window_strides=(1, 1),
        padding="SAME",
        dimension_numbers=("NHWC", "HWIO", "NHWC"),
    )
    return y4[0, ..., 0]

Interactions and Factors (see 01_all_of_thrml.ipynb).

  • LinearInteraction(weights) → adds per-node bias to logits.
  • PairInteraction(weights) → adds W @ tail_state to logits.
  • UnaryFactor(weights, block) → exposes LinearInteraction.
  • NeighborCouplingFactor(weights, (head_block, tail_block)) → exposes two directional PairInteractions.
In [7]:
class LinearInteraction(eqx.Module):
    """Interaction of the form theta_i * x_i"""
    weights: jnp.ndarray  # shape [n]

class PairInteraction(eqx.Module):
    """Interaction of the form sum_k W_{ik} * x_i * z_k  (one head i; multiple tails k)"""
    weights: jnp.ndarray  # shape [n, m] mapping head->tail


class UnaryFactor(AbstractFactor):
    r"""Unary factor: \sum_i h_i x_i"""
    weights: jnp.ndarray

    def __init__(self, weights: jnp.ndarray, block: Block):
        super().__init__([block])
        self.weights = weights

    def to_interaction_groups(self) -> list[InteractionGroup]:
        return [
            InteractionGroup(
                interaction=LinearInteraction(self.weights),
                head_nodes=self.node_groups[0],
                tail_nodes=[],
            )
        ]

class NeighborCouplingFactor(AbstractFactor):
    r"""Pair factor: \sum_{i,k} W_{ik} x_i z_k  (z = neighbor states)"""
    weights: jnp.ndarray  # shape [n_pairs, 1] for aligned pair blocks

    def __init__(self, weights: jnp.ndarray, blocks: Tuple[Block, Block]):
        super().__init__(list(blocks))
        self.weights = weights

    def to_interaction_groups(self) -> list[InteractionGroup]:
        # Two directed influences: head<-tail and tail<-head, using LinearInteraction
        return [
            InteractionGroup(
                interaction=LinearInteraction(self.weights),
                head_nodes=self.node_groups[0],
                tail_nodes=[self.node_groups[1]],
            ),
            InteractionGroup(
                interaction=LinearInteraction(self.weights),
                head_nodes=self.node_groups[1],
                tail_nodes=[self.node_groups[0]],
            ),
        ]

Custom conditional: Bernoulli over $\{0,1\}$ using Linear & Pair interactions.

In [8]:
class FireSpinConditional(AbstractParametricConditionalSampler):
    def compute_parameters(self, key, interactions, active_flags, states, sampler_state, output_sd):
        # Build logits from interactions
        logits = jnp.zeros(output_sd.shape, dtype=jnp.float32)

        for idx, (interaction, active, state) in enumerate(zip(interactions, active_flags, states)):
            if isinstance(interaction, LinearInteraction):
                # w * product(tail_states) if tails exist, else just w
                state_prod = jnp.array(1.0, dtype=jnp.float32)
                if len(state) > 0:
                    # for pairwise interactions, state[0] has shape [n_head, k_max]
                    state_prod = jnp.prod(jnp.stack(state, -1), -1)
                contrib = jnp.sum(interaction.weights * active * state_prod, axis=-1)
                logits += contrib
            else:
                raise ValueError(f"Unknown interaction type: {type(interaction)}")
        return logits, sampler_state

    def sample_given_parameters(self, key, parameters, sampler_state, output_sd):
        # Bernoulli sampling for binary {0,1} variables
        # Convert logits to probabilities using sigmoid
        probs = jax.nn.sigmoid(parameters)
        sample = jax.random.bernoulli(key, p=probs).astype(jnp.float32)
        return sample, sampler_state

    def init(self):
        return None

Block builders for grid strips and neighbor alignment:

In [9]:
def create_grid_nodes(H: int, W: int) -> List[FireNode]:
    """Create node objects for each grid cell."""
    return [FireNode() for _ in range(H * W)]

def make_full_block(nodes: List[FireNode]) -> Block:
    """A single block covering all cells."""
    return Block(nodes)

def make_strip_blocks(nodes: List[FireNode], H: int, W: int, width: int, vertical: bool) -> List[Block]:
    """Create wind-aligned strip blocks covering the grid (no overlaps)."""
    blocks: List[Block] = []
    if vertical:
        for c0 in range(0, W, width):
            strip_nodes = []
            for r in range(H):
                for c in range(c0, min(c0 + width, W)):
                    strip_nodes.append(nodes[r * W + c])
            blocks.append(Block(strip_nodes))
    else:
        for r0 in range(0, H, width):
            strip_nodes = []
            for r in range(r0, min(r0 + width, H)):
                for c in range(W):
                    strip_nodes.append(nodes[r * W + c])
            blocks.append(Block(strip_nodes))
    return blocks

def neighbor_pair_blocks(nodes: List[FireNode], H: int, W: int, direction: str) -> Tuple[Block, Block]:
    """Build two blocks (head, tail) containing aligned neighbor pairs in a given direction."""
    heads = []
    tails = []
    if direction == "E":
        for r in range(H):
            for c in range(W - 1):
                heads.append(nodes[r * W + c])
                tails.append(nodes[r * W + (c + 1)])
    elif direction == "W":
        for r in range(H):
            for c in range(1, W):
                heads.append(nodes[r * W + c])
                tails.append(nodes[r * W + (c - 1)])
    elif direction == "S":
        for r in range(H - 1):
            for c in range(W):
                heads.append(nodes[r * W + c])
                tails.append(nodes[(r + 1) * W + c])
    elif direction == "N":
        for r in range(1, H):
            for c in range(W):
                heads.append(nodes[r * W + c])
                tails.append(nodes[(r - 1) * W + c])
    else:
        raise ValueError("direction must be one of N,S,E,W")
    return Block(heads), Block(tails)

Within each time slice, spatial coupling $J$ discourages isolated flips and promotes compact fire fronts. The pairwise grid smoothness is anisotropic, which makes couplings stronger in the wind direction.

In [10]:
def anisotropic_couplings(H: int, W: int, wind: jnp.ndarray, J0: float, kappa: float) -> dict:
    """Returns dict of coupling weights per direction, shaped to match neighbor_pair_blocks order."""
    # Direction unit vectors in (dx, dy): E=(1,0), W=(-1,0), S=(0,1), N=(0,-1)
    dirs = {
        "E": jnp.array([1.0, 0.0], dtype=jnp.float32),
        "W": jnp.array([-1.0, 0.0], dtype=jnp.float32),
        "S": jnp.array([0.0, 1.0], dtype=jnp.float32),
        "N": jnp.array([0.0, -1.0], dtype=jnp.float32),
    }
    w = wind / (jnp.linalg.norm(wind) + 1e-6)
    weights = {d: float(J0 * jnp.exp(kappa * (v @ w))) for d, v in dirs.items()}

    # Build per-pair arrays
    pair_weights = {}
    for d in ["E", "W", "S", "N"]:
        if d in ["E", "W"]:
            n_pairs = H * (W - 1)
        else:
            n_pairs = (H - 1) * W
        pair_weights[d] = jnp.full((n_pairs, 1), weights[d], dtype=jnp.float32)
    return pair_weights

Dataviz code: we want to display the grid in the terminal, updating at each timestep.

States shown:

  • Y: trees (unburned)
  • /: barriers (roads, rivers)
  • đźś‚: active fire (newly burning)
  • _: burned soil (previously burned, may regrow)
In [11]:
def visualize_grid(
    burn_state: jnp.ndarray,
    prev_burn_state: jnp.ndarray,
    barrier_map: jnp.ndarray,
    t: int,
    total: int,
    burn_frac: float,
    use_unicode: bool = True,
):
    H, W = burn_state.shape

    # ANSI escape codes
    CLEAR_SCREEN = "\033[2J\033[H"
    GREEN = "\033[32m"
    RED = "\033[91m"
    GRAY = "\033[90m"
    RESET = "\033[0m"

    # For first call, clear screen; otherwise just move cursor to home
    if t == 0:
        sys.stdout.write(CLEAR_SCREEN)
    else:
        sys.stdout.write("\033[H")

    # Header
    sys.stdout.write(f"Timestep {t}/{total} | Burn fraction: {burn_frac:.2%}\n")
    sys.stdout.write("=" * min(W, 80) + "\n")
    sys.stdout.write(f"{GREEN}Y{RESET}=trees  {GRAY}/{RESET}=barriers  {RED}{'đźś‚' if use_unicode else '*'}{RESET}=fire  {GRAY}_{RESET}=burned\n")

    # Downsample if grid is too large for terminal
    max_display_size = 80
    if W > max_display_size or H > max_display_size // 2:
        scale_w = max(1, W // max_display_size)
        scale_h = max(1, H // (max_display_size // 2))
        display_burn = np.array(burn_state)[::scale_h, ::scale_w]
        display_prev = np.array(prev_burn_state)[::scale_h, ::scale_w]
        display_barrier = np.array(barrier_map)[::scale_h, ::scale_w]
    else:
        display_burn = np.array(burn_state)
        display_prev = np.array(prev_burn_state)
        display_barrier = np.array(barrier_map)

    fire_char = "đźś‚" if use_unicode else "*"

    # Draw grid
    for r in range(len(display_burn)):
        line = ""
        for c in range(len(display_burn[0])):
            is_barrier = display_barrier[r, c] > 0.5
            is_burned = display_burn[r, c] > 0.5
            was_burned = display_prev[r, c] > 0.5
            is_new_fire = is_burned and not was_burned

            if is_barrier:
                line += f"{GRAY}/{RESET}"
            elif is_new_fire:
                line += f"{RED}{fire_char}{RESET}"
            elif is_burned:
                line += f"{GRAY}_{RESET}"
            else:
                line += f"{GREEN}Y{RESET}"

        sys.stdout.write(line + "\n")

    sys.stdout.flush()

Core simulation:

Algorithm¶

For a single step, we define an energy over $x_t$ as: $$E(x_t|x_{t-1}) = -\sum_{i,j}h_{i,j}(t)x_{i,j,t} - \sum_{\langle(i,j),(p,q)\rangle} J_{(i,j),(p,q)} x_{i,j,t} x_{p,q,t}$$

Where the first sum is the unary (dryness, seeds, previous fire), and the second sum is the spatial smoothness (anisotropic neighbor couplings).

Simulation parameters¶

We define a grid size (H, W), number of frames T, wind vector, blocking (strip_width), and model knobs:

  • alpha, beta, gamma, eta: weights for dryness, barriers, persistence, neighbor ignition. They are evaluated as follows:

    • Dryness/fuel is $\alpha d_{i,j}$ where $d_{i,j} \in [0,1]$ is the dryness map.
    • Barriers $-\beta b_{i,j}$ where $b_{i,j} \in \{0,1\}$ indicates presence of a barrier (e.g. water, road).
    • Persistence $+\gamma x_{i,j,t-1}$ encourages ongoing fires to continue burning.
    • Neighbor ignition $+\eta \sum_{(p,q)\in \mathcal{N}(i,j)} w^{dir}_{(i,j)\leftarrow (p,q)} x_{p,q,t-1}$ encourages ignition if any neighbor was burning in the previous timestep.
  • J0, kappa → (reserved) spatial smoothness & anisotropy strength for pairwise terms.

  • rho → reforestation probability.

  • burn_duration → how long a cell is “actively burning” (able to ignite neighbors).

  • sweeps_per_frame → how many Gibbs sweeps you do before emitting a frame.

  • visualize + viz_* → optional terminal animation.

  • seed → RNG determinism.

The THRML graph defines the nodes & block construction for the graph. The strip_blocks are wind-aligned strip blocks that define the Gibbs update order.

The THRML "runtime" are the Block Gibbs spec, samplers, and schedule:

  • Spec: which blocks are free (to be sampled), which are clamped (none here), and expected shapes/dtypes.
  • Samplers: one conditional per block (same instance reused).
  • Schedule: a full sweep through all blocks per sample.

We wrap the factors into a THRML program and run one scheduled Gibbs sweep across strips. THRML handles gather/scatter, padding, and calling the conditional with the right neighbor states.

In [12]:
def run_wildfire(
    H: int = 96,
    W: int = 96,
    T: int = 60,
    wind: Tuple[float, float] = (1.0, 0.2),
    strip_width: int = 12,
    alpha: float = 1.2,    # dryness
    beta: float = 5.0,     # barrier penalty
    gamma: float = 1.5,    # persistence
    eta: float = 2.0,      # neighbor ignition influence
    J0: float = 0.6,       # spatial smoothness base
    kappa: float = 2.5,    # anisotropy strength
    rho: float = 0.0,      # reforestation probability per timestep
    burn_duration: int = 3,   # sweeps a cell actively burns before burning out
    sweeps_per_frame: int = 20,  # Gibbs sweeps per visualization frame
    visualize: bool = False,  # show terminal visualization
    viz_delay: float = 0.1,   # delay between frames (seconds)
    viz_unicode: bool = True, # use unicode fire symbol (đźś‚) vs ASCII (*)
    seed: int = 0,
):
    key = jax.random.PRNGKey(seed)
    wind_vec = jnp.array(wind, dtype=jnp.float32)

    # synthetic maps
    yy, xx = jnp.meshgrid(jnp.linspace(0, 1, H), jnp.linspace(0, 1, W), indexing="ij")
    dryness = 0.4 + 0.6 * (0.5 * jnp.sin(6*xx) + 0.5 * jnp.cos(4*yy))  # [0,1]ish
    barrier = ((jnp.sin(10*yy) > 0.8) & (xx > 0.3) & (xx < 0.7)).astype(jnp.float32)  # toy rivers
    # Realistic ignition: single point in bottom-left (one tree/campfire)
    # Find cell closest to (2%, 2%) position
    r_ignite = int(H * 0.02)
    c_ignite = int(W * 0.02)
    ignite0 = jnp.zeros((H, W), dtype=jnp.float32)
    ignite0 = ignite0.at[r_ignite, c_ignite].set(1.0)

    # Create grid nodes and index mapping
    grid_nodes = create_grid_nodes(H, W)
    node_to_idx = {node: idx for idx, node in enumerate(grid_nodes)}

    # blocks
    vertical: bool = abs(wind[0]) >= abs(wind[1])
    strip_blocks = make_strip_blocks(grid_nodes, H, W, strip_width, vertical=vertical)
    full_block = make_full_block(grid_nodes)

    # pairwise couplings
    pair_W = anisotropic_couplings(H, W, wind_vec, J0=J0, kappa=kappa)
    pair_blocks = {d: neighbor_pair_blocks(grid_nodes, H, W, d) for d in ["E", "W", "S", "N"]}

    # prebuild factors that do not change over time (pairwise)
    pair_factors = []
    for d in ["E", "W", "S", "N"]:
        # weights shape [n_head, n_tail]; here n_tail=1 per pair (we pack as (n_pairs, 1))
        blocks = pair_blocks[d]
        pair_factors.append(NeighborCouplingFactor(pair_W[d], blocks))

    # conditional sampler
    conditional = FireSpinConditional()

    # Block Gibbs spec: defines which blocks to sample (strip blocks cover all nodes)
    node_shape_dtypes = {FireNode: jax.ShapeDtypeStruct((), jnp.float32)}
    gibbs_spec = BlockGibbsSpec(strip_blocks, [], node_shape_dtypes)

    # One sampler per block (all use the same conditional)
    samplers = [conditional] * len(strip_blocks)

    # Schedule: 1 sweep through all blocks per Gibbs iteration
    # IMPORTANT: n_warmup must be >= len(blocks) to ensure actual sampling occurs!
    schedule = SamplingSchedule(n_warmup=len(strip_blocks), n_samples=1, steps_per_sample=len(strip_blocks))

    # frames buffer
    frames = []
    metrics_burn = []
    metrics_frontier = []

    x_prev = ignite0  # once-burned stays burned; we OR with new fires
    x_prev_prev = jnp.zeros_like(x_prev)  # track previous state for fire detection

    # Track fire age for 3-state model: 0=unburned, 1 to burn_duration=actively burning, >burn_duration=burned out
    fire_age = jnp.where(ignite0 > 0, 1.0, 0.0)  # initial ignition starts at age 1

    K = neighbor_kernel_from_wind(wind_vec, kappa=2.3, base=1.0, diag_scale=0.7, normalize=False)

    for t in range(T):
        print(f"Frame {t}/{T}, burned cells: {jnp.sum(x_prev):.1f}")

        # Track state at start of frame for fire detection in visualization
        x_frame_start = x_prev.copy()

        # Multiple Gibbs sweeps per frame for gradual spread
        for sweep in range(sweeps_per_frame):
            key, subkey = jax.random.split(key)

            # Recompute field based on current state (allows fire to propagate gradually)
            # IMPORTANT: Only actively burning cells (age 1 to burn_duration) spread fire
            active_fire_mask = (fire_age >= 1) & (fire_age <= burn_duration)
            active_fires = x_prev * active_fire_mask.astype(jnp.float32)
            neigh = conv2d_same(active_fires, K)
            # Field with negative baseline to prevent spontaneous ignition
            # Trees only ignite from strong neighbor influence (eta * neigh)
            ignition_threshold = -5.0  # Cells need strong positive field to ignite
            h = ignition_threshold + alpha * dryness - beta * barrier + gamma * x_prev + eta * neigh

            if t == 0 and sweep == 0:
                print(f"  h stats: min={jnp.min(h):.2f}, max={jnp.max(h):.2f}, mean={jnp.mean(h):.2f}")
                print(f"  h at ignition: {h[jnp.where(ignite0 > 0)][0]:.2f}")

            # build unary factor for this sweep
            unary_factor = UnaryFactor(h.reshape(-1), full_block)

            # assemble factor list - try without pairwise factors first
            factors = [unary_factor]  # + pair_factors

            # Create sampling program for this sweep
            program = FactorSamplingProgram(
                gibbs_spec=gibbs_spec,
                samplers=samplers,
                factors=factors,
                other_interaction_groups=[],
            )

            # Initialize state: one array per block containing the state of nodes in that block
            x_flat = x_prev.reshape(-1)
            init_state = []
            for block in strip_blocks:
                # Get indices of nodes in this block using fast lookup
                node_idxs = jnp.array([node_to_idx[node] for node in block.nodes])
                init_state.append(x_flat[node_idxs])

            # Sample and collect state from all blocks
            sampled_states = sample_states(
                key=subkey,
                program=program,
                schedule=schedule,
                init_state_free=init_state,
                state_clamp=[],
                nodes_to_sample=strip_blocks,
            )

            # Reassemble full grid from strip blocks (sampled_states has one entry per strip)
            x_t_flat = jnp.zeros(H * W, dtype=jnp.float32)
            for block_idx, block in enumerate(strip_blocks):
                node_idxs = jnp.array([node_to_idx[node] for node in block.nodes])
                # sampled_states[block_idx] has shape [n_samples, n_nodes_in_block]
                x_t_flat = x_t_flat.at[node_idxs].set(sampled_states[block_idx][0])

            x_t = x_t_flat.reshape(H, W).astype(jnp.float32)

            # Update fire age: increment for burning cells, keep at 0 for unburned
            # New ignitions (x_t=1 where x_prev was 0) start at age 1
            newly_ignited = (x_t > 0.5) & (x_prev < 0.5)
            fire_age = jnp.where(newly_ignited, 1.0, fire_age)
            # Increment age for cells that were already burning
            fire_age = jnp.where(x_prev > 0.5, fire_age + 1, fire_age)

            # Update accumulated burn state
            x_prev = jnp.clip(x_prev + x_t, 0, 1)  # once burned, stays 1

            # Reforestation: ONLY burned cells can regrow with probability rho
            if rho > 0:
                key, subkey = jax.random.split(key)
                # Only apply regrowth to cells that are actually burned (x_prev > 0.5)
                regrow_candidates = x_prev > 0.5
                regrow_dice = jax.random.uniform(subkey, x_prev.shape) < rho
                regrow_mask = regrow_candidates & regrow_dice
                x_prev = x_prev * (1 - regrow_mask.astype(jnp.float32))
                # Reset fire age for regrown cells
                fire_age = jnp.where(regrow_mask, 0.0, fire_age)

        # After all sweeps, record final state for this frame
        frames.append(x_prev.copy())

        # Metrics (based on final state after all sweeps)
        burn_frac = jnp.mean(x_prev)
        # frontier = count of burning neighbors transitioning; simple proxy: perimeter via gradient
        gx = jnp.abs(jnp.diff(x_prev, axis=1)).sum()
        gy = jnp.abs(jnp.diff(x_prev, axis=0)).sum()
        frontier = (gx + gy) / (2 * (H + W))

        metrics_burn.append(float(burn_frac))
        metrics_frontier.append(float(frontier))

        # Visualization (if enabled)
        if visualize:
            visualize_grid(
                burn_state=x_prev,
                prev_burn_state=x_prev_prev,
                barrier_map=barrier,
                t=t,
                total=T,
                burn_frac=float(jnp.mean(x_prev)),
                use_unicode=viz_unicode,
            )
            time.sleep(viz_delay)

        # Update state tracking for next frame
        x_prev_prev = x_frame_start

    frames = jnp.stack(frames, axis=0)

    return {
        "frames": np.array(frames),
        "metrics": {
            "burn_frac": np.array(metrics_burn),
            "frontier": np.array(metrics_frontier),
        },
        "params": dict(H=H, W=W, T=T, wind=wind, strip_width=strip_width,
                       alpha=alpha, beta=beta, gamma=gamma, eta=eta, J0=J0, kappa=kappa),
    }

Run it with

In [ ]:
out = run_wildfire(
    H=96,
    W=96,
    T=60,
    alpha=0.05,            
    gamma=0.0,             
    eta=2.0,               
    beta=10.0,             
    rho=0.01,           # 1% chance burned cells regrow per sweep
    burn_duration=3,    # Cells burn for 3 sweeps before becoming ash
    sweeps_per_frame=1, # ONE sweep per frame for very gradual spread
    visualize=True,     
    viz_delay=0.1,      # 100ms between frames
    viz_unicode=True,      
)
print("\n" + "="*80)
print("Simulation complete!")
print("Frames:", out["frames"].shape)
print("Burn fraction (first 5):", out["metrics"]["burn_frac"][:5])
print("Burn fraction (last 5):", out["metrics"]["burn_frac"][-5:])
print("Frontier (first 5):", out["metrics"]["frontier"][:5])
In [ ]: