Skip to content

Add wave equation example#37

Open
gpartin wants to merge 1 commit intoPredictiveIntelligenceLab:mainfrom
gpartin:feature/add-wave-equation-example
Open

Add wave equation example#37
gpartin wants to merge 1 commit intoPredictiveIntelligenceLab:mainfrom
gpartin:feature/add-wave-equation-example

Conversation

@gpartin
Copy link

@gpartin gpartin commented Mar 11, 2026

Summary

Adds a complete example for solving the 1D wave equation using jax-pi:

feature/add-wave-equation-exampleu_{tt} = c^2 u_{xx}feature/add-wave-equation-example

on $[0,1] \times [0,1]$ with fixed-end boundary conditions and a sinusoidal initial displacement. The wave equation is a fundamental hyperbolic PDE, but the current example suite only includes parabolic PDEs (Burgers, Allen-Cahn) and incompressible flow. This fills the gap for wave/hyperbolic PDEs.

What's included

  • models.py: \Wave\ class inheriting from \ForwardIVP\ with:

    • PDE residual: \u_tt - c^2 * u_xx\ via nested \jax.grad\
    • Initial displacement loss (\u(x,0) = sin(pi*x))
    • Initial velocity loss (\u_t(x,0) = 0)
    • Dirichlet boundary conditions at both ends
    • NTK computation and causal training support
    • \WaveEvaluator\ for logging errors and predictions
  • train.py / eval.py / main.py: Standard training/evaluation following the same pattern as Burgers

  • configs/default.py: Default config with grad_norm weighting, causal training, Fourier features, weight factorization

  • generate_data.py: Script to create the analytical reference solution

  • data/wave.mat: Pre-generated reference on a 256x201 grid

Technical notes

  • The wave equation is second-order in time, so it requires enforcing both initial displacement and initial velocity as separate loss terms (4 loss components: \ics, \ics_vel, \�cs,
    es)
  • Causal training is particularly beneficial here since the wave equation has finite propagation speed
  • No periodicity used (fixed-end BCs, not periodic)
  • Analytical solution: (x,t) = \sin(\pi x)\cos(\pi c t)$ for validation

Add a complete example for the 1D wave equation:

  u_tt = c^2 * u_xx

on [0, 1] x [0, 1] with fixed-end boundaries and sinusoidal initial
displacement (zero initial velocity). Includes:

- models.py: Wave class with PDE residual (u_tt - c^2 u_xx), initial
  displacement/velocity losses, boundary condition losses, NTK support,
  and causal training
- train.py / eval.py / main.py: Standard training and evaluation loop
- configs/default.py: Configuration with grad_norm weighting and causal
  training enabled, 4-layer MLP with Fourier features
- generate_data.py: Script to create analytical reference solution
- data/wave.mat: Pre-generated reference (256x201 grid)

The wave equation is second-order in time, so the loss includes both
initial displacement and initial velocity conditions, plus Dirichlet
boundary conditions at both ends.
Copilot AI review requested due to automatic review settings March 11, 2026 19:27
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new examples/wave example demonstrating a 1D wave-equation PINN workflow in jax-pi, complementing the existing example suite (which currently focuses on parabolic PDEs and fluid problems).

Changes:

  • Introduces a Wave ForwardIVP model with residual/IC(velocity+displacement)/BC losses, optional causal weighting, and an evaluator for logging predictions and error.
  • Adds training, evaluation, and CLI entrypoint scripts following the existing examples pattern.
  • Adds default configuration and a reference-solution generator + dataset loader for data/wave.mat.

Reviewed changes

Copilot reviewed 7 out of 8 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
examples/wave/utils.py Loads the wave reference dataset from data/wave.mat.
examples/wave/train.py Training loop for the wave example, including sampling, logging, and checkpoint saving.
examples/wave/models.py Wave PDE model definition (losses/NTK/L2 error) and evaluator for logging.
examples/wave/main.py Abseil-based entrypoint selecting train vs eval.
examples/wave/generate_data.py Script to generate and save the analytical reference solution to a .mat file.
examples/wave/eval.py Restores a checkpoint, computes error, and saves comparison plots.
examples/wave/configs/default.py Default hyperparameters/config for the wave example (architecture, weighting, logging).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment on lines +38 to +41


if __name__ == "__main__":
app.run(main)
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This main entrypoint is missing flags.mark_flags_as_required(["config", "workdir"]) which is present in other examples’ main.py files. Without it, the example can run with unintended defaults and diverges from the repo’s established example CLI pattern.

Copilot uses AI. Check for mistakes.
if (step + 1) % config.saving.save_every_steps == 0 or (
step + 1
) == config.training.max_steps:
ckpt_path = os.path.join(os.getcwd(), config.wandb.name, "ckpt")
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checkpoint save directory is derived from os.getcwd() and wandb.name, ignoring the provided workdir, and it does not match the path that eval.py uses for restore (workdir/ckpt/). As written, running train then eval with the same workdir will not find the checkpoint. Consider saving under workdir (e.g., workdir/ckpt/<wandb.name>) to make train/eval consistent and respect the CLI flag.

Suggested change
ckpt_path = os.path.join(os.getcwd(), config.wandb.name, "ckpt")
ckpt_path = os.path.join(workdir, "ckpt", config.wandb.name)

Copilot uses AI. Check for mistakes.

# Restore model
model = models.Wave(config, u0, t_star, x_star, c=config.c)
ckpt_path = os.path.join(workdir, "ckpt", config.wandb.name)
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checkpoint restore path (workdir/ckpt/<wandb.name>) does not match the path used by train.py (cwd/<wandb.name>/ckpt). With default flags this will fail to restore immediately after training. Align the restore path with the save path once the training side is corrected (ideally both use workdir/ckpt/<wandb.name>).

Suggested change
ckpt_path = os.path.join(workdir, "ckpt", config.wandb.name)
ckpt_path = os.path.join(workdir, config.wandb.name, "ckpt")

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants