Add wave equation example#37
Conversation
Add a complete example for the 1D wave equation: u_tt = c^2 * u_xx on [0, 1] x [0, 1] with fixed-end boundaries and sinusoidal initial displacement (zero initial velocity). Includes: - models.py: Wave class with PDE residual (u_tt - c^2 u_xx), initial displacement/velocity losses, boundary condition losses, NTK support, and causal training - train.py / eval.py / main.py: Standard training and evaluation loop - configs/default.py: Configuration with grad_norm weighting and causal training enabled, 4-layer MLP with Fourier features - generate_data.py: Script to create analytical reference solution - data/wave.mat: Pre-generated reference (256x201 grid) The wave equation is second-order in time, so the loss includes both initial displacement and initial velocity conditions, plus Dirichlet boundary conditions at both ends.
There was a problem hiding this comment.
Pull request overview
Adds a new examples/wave example demonstrating a 1D wave-equation PINN workflow in jax-pi, complementing the existing example suite (which currently focuses on parabolic PDEs and fluid problems).
Changes:
- Introduces a Wave
ForwardIVPmodel with residual/IC(velocity+displacement)/BC losses, optional causal weighting, and an evaluator for logging predictions and error. - Adds training, evaluation, and CLI entrypoint scripts following the existing examples pattern.
- Adds default configuration and a reference-solution generator + dataset loader for
data/wave.mat.
Reviewed changes
Copilot reviewed 7 out of 8 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| examples/wave/utils.py | Loads the wave reference dataset from data/wave.mat. |
| examples/wave/train.py | Training loop for the wave example, including sampling, logging, and checkpoint saving. |
| examples/wave/models.py | Wave PDE model definition (losses/NTK/L2 error) and evaluator for logging. |
| examples/wave/main.py | Abseil-based entrypoint selecting train vs eval. |
| examples/wave/generate_data.py | Script to generate and save the analytical reference solution to a .mat file. |
| examples/wave/eval.py | Restores a checkpoint, computes error, and saves comparison plots. |
| examples/wave/configs/default.py | Default hyperparameters/config for the wave example (architecture, weighting, logging). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| app.run(main) |
There was a problem hiding this comment.
This main entrypoint is missing flags.mark_flags_as_required(["config", "workdir"]) which is present in other examples’ main.py files. Without it, the example can run with unintended defaults and diverges from the repo’s established example CLI pattern.
| if (step + 1) % config.saving.save_every_steps == 0 or ( | ||
| step + 1 | ||
| ) == config.training.max_steps: | ||
| ckpt_path = os.path.join(os.getcwd(), config.wandb.name, "ckpt") |
There was a problem hiding this comment.
Checkpoint save directory is derived from os.getcwd() and wandb.name, ignoring the provided workdir, and it does not match the path that eval.py uses for restore (workdir/ckpt/). As written, running train then eval with the same workdir will not find the checkpoint. Consider saving under workdir (e.g., workdir/ckpt/<wandb.name>) to make train/eval consistent and respect the CLI flag.
| ckpt_path = os.path.join(os.getcwd(), config.wandb.name, "ckpt") | |
| ckpt_path = os.path.join(workdir, "ckpt", config.wandb.name) |
|
|
||
| # Restore model | ||
| model = models.Wave(config, u0, t_star, x_star, c=config.c) | ||
| ckpt_path = os.path.join(workdir, "ckpt", config.wandb.name) |
There was a problem hiding this comment.
Checkpoint restore path (workdir/ckpt/<wandb.name>) does not match the path used by train.py (cwd/<wandb.name>/ckpt). With default flags this will fail to restore immediately after training. Align the restore path with the save path once the training side is corrected (ideally both use workdir/ckpt/<wandb.name>).
| ckpt_path = os.path.join(workdir, "ckpt", config.wandb.name) | |
| ckpt_path = os.path.join(workdir, config.wandb.name, "ckpt") |
Summary
Adds a complete example for solving the 1D wave equation using jax-pi:
feature/add-wave-equation-exampleu_{tt} = c^2 u_{xx}feature/add-wave-equation-example
on$[0,1] \times [0,1]$ with fixed-end boundary conditions and a sinusoidal initial displacement. The wave equation is a fundamental hyperbolic PDE, but the current example suite only includes parabolic PDEs (Burgers, Allen-Cahn) and incompressible flow. This fills the gap for wave/hyperbolic PDEs.
What's included
models.py: \Wave\ class inheriting from \ForwardIVP\ with:train.py/eval.py/main.py: Standard training/evaluation following the same pattern as Burgersconfigs/default.py: Default config with grad_norm weighting, causal training, Fourier features, weight factorizationgenerate_data.py: Script to create the analytical reference solutiondata/wave.mat: Pre-generated reference on a 256x201 gridTechnical notes
es)