diff --git a/.gitignore b/.gitignore index a065fce..52c006f 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,4 @@ examples/.ipynb_checkpoints/ conda-output/ examples/cache/ *gtfs.json +.claude/worktrees/ diff --git a/cuda/common.h b/cuda/common.h index 2580a7f..8d09bdd 100644 --- a/cuda/common.h +++ b/cuda/common.h @@ -68,4 +68,6 @@ struct Params int _pad0; // padding for pointer alignment // --- point cloud fields (offset 88) --- float* point_colors; // per-point RGBA (4 floats per point, indexed by primitive_id) + // --- smooth normal fields (offset 96) --- + unsigned long long* smooth_normal_table; // [2*instanceId]=normals_ptr, [2*instanceId+1]=indices_ptr }; diff --git a/cuda/kernel.cu b/cuda/kernel.cu index ca5c454..f2bb2bb 100644 --- a/cuda/kernel.cu +++ b/cuda/kernel.cu @@ -99,17 +99,46 @@ extern "C" __global__ void __closesthit__chit() float3 n; if (optixIsTriangleHit()) { - float3 data[3]; - // Always use the 4-parameter overload for backward compatibility. - // The parameterless overload (OptiX 9.1+) requires ABI version 99, - // which needs driver 570+. The 4-param form works on all versions. - OptixTraversableHandle gas = optixGetGASTraversableHandle(); - unsigned int sbtIdx = optixGetSbtGASIndex(); - float time = optixGetRayTime(); - optixGetTriangleVertexData(gas, primIdx, sbtIdx, time, data); - float3 AB = data[1] - data[0]; - float3 AC = data[2] - data[0]; - n = normalize(cross(AB, AC)); + // Check for per-vertex smooth normals for this instance + bool has_smooth = false; + if (params.smooth_normal_table != 0) { + unsigned int instId = optixGetInstanceId(); + unsigned long long normals_ptr = params.smooth_normal_table[2 * instId]; + if (normals_ptr != 0) { + has_smooth = true; + unsigned long long indices_ptr = params.smooth_normal_table[2 * instId + 1]; + const float* norms = reinterpret_cast(normals_ptr); + const int* idx = reinterpret_cast(indices_ptr); + + // Barycentric interpolation of vertex normals + const float2 bary = optixGetTriangleBarycentrics(); + const float w0 = 1.0f - bary.x - bary.y; + const float w1 = bary.x; + const float w2 = bary.y; + + const int i0 = idx[primIdx * 3]; + const int i1 = idx[primIdx * 3 + 1]; + const int i2 = idx[primIdx * 3 + 2]; + + n = normalize(make_float3( + w0 * norms[i0 * 3] + w1 * norms[i1 * 3] + w2 * norms[i2 * 3], + w0 * norms[i0 * 3 + 1] + w1 * norms[i1 * 3 + 1] + w2 * norms[i2 * 3 + 1], + w0 * norms[i0 * 3 + 2] + w1 * norms[i1 * 3 + 2] + w2 * norms[i2 * 3 + 2] + )); + } + } + + if (!has_smooth) { + // Flat shading fallback: face normal from triangle vertices + float3 data[3]; + OptixTraversableHandle gas = optixGetGASTraversableHandle(); + unsigned int sbtIdx = optixGetSbtGASIndex(); + float time = optixGetRayTime(); + optixGetTriangleVertexData(gas, primIdx, sbtIdx, time, data); + float3 AB = data[1] - data[0]; + float3 AC = data[2] - data[0]; + n = normalize(cross(AB, AC)); + } } else { // Round curve tube: use face-up normal for terrain roads/rivers n = make_float3(0.0f, 0.0f, 1.0f); diff --git a/examples/explore_zarr.py b/examples/explore_zarr.py index 8077681..5c111d6 100644 --- a/examples/explore_zarr.py +++ b/examples/explore_zarr.py @@ -83,7 +83,147 @@ def _utm_epsg(lon, lat): return 32600 + zone if lat >= 0 else 32700 + zone -def load_window(zarr_path, center_lon, center_lat, size_deg, subsample): +def _build_lod_arrays(zarr_path, array_name, max_factor=64): + """Build cached box-filter downsampled LOD arrays in the zarr store. + + Creates arrays named ``{array_name}_lod{factor}`` for each power-of-2 + factor from 2 up to ``max_factor``. Each level is built from the + previous one, so the full-res array is only read once. + + This is a one-time cost — subsequent runs skip existing arrays. + """ + import warnings + + root = zarr.open_group(zarr_path, mode="r+") + z = root[array_name] + + scale = float(z.attrs.get("scale_factor", 1.0)) + add_offset = float(z.attrs.get("add_offset", 0.0)) + fill_value = z.fill_value + + factor = 2 + prev_name = array_name + while factor <= max_factor: + lod_name = f"{array_name}_lod{factor}" + if lod_name in root: + prev_name = lod_name + factor *= 2 + continue + + prev = root[prev_name] + pH, pW = prev.shape + oH, oW = pH // 2, pW // 2 + if oH < 2 or oW < 2: + break + + chunk_h = min(512, oH) + chunk_w = min(512, oW) + out = root.create_array( + lod_name, + shape=(oH, oW), + chunks=(chunk_h, chunk_w), + dtype=np.float32, + fill_value=np.nan, + overwrite=True, + ) + + is_source = (prev_name == array_name) + BLOCK = 512 # output rows per processing pass + t0 = time.time() + for out_r0 in range(0, oH, BLOCK): + out_r1 = min(out_r0 + BLOCK, oH) + in_r0 = out_r0 * 2 + in_r1 = min(out_r1 * 2, pH) + + block = np.array(prev[in_r0:in_r1, :], dtype=np.float32) + if is_source: + block[block == fill_value] = np.nan + block = block * scale + add_offset + + # Box-filter 2×2 + hh = block.shape[0] // 2 * 2 + ww = block.shape[1] // 2 * 2 + if hh == 0 or ww == 0: + break + block2 = block[:hh, :ww].reshape(hh // 2, 2, ww // 2, 2) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + ds = np.nanmean(block2, axis=(1, 3)).astype(np.float32) + out[out_r0:out_r0 + ds.shape[0], :ds.shape[1]] = ds + + dt = time.time() - t0 + print(f" Built {lod_name}: {oH:,}×{oW:,} ({dt:.1f}s)") + + prev_name = lod_name + factor *= 2 + + return root + + +def _get_lod_array(root, array_name, subsample): + """Find the best cached LOD array for a given subsample factor. + + Returns ``(zarr_array, remaining_stride, needs_cf_decode)`` where + ``remaining_stride`` is the additional stride to apply after reading. + """ + # Find the largest LOD factor ≤ subsample + best_factor = 1 + best_name = array_name + factor = 2 + while factor <= subsample: + lod_name = f"{array_name}_lod{factor}" + if lod_name in root: + best_factor = factor + best_name = lod_name + factor *= 2 + + remaining = subsample // best_factor + needs_cf = (best_name == array_name) + return root[best_name], remaining, needs_cf, best_factor + + +def _read_raw_window(z, yi0, yi1, xi0, xi1, subsample, scale, offset, fill, + lod_root=None, array_name=None): + """Read a raw pixel window from the zarr array and CF-decode it. + + When ``lod_root`` is provided, reads from the best cached LOD array + for the requested subsample factor instead of striding the full-res. + """ + if lod_root is not None and array_name is not None and subsample > 1: + lod_arr, stride, needs_cf, lod_factor = _get_lod_array( + lod_root, array_name, subsample) + # Map full-res pixel coords to LOD array coords + ly0 = yi0 // lod_factor + ly1 = max(ly0 + 1, yi1 // lod_factor) + lx0 = xi0 // lod_factor + lx1 = max(lx0 + 1, xi1 // lod_factor) + # Clamp to array bounds + lH, lW = lod_arr.shape + ly0 = max(0, min(ly0, lH - 1)) + ly1 = min(ly1, lH) + lx0 = max(0, min(lx0, lW - 1)) + lx1 = min(lx1, lW) + data = lod_arr[ly0:ly1:stride, lx0:lx1:stride] + if not isinstance(data, np.ndarray): + data = cp.asnumpy(data) + data = data.astype(np.float32) + if needs_cf: + data[data == fill] = np.nan + data = data * scale + offset + return data + + # Fallback: stride the full-res array directly + data = z[yi0:yi1:subsample, xi0:xi1:subsample] + if not isinstance(data, np.ndarray): + data = cp.asnumpy(data) + data = data.astype(np.float32) + data[data == fill] = np.nan + data = data * scale + offset + return data + + +def load_window(zarr_path, center_lon, center_lat, size_deg, subsample, + horizon_rings=0, horizon_multiplier=4): """Load a subsampled elevation window from the zarr DEM. Reads the geographic-CRS chunk from disk, applies CF decoding, then @@ -99,6 +239,12 @@ def load_window(zarr_path, center_lon, center_lat, size_deg, subsample): Half-width of the window in degrees (window is 2*size_deg on a side). subsample : int Take every Nth pixel to reduce resolution. + horizon_rings : int + Number of concentric low-res rings to load around the main window. + Each ring doubles the footprint at ``horizon_multiplier`` times the + subsample of the previous ring. 0 = no horizon (original behavior). + horizon_multiplier : int + Subsample multiplier per ring (default 4: each ring is 4x coarser). Returns ------- @@ -108,43 +254,131 @@ def load_window(zarr_path, center_lon, center_lat, size_deg, subsample): z, fmt, (x_origin, dx, y_origin, dy), (scale, offset, fill), crs_wkt = \ _open_zarr_array(zarr_path) + ARRAY_NAME = "usgs10m_dem" H, W = z.shape - lon_min = center_lon - size_deg - lon_max = center_lon + size_deg - lat_min = center_lat - size_deg - lat_max = center_lat + size_deg - - # Pixel indices from GeoTransform (x ascending, y descending) - xi0 = max(int((lon_min - x_origin) / dx), 0) - xi1 = min(int((lon_max - x_origin) / dx) + 1, W) - yi0 = max(int((lat_max - y_origin) / dy), 0) # dy is negative - yi1 = min(int((lat_min - y_origin) / dy) + 1, H) - - raw_w = xi1 - xi0 - raw_h = yi1 - yi0 - out_w = raw_w // subsample - out_h = raw_h // subsample - print(f"Window: lon [{lon_min:.3f}, {lon_max:.3f}], " - f"lat [{lat_min:.3f}, {lat_max:.3f}]") - print(f"Pixel window: {raw_w:,} x {raw_h:,} " - f"(subsample {subsample}x -> {out_w:,} x {out_h:,})") - - # Read only the chunks that overlap the window - t0 = time.time() - data = z[yi0:yi1:subsample, xi0:xi1:subsample] - dt_read = time.time() - t0 - - # Ensure numpy for CF decoding + reprojection - if not isinstance(data, np.ndarray): - data = cp.asnumpy(data) - data = data.astype(np.float32) - data[data == fill] = np.nan - data = data * scale + offset - print(f"Read: {dt_read:.2f}s") - # Build geographic DataArray with source CRS - x_coords = np.arange(xi0, xi1, subsample) * dx + x_origin - y_coords = np.arange(yi0, yi1, subsample) * dy + y_origin + # Use cached LOD arrays for fast horizon reads when available. + # If they don't exist, _read_raw_window falls back to strided reads. + # Run with --build-lods to pre-build the full LOD cache explicitly. + lod_root = None + if horizon_rings > 0: + lod_root = zarr.open_group(zarr_path, mode="r") + # Check if any LOD arrays exist + has_lods = any(f"{ARRAY_NAME}_lod{2**k}" in lod_root + for k in range(1, 7)) + if has_lods: + print("Using cached LOD arrays for horizon reads") + else: + print("No LOD cache found — using strided reads " + "(run with --build-lods to pre-build)") + + if horizon_rings > 0: + # Build a composite: coarse wide canvas with high-res center. + # Work outward from the widest ring, then paste finer rings on top. + s = subsample + sz = size_deg + for _ in range(horizon_rings): + s *= horizon_multiplier + sz *= 2 + # Outermost ring defines the canvas + canvas_sub = s + rings_spec = [] + cur_s, cur_sz = canvas_sub, sz + while cur_s >= subsample: + rings_spec.append((cur_sz, cur_s)) + cur_s //= horizon_multiplier + cur_sz /= 2 + + # Read outermost (coarsest) ring as the canvas + outer_sz, outer_sub = rings_spec[0] + lon_min_c = center_lon - outer_sz + lon_max_c = center_lon + outer_sz + lat_min_c = center_lat - outer_sz + lat_max_c = center_lat + outer_sz + xi0_c = max(int((lon_min_c - x_origin) / dx), 0) + xi1_c = min(int((lon_max_c - x_origin) / dx) + 1, W) + yi0_c = max(int((lat_max_c - y_origin) / dy), 0) + yi1_c = min(int((lat_min_c - y_origin) / dy) + 1, H) + + print(f"Horizon: {len(rings_spec)} rings, " + f"outermost {outer_sz:.3f}° at {outer_sub}x subsample") + t0 = time.time() + canvas = _read_raw_window(z, yi0_c, yi1_c, xi0_c, xi1_c, + outer_sub, scale, offset, fill, + lod_root=lod_root, array_name=ARRAY_NAME) + # Canvas geo coords + x_coords_c = np.arange(xi0_c, xi1_c, outer_sub) * dx + x_origin + y_coords_c = np.arange(yi0_c, yi1_c, outer_sub) * dy + y_origin + + # Paste inner rings (each finer, smaller footprint) + for ring_sz, ring_sub in rings_spec[1:]: + lon_min_r = center_lon - ring_sz + lon_max_r = center_lon + ring_sz + lat_min_r = center_lat - ring_sz + lat_max_r = center_lat + ring_sz + xi0_r = max(int((lon_min_r - x_origin) / dx), 0) + xi1_r = min(int((lon_max_r - x_origin) / dx) + 1, W) + yi0_r = max(int((lat_max_r - y_origin) / dy), 0) + yi1_r = min(int((lat_min_r - y_origin) / dy) + 1, H) + + ring_data = _read_raw_window(z, yi0_r, yi1_r, xi0_r, xi1_r, + ring_sub, scale, offset, fill, + lod_root=lod_root, + array_name=ARRAY_NAME) + ring_x = np.arange(xi0_r, xi1_r, ring_sub) * dx + x_origin + ring_y = np.arange(yi0_r, yi1_r, ring_sub) * dy + y_origin + + # Upsample ring_data to canvas resolution via nearest-neighbor + # and paste into the canvas where it overlaps. + ratio = outer_sub // ring_sub + cx0 = np.searchsorted(x_coords_c, ring_x[0]) + cy0 = np.searchsorted(-y_coords_c, -ring_y[0]) # y descending + # Upsample: repeat each ring pixel ratio×ratio times + upsampled = np.repeat(np.repeat(ring_data, ratio, axis=0), + ratio, axis=1) + # Clip to canvas bounds + paste_h = min(upsampled.shape[0], canvas.shape[0] - cy0) + paste_w = min(upsampled.shape[1], canvas.shape[1] - cx0) + if paste_h > 0 and paste_w > 0: + canvas[cy0:cy0 + paste_h, cx0:cx0 + paste_w] = \ + upsampled[:paste_h, :paste_w] + + dt_read = time.time() - t0 + print(f"Horizon read: {dt_read:.2f}s, " + f"canvas {canvas.shape[1]}x{canvas.shape[0]}") + + data = canvas + x_coords = x_coords_c + y_coords = y_coords_c + out_h, out_w = data.shape + else: + lon_min = center_lon - size_deg + lon_max = center_lon + size_deg + lat_min = center_lat - size_deg + lat_max = center_lat + size_deg + + xi0 = max(int((lon_min - x_origin) / dx), 0) + xi1 = min(int((lon_max - x_origin) / dx) + 1, W) + yi0 = max(int((lat_max - y_origin) / dy), 0) + yi1 = min(int((lat_min - y_origin) / dy) + 1, H) + + raw_w = xi1 - xi0 + raw_h = yi1 - yi0 + out_w = raw_w // subsample + out_h = raw_h // subsample + print(f"Window: lon [{lon_min:.3f}, {lon_max:.3f}], " + f"lat [{lat_min:.3f}, {lat_max:.3f}]") + print(f"Pixel window: {raw_w:,} x {raw_h:,} " + f"(subsample {subsample}x -> {out_w:,} x {out_h:,})") + + t0 = time.time() + data = _read_raw_window(z, yi0, yi1, xi0, xi1, + subsample, scale, offset, fill) + dt_read = time.time() - t0 + print(f"Read: {dt_read:.2f}s") + + x_coords = np.arange(xi0, xi1, subsample) * dx + x_origin + y_coords = np.arange(yi0, yi1, subsample) * dy + y_origin da = xr.DataArray( data, @@ -175,7 +409,8 @@ def load_window(zarr_path, center_lon, center_lat, size_deg, subsample): return out -def make_terrain_loader(zarr_path, size_deg, subsample, center_lon, center_lat): +def make_terrain_loader(zarr_path, size_deg, subsample, center_lon, center_lat, + horizon_rings=0): """Create a terrain loader callback for dynamic chunk streaming. The viewer passes the camera position in the DataArray's CRS (UTM @@ -193,6 +428,8 @@ def make_terrain_loader(zarr_path, size_deg, subsample, center_lon, center_lat): center_lon, center_lat : float Initial center in WGS84 degrees (used to pick the UTM zone for the inverse projection). + horizon_rings : int + Number of low-res horizon rings to include on reload. Returns ------- @@ -209,7 +446,8 @@ def make_terrain_loader(zarr_path, size_deg, subsample, center_lon, center_lat): def loader(cam_x, cam_y): try: lon, lat = to_lonlat.transform(cam_x, cam_y) - return load_window(zarr_path, lon, lat, size_deg, subsample) + return load_window(zarr_path, lon, lat, size_deg, subsample, + horizon_rings=horizon_rings) except Exception as e: print(f"Terrain loader error: {e}") return None @@ -217,6 +455,74 @@ def loader(cam_x, cam_y): return loader +def make_tile_data_fn(zarr_path, center_lon, center_lat): + """Create a tile data callback for streaming terrain from zarr. + + The callback converts UTM world-space bounds to geographic + coordinates, reads the corresponding slice from the zarr store, + CF-decodes the elevation data, and returns it as a float32 array. + + Parameters + ---------- + zarr_path : str + Path to the zarr store. + center_lon, center_lat : float + Center of the initial view in WGS84 degrees (used to pick the + UTM zone for the inverse projection). + + Returns + ------- + callable + ``fn(x_min, y_min, x_max, y_max, target_samples)`` + → ``np.ndarray`` of shape ``(H, W)`` (float32 elevations, + row 0 = northernmost) or ``None``. + """ + from pyproj import Transformer + + z, _fmt, (x_origin, dx, y_origin, dy), \ + (scale, offset, fill), _crs_wkt = _open_zarr_array(zarr_path) + H, W = z.shape + + target_epsg = _utm_epsg(center_lon, center_lat) + to_lonlat = Transformer.from_crs( + f"EPSG:{target_epsg}", "EPSG:4326", always_xy=True) + + lod_root = zarr.open_group(zarr_path, mode="r") + ARRAY_NAME = "usgs10m_dem" + + def tile_data_fn(x_min, y_min, x_max, y_max, target_samples): + # Convert UTM corners to lon/lat + lon_min, lat_min = to_lonlat.transform(x_min, y_min) + lon_max, lat_max = to_lonlat.transform(x_max, y_max) + + # Zarr pixel indices (dy < 0: row 0 = northernmost latitude) + xi0 = int((lon_min - x_origin) / dx) + xi1 = int((lon_max - x_origin) / dx) + 1 + yi0 = int((lat_max - y_origin) / dy) # north → smaller index + yi1 = int((lat_min - y_origin) / dy) + 1 # south → larger index + + # Clamp to array bounds + xi0 = max(0, xi0) + xi1 = min(W, xi1) + yi0 = max(0, yi0) + yi1 = min(H, yi1) + + if xi1 <= xi0 or yi1 <= yi0: + return None + + # Compute stride to achieve approximately target_samples per side + raw_size = max(xi1 - xi0, yi1 - yi0) + stride = max(1, raw_size // target_samples) + + data = _read_raw_window(z, yi0, yi1, xi0, xi1, stride, + scale, offset, fill, + lod_root=lod_root, + array_name=ARRAY_NAME) + return data + + return tile_data_fn + + if __name__ == "__main__": parser = argparse.ArgumentParser( description="Explore a large zarr DEM interactively." @@ -231,61 +537,47 @@ def loader(cam_x, cam_y): help="Subsample factor (default: 4)") parser.add_argument("--zarr", type=str, default=ZARR_PATH, help="Path to zarr store") + parser.add_argument("--horizon", type=int, default=2, + help="Number of low-res horizon rings (default: 2, 0=off)") + parser.add_argument("--build-lods", action="store_true", + help="Pre-build LOD cache arrays and exit") + parser.add_argument("--clear-lods", action="store_true", + help="Delete cached LOD arrays from the zarr store and exit") args = parser.parse_args() + if args.clear_lods: + root = zarr.open_group(args.zarr, mode="r+") + removed = [] + for key in list(root.keys()): + if "_lod" in key: + del root[key] + removed.append(key) + if removed: + print(f"Removed {len(removed)} LOD arrays: {', '.join(removed)}") + else: + print("No LOD arrays found.") + raise SystemExit(0) + + if args.build_lods: + max_sub = args.subsample + for _ in range(max(args.horizon, 2)): + max_sub *= 4 + print(f"Building LOD arrays up to {max_sub}x...") + _build_lod_arrays(args.zarr, "usgs10m_dem", max_factor=max_sub) + print("Done.") + raise SystemExit(0) + + # Tile streaming handles the horizon — no need for baked multi-res + # rings in the initial window (they produce coarse pixels in LOD 0). terrain = load_window( args.zarr, args.lon, args.lat, args.size, args.subsample, + horizon_rings=0, ) - loader = make_terrain_loader(args.zarr, args.size, args.subsample, args.lon, args.lat) - - # --- Hydro flow data (off by default, Shift+Y to toggle) --------------- - hydro = None - try: - from xrspatial import fill as _fill - from xrspatial import flow_direction as _flow_direction - from xrspatial import flow_accumulation as _flow_accumulation - from xrspatial import stream_order as _stream_order - from scipy.ndimage import uniform_filter as _uniform_filter - - print("Conditioning DEM for hydrological flow...") - _elev = cp.asnumpy(terrain.data).astype(np.float32) - _nodata = (_elev == 0.0) | np.isnan(_elev) - _elev[_nodata] = -100.0 - - _smoothed = _uniform_filter(_elev, size=15, mode='nearest') - _smoothed[_nodata] = -100.0 - - _sm = cp.asarray(_smoothed) - _filled = _fill(terrain.copy(data=_sm)) - _fd = _filled.data - _sm - _resolved = _filled.data + _fd * 0.01 - cp.random.seed(0) - _resolved += cp.random.uniform(0, 0.001, _resolved.shape, - dtype=cp.float32) - _resolved[cp.asarray(_nodata)] = -100.0 - - fd = _flow_direction(terrain.copy(data=_resolved)) - fa = _flow_accumulation(fd) - so = _stream_order(fd, fa, threshold=50) - - fd_out, fa_out, so_out = fd.data, fa.data, so.data - _nodata_gpu = cp.asarray(_nodata) - fd_out[_nodata_gpu] = cp.nan - fa_out[_nodata_gpu] = cp.nan - so_out[_nodata_gpu] = cp.nan - - hydro = { - 'flow_dir': fd_out, - 'flow_accum': fa_out, - 'stream_order': so_out, - 'accum_threshold': 50, - 'enabled': False, - } - print(f" Flow direction + accumulation computed on " - f"{terrain.shape[0]}x{terrain.shape[1]} grid") - except Exception as e: - print(f"Skipping hydro: {e}") + tile_fn = make_tile_data_fn(args.zarr, args.lon, args.lat) + + # Hydro flow: computed lazily on GPU when first enabled (Shift+Y) + hydro = {'enabled': False} print(f"\nLaunching explore (Shift+Y to toggle hydro)...\n") terrain.rtx.explore( @@ -293,7 +585,7 @@ def loader(cam_x, cam_y): height=1600, render_scale=0.5, color_stretch='cbrt', - terrain_loader=loader, + tile_data_fn=tile_fn, hydro_data=hydro, repl=True, ) diff --git a/examples/playground.py b/examples/playground.py index ab86fdb..a6d98f5 100644 --- a/examples/playground.py +++ b/examples/playground.py @@ -244,92 +244,14 @@ def load_terrain(): except Exception as e: print(f"Skipping weather: {e}") - # --- Hydro flow data (off by default, Shift+Y to toggle) --------------- - hydro = None - try: - from xrspatial import fill as _fill - from xrspatial import flow_direction as _flow_direction - from xrspatial import flow_accumulation as _flow_accumulation - from xrspatial import stream_order as _stream_order - from xrspatial import stream_link as _stream_link - from scipy.ndimage import uniform_filter as _uniform_filter - - print("Conditioning DEM for hydrological flow...") - _data = terrain.data - _is_cupy = hasattr(_data, 'get') - _elev = _data.get() if _is_cupy else np.array(_data) - _elev = _elev.astype(np.float32) - _ocean = (_elev == 0.0) | np.isnan(_elev) - _elev[_ocean] = -100.0 - - _smoothed = _uniform_filter(_elev, size=15, mode='nearest') - _smoothed[_ocean] = -100.0 - - if _is_cupy: - import cupy as _cp - _sm = _cp.asarray(_smoothed) - else: - _sm = _smoothed - _filled = _fill(terrain.copy(data=_sm)) - _fd = _filled.data - _sm - _resolved = _filled.data + _fd * 0.01 - if _is_cupy: - _cp.random.seed(0) - _resolved += _cp.random.uniform(0, 0.001, _resolved.shape, - dtype=_cp.float32) - _resolved[_cp.asarray(_ocean)] = -100.0 - else: - np.random.seed(0) - _resolved += np.random.uniform( - 0, 0.001, _resolved.shape).astype(np.float32) - _resolved[_ocean] = -100.0 - - fd = _flow_direction(terrain.copy(data=_resolved)) - fa = _flow_accumulation(fd) - so = _stream_order(fd, fa, threshold=50) - sl = _stream_link(fd, fa, threshold=50) - - fd_out, fa_out, so_out = fd.data, fa.data, so.data - if _is_cupy: - _ocean_gpu = _cp.asarray(_ocean) - fd_out[_ocean_gpu] = _cp.nan - fa_out[_ocean_gpu] = _cp.nan - so_out[_ocean_gpu] = _cp.nan - else: - fd_out[_ocean] = np.nan - fa_out[_ocean] = np.nan - so_out[_ocean] = np.nan - - _sl_out = sl.data - if _is_cupy: - _sl_out[_ocean_gpu] = _cp.nan - else: - _sl_out[_ocean] = np.nan - _sl_np = _sl_out.get() if _is_cupy else np.asarray(_sl_out) - _sl_clean = np.nan_to_num(_sl_np, nan=0.0).astype(np.float32) - if _is_cupy: - _sl_clean = _cp.asarray(_sl_clean) - ds['stream_link'] = terrain.copy(data=_sl_clean).rename(None) - - hydro = { - 'flow_dir': fd_out, - 'flow_accum': fa_out, - 'stream_order': so_out, - 'stream_link': _sl_out, - 'accum_threshold': 50, - 'enabled': False, - } - print(f" Flow direction + accumulation computed on " - f"{terrain.shape[0]}x{terrain.shape[1]} grid") - except Exception as e: - print(f"Skipping hydro: {e}") + # Hydro flow: computed lazily on GPU when first enabled (Shift+Y) + hydro = {'enabled': False} print("\nLaunching explore (press G to cycle layers, " "Shift+W for wind, Shift+N for clouds, Shift+Y for hydro)...\n") ds.rtx.explore( z='elevation', scene_zarr=ZARR, - mesh_type='voxel', width=1024, height=768, render_scale=0.5, diff --git a/rtxpy/accessor.py b/rtxpy/accessor.py index f39719a..8779dbb 100644 --- a/rtxpy/accessor.py +++ b/rtxpy/accessor.py @@ -2592,10 +2592,10 @@ def place_tiles(self, url='osm', zoom=None): def explore(self, width=800, height=600, render_scale=0.5, start_position=None, look_at=None, key_repeat_interval=0.05, pixel_spacing_x=None, pixel_spacing_y=None, - mesh_type='heightfield', color_stretch='linear', title=None, + color_stretch='linear', title=None, subsample=1, wind_data=None, weather_data=None, hydro_data=None, gtfs_data=None, - terrain_loader=None, + terrain_loader=None, tile_data_fn=None, scene_zarr=None, ao_samples=0, gi_bounces=1, denoise=False, fog_density=0.0, fog_color=(0.7, 0.8, 0.9), colormap=None, sun_azimuth=None, sun_altitude=None, @@ -2632,9 +2632,6 @@ def explore(self, width=800, height=600, render_scale=0.5, pixel_spacing_y : float, optional Y spacing between pixels in world units. If None, uses the value from the last triangulate() call (default 1.0). - mesh_type : str, optional - Mesh generation method: 'tin' or 'voxel'. - Default is 'tin'. subsample : int, optional Initial terrain subsample factor (1, 2, 4, 8). Full-resolution data is preserved; press Shift+R / R to change at runtime. @@ -2695,23 +2692,6 @@ def explore(self, width=800, height=600, render_scale=0.5, spacing_x = pixel_spacing_x if pixel_spacing_x is not None else self._pixel_spacing_x spacing_y = pixel_spacing_y if pixel_spacing_y is not None else self._pixel_spacing_y - # Rebuild terrain geometry if mesh_type doesn't match current state - current_mesh_type = getattr(self, '_terrain_mesh_type', 'heightfield') - if mesh_type != current_mesh_type and 'terrain' in (self._rtx.list_geometries() or []): - self._rtx.remove_geometry('terrain') - if mesh_type == 'heightfield': - self.heightfield(geometry_id='terrain', - pixel_spacing_x=spacing_x, - pixel_spacing_y=spacing_y) - elif mesh_type == 'voxel': - self.voxelate(geometry_id='terrain', - pixel_spacing_x=spacing_x, - pixel_spacing_y=spacing_y) - else: - self.triangulate(geometry_id='terrain', - pixel_spacing_x=spacing_x, - pixel_spacing_y=spacing_y) - # Pass geometry color builder if any colors are set geometry_colors_builder = None if self._geometry_colors: @@ -2728,7 +2708,6 @@ def explore(self, width=800, height=600, render_scale=0.5, rtx=self._rtx, pixel_spacing_x=spacing_x, pixel_spacing_y=spacing_y, - mesh_type=mesh_type, color_stretch=color_stretch, title=title, tile_service=getattr(self, '_tile_service', None), @@ -2741,6 +2720,7 @@ def explore(self, width=800, height=600, render_scale=0.5, gtfs_data=gtfs_data, accessor=self, terrain_loader=terrain_loader, + tile_data_fn=tile_data_fn, scene_zarr=scene_zarr, ao_samples=ao_samples, gi_bounces=gi_bounces, @@ -3064,12 +3044,12 @@ def load_meshes(self, zarr_path, chunks=None, z=None): def explore(self, z, width=800, height=600, render_scale=0.5, start_position=None, look_at=None, key_repeat_interval=0.05, pixel_spacing_x=None, pixel_spacing_y=None, - mesh_type='heightfield', color_stretch='linear', title=None, + color_stretch='linear', title=None, subtitle=None, legend=None, subsample=1, wind_data=None, weather_data=None, hydro_data=None, gtfs_data=None, - terrain_loader=None, + terrain_loader=None, tile_data_fn=None, scene_zarr=None, ao_samples=0, gi_bounces=1, denoise=False, fog_density=0.0, fog_color=(0.7, 0.8, 0.9), @@ -3108,8 +3088,6 @@ def explore(self, z, width=800, height=600, render_scale=0.5, X spacing between pixels in world units. Default is 1.0. pixel_spacing_y : float, optional Y spacing between pixels in world units. Default is 1.0. - mesh_type : str, optional - Mesh generation method: 'tin' or 'voxel'. Default is 'tin'. repl : bool, optional If True, start an interactive Python REPL alongside the viewer. Default is False. @@ -3167,7 +3145,6 @@ def explore(self, z, width=800, height=600, render_scale=0.5, rtx=terrain_da.rtx._rtx, pixel_spacing_x=spacing_x, pixel_spacing_y=spacing_y, - mesh_type=mesh_type, overlay_layers=overlay_layers, color_stretch=color_stretch, title=title, @@ -3183,6 +3160,7 @@ def explore(self, z, width=800, height=600, render_scale=0.5, gtfs_data=gtfs_data, accessor=terrain_da.rtx, terrain_loader=terrain_loader, + tile_data_fn=tile_data_fn, scene_zarr=scene_zarr, ao_samples=ao_samples, gi_bounces=gi_bounces, diff --git a/rtxpy/analysis/render.py b/rtxpy/analysis/render.py index ff31593..a0b6d31 100644 --- a/rtxpy/analysis/render.py +++ b/rtxpy/analysis/render.py @@ -1098,8 +1098,10 @@ def _shade_terrain_kernel( pixel_spacing_x, pixel_spacing_y, color_stretch, rgb_texture, + rgb_texture_offset_y, rgb_texture_offset_x, overlay_data, overlay_alpha, overlay_min, overlay_range, overlay_as_water, overlay_color_lut, + overlay_offset_y, overlay_offset_x, instance_ids, geometry_colors, primitive_ids, point_colors, point_color_offsets, ao_factor, gi_color, gi_intensity, @@ -1213,10 +1215,12 @@ def _shade_terrain_kernel( if tex_h > 1: # Sample RGB directly from tile texture - if elev_y >= 0 and elev_y < tex_h and elev_x >= 0 and elev_x < tex_w: - base_r = rgb_texture[elev_y, elev_x, 0] - base_g = rgb_texture[elev_y, elev_x, 1] - base_b = rgb_texture[elev_y, elev_x, 2] + tex_y = elev_y - rgb_texture_offset_y + tex_x = elev_x - rgb_texture_offset_x + if tex_y >= 0 and tex_y < tex_h and tex_x >= 0 and tex_x < tex_w: + base_r = rgb_texture[tex_y, tex_x, 0] + base_g = rgb_texture[tex_y, tex_x, 1] + base_b = rgb_texture[tex_y, tex_x, 2] else: base_r = 0.3 base_g = 0.3 @@ -1261,8 +1265,10 @@ def _shade_terrain_kernel( ov_h = overlay_data.shape[0] ov_w = overlay_data.shape[1] if ov_h > 1 and overlay_alpha > 0.0: - if elev_y >= 0 and elev_y < ov_h and elev_x >= 0 and elev_x < ov_w: - ov_val = overlay_data[elev_y, elev_x] + ov_y = elev_y - overlay_offset_y + ov_x = elev_x - overlay_offset_x + if ov_y >= 0 and ov_y < ov_h and ov_x >= 0 and ov_x < ov_w: + ov_val = overlay_data[ov_y, ov_x] if not math.isnan(ov_val): if overlay_as_water and ov_val > 0.5: # Flood water shader — same look as ocean @@ -1417,10 +1423,12 @@ def _shade_terrain_kernel( tex_h = rgb_texture.shape[0] tex_w = rgb_texture.shape[1] - if tex_h > 1 and refl_ey >= 0 and refl_ey < tex_h and refl_ex >= 0 and refl_ex < tex_w: - refl_r = rgb_texture[refl_ey, refl_ex, 0] - refl_g = rgb_texture[refl_ey, refl_ex, 1] - refl_b = rgb_texture[refl_ey, refl_ex, 2] + refl_tex_y = refl_ey - rgb_texture_offset_y + refl_tex_x = refl_ex - rgb_texture_offset_x + if tex_h > 1 and refl_tex_y >= 0 and refl_tex_y < tex_h and refl_tex_x >= 0 and refl_tex_x < tex_w: + refl_r = rgb_texture[refl_tex_y, refl_tex_x, 0] + refl_g = rgb_texture[refl_tex_y, refl_tex_x, 1] + refl_b = rgb_texture[refl_tex_y, refl_tex_x, 2] elif refl_ey >= 0 and refl_ey < elev_h and refl_ex >= 0 and refl_ex < elev_w: refl_elev = elevation_data[refl_ey, refl_ex] if elev_range > 0: @@ -1889,10 +1897,12 @@ def _shade_terrain( color_stretch=0, sky_color=(-1.0, 0.0, 0.0), rgb_texture=None, + rgb_texture_offset_y=0, rgb_texture_offset_x=0, overlay_data=None, overlay_alpha=0.5, overlay_min=0.0, overlay_range=1.0, overlay_as_water=False, overlay_color_lut=None, + overlay_offset_y=0, overlay_offset_x=0, instance_ids=None, geometry_colors=None, primitive_ids=None, point_colors=None, point_color_offsets=None, ao_factor=None, gi_color=None, gi_intensity=2.0, @@ -2012,8 +2022,10 @@ def _shade_terrain( pixel_spacing_x, pixel_spacing_y, color_stretch, rgb_texture, + np.int32(rgb_texture_offset_y), np.int32(rgb_texture_offset_x), overlay_data, overlay_alpha, overlay_min, overlay_range, overlay_as_water, overlay_color_lut, + np.int32(overlay_offset_y), np.int32(overlay_offset_x), instance_ids, geometry_colors, primitive_ids, point_colors, point_color_offsets, ao_factor, gi_color, np.float32(gi_intensity), @@ -2144,10 +2156,14 @@ def render( color_stretch: str = 'linear', sky_color: Optional[Tuple[float, float, float]] = None, rgb_texture=None, + rgb_texture_offset_y: int = 0, + rgb_texture_offset_x: int = 0, overlay_data=None, overlay_alpha: float = 0.5, overlay_as_water: bool = False, overlay_color_lut=None, + overlay_offset_y: int = 0, + overlay_offset_x: int = 0, geometry_colors=None, ao_samples: int = 0, ao_radius: Optional[float] = None, @@ -2537,10 +2553,14 @@ def render( stretch_int, sky_color=(-1.0, 0.0, 0.0) if sky_color is None else sky_color, rgb_texture=rgb_texture, + rgb_texture_offset_y=rgb_texture_offset_y, + rgb_texture_offset_x=rgb_texture_offset_x, overlay_data=d_overlay, overlay_alpha=overlay_alpha, overlay_min=ov_min, overlay_range=ov_range, overlay_as_water=overlay_as_water, overlay_color_lut=overlay_color_lut, + overlay_offset_y=overlay_offset_y, + overlay_offset_x=overlay_offset_x, instance_ids=d_instance_ids, geometry_colors=geometry_colors, primitive_ids=d_primitive_ids, point_colors=d_point_colors, diff --git a/rtxpy/engine.py b/rtxpy/engine.py index 642d3b4..3815614 100644 --- a/rtxpy/engine.py +++ b/rtxpy/engine.py @@ -29,6 +29,7 @@ from .viewer.wind import WindState from .viewer.cloud import CloudState from .viewer.hydro import HydroState +from .viewer.hydro_manager import HydroManager from .viewer.hud import HUDState from .viewer.keybindings import MOVEMENT_KEYS, SHIFT_BINDINGS, KEY_BINDINGS, SPECIAL_BINDINGS @@ -162,166 +163,6 @@ def _wind_splat_kernel( cuda.atomic.add(output, (py, px, 2), contrib * color_b) - @cuda.jit - def _hydro_splat_kernel( - trails, # (N*T, 2) float32 — (row, col) per trail point - ages, # (N,) int32 — per-particle age - lifetimes, # (N,) int32 — per-particle lifetime - colors, # (N, 3) float32 — per-particle (r, g, b) - radii, # (N,) int32 — per-particle splat radius - trail_len, # int32 scalar — trail points per particle - base_alpha, # float32 scalar — base alpha intensity - min_vis_age, # int32 scalar — minimum visible age - ref_depth, # float32 scalar — depth-scaling reference distance - terrain, # (tH, tW) float32 — terrain elevation - depth_t, # (sh, sw) float32 — ray-trace t-values for occlusion - output, # (sh, sw, 3) float32 — frame buffer (atomic add) - # Camera basis — scalar args to avoid tiny GPU allocations - cam_x, cam_y, cam_z, - fwd_x, fwd_y, fwd_z, - rgt_x, rgt_y, rgt_z, - up_x, up_y, up_z, - # Projection params - fov_scale, aspect_ratio, - # Terrain/world params - psx, psy, ve, subsample_f, min_depth, - ): - idx = cuda.grid(1) - if idx >= trails.shape[0]: - return - - # Compute alpha on-GPU from per-particle ages/lifetimes - pidx = idx // trail_len - tidx = idx % trail_len - age = ages[pidx] - lifetime = lifetimes[pidx] - - # Trail point not yet laid down - if age <= tidx: - return - - # Fade in / fade out / trail decay - fade_in = (age - min_vis_age) * 0.1 - if fade_in < 0.0: - fade_in = 0.0 - elif fade_in > 1.0: - fade_in = 1.0 - fade_out = (lifetime - age) * 0.05 - if fade_out < 0.0: - fade_out = 0.0 - elif fade_out > 1.0: - fade_out = 1.0 - # Quadratic trail decay — comet-tail effect - t = float(tidx) / float(trail_len) - trail_fade = (1.0 - t) * (1.0 - t) - a = base_alpha * fade_in * fade_out * trail_fade - - # Head glow: bright spark at particle position - if tidx == 0: - a = a * 1.5 - - if a < 1e-6: - return - - row = trails[idx, 0] - col = trails[idx, 1] - - # Terrain Z lookup (nearest-neighbor, clamped) - tH = terrain.shape[0] - tW = terrain.shape[1] - sr = int(row / subsample_f) - sc = int(col / subsample_f) - if sr < 0: - sr = 0 - elif sr >= tH: - sr = tH - 1 - if sc < 0: - sc = 0 - elif sc >= tW: - sc = tW - 1 - z_raw = terrain[sr, sc] - if z_raw != z_raw: # NaN check - z_raw = 0.0 - z_val = z_raw * ve + 3.0 - - # World position - wx = col * psx - wy = row * psy - - # Camera-relative - dx = wx - cam_x - dy = wy - cam_y - dz = z_val - cam_z - - # Depth along forward axis - depth = dx * fwd_x + dy * fwd_y + dz * fwd_z - if depth <= min_depth: - return - - # Depth-scaled alpha: closer = brighter, farther = fainter. - # Prevents zoomed-out over-saturation from dense overlapping particles. - depth_scale = ref_depth / (depth + ref_depth) - a = a * depth_scale - - if a < 1e-6: - return - - inv_depth = 1.0 / (depth + 1e-10) - u_cam = dx * rgt_x + dy * rgt_y + dz * rgt_z - v_cam = dx * up_x + dy * up_y + dz * up_z - u_ndc = u_cam * inv_depth / (fov_scale * aspect_ratio) - v_ndc = v_cam * inv_depth / fov_scale - - sh = output.shape[0] - sw = output.shape[1] - sx = int((u_ndc + 1.0) * 0.5 * sw) - sy = int((1.0 - v_ndc) * 0.5 * sh) - - if sx < 0 or sx >= sw or sy < 0 or sy >= sh: - return - - # Depth test: cull particles occluded by terrain. - # Convert ray t-value at this pixel to forward depth, then compare - # to the particle's forward depth (already computed as `depth`). - if depth_t.shape[0] > 0: - t_val = depth_t[sy, sx] - if t_val > 0.0 and t_val < 1.0e20: - # Forward depth = t / sqrt(1 + u_cam^2 + v_cam^2) - u_px = (2.0 * float(sx) / float(sw) - 1.0) * fov_scale * aspect_ratio - v_px = (1.0 - 2.0 * float(sy) / float(sh)) * fov_scale - inv_cos = math.sqrt(1.0 + u_px * u_px + v_px * v_px) - terrain_fwd = t_val / inv_cos - if depth > terrain_fwd: - return - - # Per-particle color and radius - color_r = colors[pidx, 0] - color_g = colors[pidx, 1] - color_b = colors[pidx, 2] - r = radii[pidx] - if r < 1: - r = 1 - # Head glow: +1px radius halo at particle position - if tidx == 0: - r = r + 1 - - # Circular stamp splat - for offy in range(-r, r + 1): - for offx in range(-r, r + 1): - dist_sq = offx * offx + offy * offy - if dist_sq > r * r: - continue - falloff = 1.0 - math.sqrt(dist_sq) / r - px = sx + offx - py = sy + offy - if px < 0 or px >= sw or py < 0 or py >= sh: - continue - contrib = a * falloff - cuda.atomic.add(output, (py, px, 0), contrib * color_r) - cuda.atomic.add(output, (py, px, 1), contrib * color_g) - cuda.atomic.add(output, (py, px, 2), contrib * color_b) - - @cuda.jit def _rain_splat_kernel( pts, # (N, 2) float32 — (row, col) per rain particle @@ -538,6 +379,25 @@ def __init__(self, zarr_path, psx, psy): self.radius = 2 self._zarr_path = zarr_path + # Distance-aware loading parameters + self._chunk_world_w = self._chunk_w * psx + self._chunk_world_h = self._chunk_h * psy + self.max_distance = None # None = use radius-based fallback + self.per_tick_load_limit = 2 # max new zarr reads per tick + self.max_chunks = 25 # max visible chunks + + # LOD-aware loading state + self._lod_distances = None # set from LOD manager + self._tile_lods = None # {(tr,tc): lod_level} from aligned LOD manager + self._last_cam_pos = None # for movement detection + self._cam_moving = False + + # Mesh simplification for placed geometry at higher LOD levels. + # LOD 0 = full detail (ratio 1.0), LOD 1 = 50%, LOD 2 = 25%, LOD 3+ = 10%. + self._simplify_ratios = (1.0, 0.5, 0.25, 0.1) + # Cache: (cr, cc, gid, lod) -> (simplified_verts, simplified_indices) + self._simplify_cache = {} + def _load_chunk(self, cr, cc): """Load a single chunk from zarr into cache.""" if (cr, cc) in self._cache: @@ -553,59 +413,190 @@ def _load_chunk(self, cr, cc): combined[gid] = data # (verts, widths, indices) self._cache[(cr, cc)] = combined + def _chunk_center(self, cr, cc): + """World-coordinate center of chunk (cr, cc).""" + cx = (cc * self._chunk_w + self._chunk_w * 0.5) * self._psx + cy = (cr * self._chunk_h + self._chunk_h * 0.5) * self._psy + return cx, cy + + def _get_simplified(self, cr, cc, gid, lod, verts, indices): + """Return (possibly simplified) mesh for a chunk at a given LOD. + + LOD 0 returns the original mesh. LOD 1+ applies quadric + decimation with ratio from ``_simplify_ratios``, caching the + result for reuse across frames. + """ + ratio_idx = min(lod, len(self._simplify_ratios) - 1) + ratio = self._simplify_ratios[ratio_idx] + if ratio >= 1.0: + return verts, indices + key = (cr, cc, gid, lod) + cached = self._simplify_cache.get(key) + if cached is not None: + return cached + from .lod import simplify_mesh + sv, si = simplify_mesh(verts, indices, ratio) + self._simplify_cache[key] = (sv, si) + return sv, si + def update(self, cam_x, cam_y, viewer): """Called per tick. Returns True if meshes changed.""" - # Camera world pos -> chunk coord - cc_cam = int(cam_x / self._psx) // self._chunk_w - cr_cam = int(cam_y / self._psy) // self._chunk_h - - # Compute visible ring clamped to grid - cr0 = max(cr_cam - self.radius, 0) - cr1 = min(cr_cam + self.radius, self._n_chunk_rows - 1) - cc0 = max(cc_cam - self.radius, 0) - cc1 = min(cc_cam + self.radius, self._n_chunk_cols - 1) - - new_visible = set() - for cr in range(cr0, cr1 + 1): - for cc in range(cc0, cc1 + 1): + import math + from .lod import compute_lod_level + + # Detect camera movement for LOD-aware load deferral + move_thresh = self._chunk_world_w * 0.1 + if self._last_cam_pos is not None: + dx = cam_x - self._last_cam_pos[0] + dy = cam_y - self._last_cam_pos[1] + self._cam_moving = (dx * dx + dy * dy) > move_thresh * move_thresh + self._last_cam_pos = (cam_x, cam_y) + + max_dist = self.max_distance + lod_dists = self._lod_distances + tile_lods = self._tile_lods # {(tr,tc): lod} when grids aligned + chunk_dists = {} + chunk_lods = {} # per-chunk LOD level + + if tile_lods is not None: + # Grids aligned: reuse LOD manager's tile assignments directly. + # tile_lods keys are (tile_row, tile_col) which map 1:1 to + # chunk (cr, cc) when tile_size == chunk_size. + new_visible = set() + max_lod = len(lod_dists) if lod_dists else 999 + for (cr, cc), lod in tile_lods.items(): + if cr >= self._n_chunk_rows or cc >= self._n_chunk_cols: + continue + if lod > max_lod: + continue + chunk_lods[(cr, cc)] = lod + cx, cy = self._chunk_center(cr, cc) + chunk_dists[(cr, cc)] = math.sqrt( + (cam_x - cx) ** 2 + (cam_y - cy) ** 2) new_visible.add((cr, cc)) - - if new_visible == self._visible: + # Cap at max_chunks, keeping closest + if len(new_visible) > self.max_chunks: + by_dist = sorted(new_visible, key=lambda c: chunk_dists[c]) + new_visible = set(by_dist[:self.max_chunks]) + elif max_dist is not None: + # Distance-aware: compute visible chunks from world-rect + from .mesh_store import chunks_for_world_rect + x0 = cam_x - max_dist + y0 = cam_y - max_dist + x1 = cam_x + max_dist + y1 = cam_y + max_dist + candidates = chunks_for_world_rect( + x0, y0, x1, y1, + self._psx, self._psy, + self._chunk_h, self._chunk_w, + self._elev_shape) + for cr, cc in candidates: + cx, cy = self._chunk_center(cr, cc) + chunk_dists[(cr, cc)] = math.sqrt( + (cam_x - cx) ** 2 + (cam_y - cy) ** 2) + if lod_dists: + max_lod = len(lod_dists) + candidates = [ + c for c in candidates + if compute_lod_level(chunk_dists[c], lod_dists) <= max_lod + ] + candidates.sort(key=lambda c: chunk_dists[c]) + new_visible = set(candidates[:self.max_chunks]) + # Compute per-chunk LOD for mesh simplification + if lod_dists: + for cr, cc in new_visible: + chunk_lods[(cr, cc)] = compute_lod_level( + chunk_dists[(cr, cc)], lod_dists) + else: + # Legacy radius-based ring + cc_cam = int(cam_x / self._psx) // self._chunk_w + cr_cam = int(cam_y / self._psy) // self._chunk_h + cr0 = max(cr_cam - self.radius, 0) + cr1 = min(cr_cam + self.radius, self._n_chunk_rows - 1) + cc0 = max(cc_cam - self.radius, 0) + cc1 = min(cc_cam + self.radius, self._n_chunk_cols - 1) + new_visible = set() + for cr in range(cr0, cr1 + 1): + for cc in range(cc0, cc1 + 1): + new_visible.add((cr, cc)) + cx, cy = self._chunk_center(cr, cc) + chunk_dists[(cr, cc)] = math.sqrt( + (cam_x - cx) ** 2 + (cam_y - cy) ** 2) + + # Check if any visible chunks are uncached (deferred from prior tick) + has_deferred = any((cr, cc) not in self._cache for cr, cc in new_visible) + + if new_visible == self._visible and not has_deferred: return False + # Evict simplification cache entries for chunks leaving visible set + departed = self._visible - new_visible + if departed and self._simplify_cache: + for k in [k for k in self._simplify_cache + if (k[0], k[1]) in departed]: + del self._simplify_cache[k] + self._visible = new_visible - # Load any uncached chunks - for cr, cc in new_visible: + # Load uncached chunks, prioritized by distance (closest first). + # Limited to per_tick_load_limit new zarr reads per tick. + uncached = [(cr, cc) for cr, cc in new_visible + if (cr, cc) not in self._cache] + if uncached and chunk_dists: + uncached.sort(key=lambda c: chunk_dists.get(c, 0)) + loads = 0 + for cr, cc in uncached: + if loads >= self.per_tick_load_limit: + break + # When moving, defer distant (LOD 1+) chunks + lod = chunk_lods.get((cr, cc)) + if lod is None and lod_dists and (cr, cc) in chunk_dists: + lod = compute_lod_level(chunk_dists[(cr, cc)], lod_dists) + if self._cam_moving and lod is not None and lod > 0: + continue self._load_chunk(cr, cc) - - # Merge visible chunks per gid - merged = {} - for gid in self._gids: - all_verts = [] - all_widths = [] - all_indices = [] - vert_offset = 0 - is_curve = False - for cr, cc in sorted(new_visible): - chunk_data = self._cache.get((cr, cc), {}) - if gid not in chunk_data: - continue - data = chunk_data[gid] + loads += 1 + + # Merge visible chunks per gid. Iterate chunks first so we only + # touch gids that actually have data (skips empty lookups). + # Per-gid accumulators: {gid: (all_verts, all_widths, all_indices, + # vert_offset, is_curve)} + merge_acc = {} + for cr, cc in sorted(new_visible): + chunk_data = self._cache.get((cr, cc)) + if not chunk_data: + continue + clod = chunk_lods.get((cr, cc), 0) + for gid, data in chunk_data.items(): if len(data) == 3: - # Curve geometry: (verts, widths, indices) verts, widths, indices = data - is_curve = True if len(indices) == 0: continue - all_widths.append(widths) + acc = merge_acc.get(gid) + if acc is None: + acc = ([], [], [], [0], True) + merge_acc[gid] = acc + acc[0].append(verts) + acc[1].append(widths) + acc[2].append(indices + acc[3][0]) + acc[3][0] += len(verts) // 3 else: verts, indices = data if len(indices) == 0: continue - all_indices.append(indices + vert_offset) - all_verts.append(verts) - vert_offset += len(verts) // 3 + if clod > 0: + verts, indices = self._get_simplified( + cr, cc, gid, clod, verts, indices) + acc = merge_acc.get(gid) + if acc is None: + acc = ([], [], [], [0], False) + merge_acc[gid] = acc + acc[0].append(verts) + acc[2].append(indices + acc[3][0]) + acc[3][0] += len(verts) // 3 + + merged = {} + for gid, (all_verts, all_widths, all_indices, _, is_curve) in merge_acc.items(): if all_verts: if is_curve: merged[gid] = (np.concatenate(all_verts), @@ -663,31 +654,23 @@ def update(self, cam_x, cam_y, viewer): else: verts, indices = data - # Re-snap Z coordinates to current terrain surface + VE. - # Meshes from zarr have Z computed from the full-res terrain. - # When terrain is subsampled, the rendered surface differs from - # the full-res values, so we re-anchor each vertex's height - # offset onto the current terrain using bilinear interpolation. + # Apply VE to Z coordinates and cache base_z for VE rescaling. + # orig_base_z and new_base_z both sample the same full-res terrain + # at the same XY positions, so (new_base_z + z_offset) == stored_z. + # The only transformation needed is: final_z = stored_z * ve. + # We still compute base_z once for the baked mesh cache (used by + # _rebuild_vertical_exaggeration to rescale without re-reading zarr). n_verts = len(verts) // 3 use_gpu = (gpu_terrain is not None and gpu_base_terrain is not None - and n_verts > 1000) + and n_verts > 10000) if use_gpu: - vx = cp.asarray(verts[0::3]) - vy = cp.asarray(verts[1::3]) - vz_stored = cp.asarray(verts[2::3]) - - orig_base_z_gpu = _bilinear_terrain_z( - gpu_base_terrain, vx, vy, base_psx, base_psy) - z_offset = vz_stored - orig_base_z_gpu - - new_base_z = _bilinear_terrain_z( - gpu_terrain, vx, vy, - viewer.pixel_spacing_x, viewer.pixel_spacing_y) - - updated_verts_gpu = cp.asarray(verts.copy()) - updated_verts_gpu[2::3] = (new_base_z + z_offset) * ve + # cp.asarray copies H→D (verts is numpy), so we can + # mutate it in-place without an extra GPU copy. + updated_verts_gpu = cp.asarray(verts) + if ve != 1.0: + updated_verts_gpu[2::3] *= ve if is_curve: rtx.add_curve_geometry( @@ -699,27 +682,20 @@ def update(self, cam_x, cam_y, viewer): if accessor is not None: accessor._geometry_colors[gid] = self._colors.get(gid, (0.6, 0.6, 0.6)) - orig_base_z_np = orig_base_z_gpu.get() + vx = cp.asarray(verts[0::3]) + vy = cp.asarray(verts[1::3]) + orig_base_z_np = _bilinear_terrain_z( + gpu_base_terrain, vx, vy, + base_psx, base_psy).get() if is_curve: accessor._baked_meshes[gid] = ( verts.copy(), widths.copy(), indices.copy(), orig_base_z_np) else: accessor._baked_meshes[gid] = (verts.copy(), indices.copy(), orig_base_z_np) else: - vx = verts[0::3] - vy = verts[1::3] - vz_stored = verts[2::3].copy() - - orig_base_z = _bilinear_terrain_z( - base_terrain_np, vx, vy, base_psx, base_psy) - z_offset = vz_stored - orig_base_z - - new_base_z = _bilinear_terrain_z( - terrain_np, vx, vy, - viewer.pixel_spacing_x, viewer.pixel_spacing_y) - updated_verts = verts.copy() - updated_verts[2::3] = (new_base_z + z_offset) * ve + if ve != 1.0: + updated_verts[2::3] *= ve if is_curve: rtx.add_curve_geometry(gid, updated_verts, widths, indices) @@ -729,6 +705,9 @@ def update(self, cam_x, cam_y, viewer): if accessor is not None: accessor._geometry_colors[gid] = self._colors.get(gid, (0.6, 0.6, 0.6)) + orig_base_z = _bilinear_terrain_z( + base_terrain_np, verts[0::3], verts[1::3], + base_psx, base_psy) if is_curve: accessor._baked_meshes[gid] = ( verts.copy(), widths.copy(), indices.copy(), orig_base_z) @@ -763,8 +742,6 @@ def update(self, cam_x, cam_y, viewer): viewer._geometry_layer_idx = 0 for geom_id in viewer._all_geometries: - if geom_id == 'terrain': - continue if layer_name == 'none': rtx.set_geometry_visible(geom_id, False) elif layer_name == 'all': @@ -970,29 +947,28 @@ def fn(v): _add_overlay(v, name, data) self._submit(fn) - def add_hydro(self, flow_dir, flow_accum, **kwargs): + def add_hydro(self, flow_accum, **kwargs): """Add hydrological flow particle visualization. - The flow grids should be computed from a depression-filled DEM - (e.g. ``xrspatial.fill()``) so particles follow coherent drainage - paths instead of getting trapped in pits. + Uses MFD (Multiple Flow Direction) to compute flow vectors from + terrain elevation so particles follow natural drainage paths + distributed across all downhill neighbors. Parameters ---------- - flow_dir : array-like, shape (H, W) - D8 flow direction grid (1=E, 2=SE, 4=S, 8=SW, 16=W, 32=NW, - 64=N, 128=NE). Compute with ``xrspatial.flow_direction()``. flow_accum : array-like, shape (H, W) Flow accumulation grid (cell counts or area). Compute with ``xrspatial.flow_accumulation()``. **kwargs Optional overrides: n_particles, max_age, trail_len, speed, - accum_threshold, color, alpha, dot_radius, min_visible_age. + accum_threshold, color, alpha, dot_radius, min_visible_age, + flow_dir_mfd (xrspatial MFD fractions), + elevation (conditioned DEM for manual MFD fallback). """ stream_order = kwargs.get('stream_order') stream_link = kwargs.get('stream_link') def fn(v): - v._init_hydro(flow_dir, flow_accum, **kwargs) + v._init_hydro(flow_accum, **kwargs) v._hydro_enabled = True # Add stream link overlay with palette-matched colors if stream_link is not None: @@ -1377,10 +1353,8 @@ class InteractiveViewer: - [/]: Decrease/increase observer height - R: Decrease terrain resolution (coarser, up to 8x subsample) - Shift+R: Increase terrain resolution (finer, down to 1x) - - Shift+A: Toggle distance-based terrain LOD - Z: Decrease vertical exaggeration - Shift+Z: Increase vertical exaggeration - - B: Toggle mesh type (TIN / voxel) - Y: Cycle color stretch (linear, sqrt, cbrt, log) - T: Toggle shadows - 0: Toggle ambient occlusion (progressive) @@ -1406,7 +1380,6 @@ def __init__(self, raster, width: int = 800, height: int = 600, render_scale: float = 0.5, key_repeat_interval: float = 0.05, rtx: 'RTX' = None, pixel_spacing_x: float = 1.0, pixel_spacing_y: float = 1.0, - mesh_type: str = 'heightfield', overlay_layers: dict = None, title: str = None, subtitle: str = None, @@ -1438,8 +1411,6 @@ def __init__(self, raster, width: int = 800, height: int = 600, Must match the spacing used when triangulating terrain. Default 1.0. pixel_spacing_y : float, optional Y spacing between pixels in world units. Default 1.0. - mesh_type : str, optional - Mesh generation method: 'tin' or 'voxel'. Default is 'tin'. """ if not has_cupy: raise ImportError( @@ -1450,7 +1421,7 @@ def __init__(self, raster, width: int = 800, height: int = 600, # Terrain state (raster, spacing, elevation stats, mesh caches) self.terrain = TerrainState( raster, pixel_spacing_x=pixel_spacing_x, - pixel_spacing_y=pixel_spacing_y, mesh_type=mesh_type, + pixel_spacing_y=pixel_spacing_y, subsample=subsample, skirt=skirt, ) @@ -1557,6 +1528,12 @@ def __init__(self, raster, width: int = 800, height: int = 600, # Hydro flow particle state self.hydro = HydroState() + self.hydro_mgr = HydroManager(self.hydro) + + # Per-tile overlay compositing (created when LOD is enabled) + self._overlay_tile_mgr = None + # Per-tile basemap texture compositing (created when LOD is enabled) + self._texture_tile_mgr = None # GTFS-RT realtime vehicle overlay state self._gtfs_rt_url = None @@ -1621,7 +1598,7 @@ def __init__(self, raster, width: int = 800, height: int = 600, new_data = self._base_raster.data.copy() new_data[cp.asarray(ocean_fill)] = cp.nan else: - new_data = self._base_raster.data.copy() + new_data = base_np.copy() new_data[ocean_fill] = np.nan self._base_raster = self._base_raster.copy(data=new_data) # Re-derive working raster from updated base @@ -1684,64 +1661,10 @@ def __init__(self, raster, width: int = 800, height: int = 600, else: self._overlay_layers = dict(self._base_overlay_layers) - # Build terrain geometry if RTX exists but has no terrain. - # Without this, render() falls into the auto-VE / prepare_mesh path - # which computes vertical_exaggeration from pixel dimensions (not world - # units), producing wrong results when pixel_spacing != 1. - if rtx is not None and not rtx.has_geometry('terrain'): - from . import mesh as mesh_mod - if mesh_type == 'heightfield': - rtx.add_heightfield_geometry( - 'terrain', terrain_np, H, W, - spacing_x=self.pixel_spacing_x, - spacing_y=self.pixel_spacing_y, - ve=1.0, - ) - if self.terrain_skirt: - sv, si = mesh_mod.build_terrain_skirt( - terrain_np, H, W, scale=1.0, - pixel_spacing_x=self.pixel_spacing_x, - pixel_spacing_y=self.pixel_spacing_y) - rtx.add_geometry('terrain_skirt', sv, si) - cache_key = (self.subsample_factor, mesh_type) - self._terrain_mesh_cache[cache_key] = ( - None, None, terrain_np.copy(), - ) - else: - if mesh_type == 'voxel': - nv = H * W * 8 - nt = H * W * 12 - verts = np.zeros(nv * 3, dtype=np.float32) - idxs = np.zeros(nt * 3, dtype=np.int32) - base_elev = float(np.nanmin(terrain_np)) - mesh_mod.voxelate_terrain(verts, idxs, raster, scale=1.0, - base_elevation=base_elev) - else: - nv = H * W - nt = (H - 1) * (W - 1) * 2 - verts = np.zeros(nv * 3, dtype=np.float32) - idxs = np.zeros(nt * 3, dtype=np.int32) - mesh_mod.triangulate_terrain(verts, idxs, raster, scale=1.0) - - # Add skirt for TIN meshes - if self.terrain_skirt: - verts, idxs = mesh_mod.add_terrain_skirt( - verts, idxs, H, W) - - if self.pixel_spacing_x != 1.0 or self.pixel_spacing_y != 1.0: - verts[0::3] *= self.pixel_spacing_x - verts[1::3] *= self.pixel_spacing_y - - cache_key = (self.subsample_factor, mesh_type) - self._terrain_mesh_cache[cache_key] = ( - verts.copy(), idxs.copy(), terrain_np.copy(), - ) - - # Only pass grid_dims for TIN meshes without skirt — - # cluster GAS requires regular grid triangle layout. - gd = (H, W) if mesh_type != 'voxel' and not self.terrain_skirt else None - rtx.add_geometry('terrain', verts, idxs, - grid_dims=gd) + # Enable LOD terrain immediately — no single-GAS terrain is built. + # LOD tiles are created lazily during the first update() call. + if rtx is not None: + self._enable_terrain_lod() # ------------------------------------------------------------------ # Delegation properties — InputState @@ -2307,6 +2230,14 @@ def _minimap_rect(self): def _minimap_rect(self, value): self.hud.minimap_rect = value + @property + def _minimap_world_extent(self): + return self.hud.minimap_world_extent + + @_minimap_world_extent.setter + def _minimap_world_extent(self, value): + self.hud.minimap_world_extent = value + @property def _minimap_style(self): return self.hud.minimap_style @@ -2331,6 +2262,22 @@ def _minimap_colors(self): def _minimap_colors(self, value): self.hud.minimap_colors = value + @property + def _minimap_bg_extent(self): + return self.hud.minimap_bg_extent + + @_minimap_bg_extent.setter + def _minimap_bg_extent(self, value): + self.hud.minimap_bg_extent = value + + @property + def _minimap_last_stream_time(self): + return self.hud.minimap_last_stream_time + + @_minimap_last_stream_time.setter + def _minimap_last_stream_time(self, value): + self.hud.minimap_last_stream_time = value + # ------------------------------------------------------------------ # Delegation properties — WindState # ------------------------------------------------------------------ @@ -2639,6 +2586,14 @@ def _hydro_data(self): def _hydro_data(self, value): self.hydro.hydro_data = value + @property + def _hydro_lazy(self): + return self.hydro.hydro_lazy + + @_hydro_lazy.setter + def _hydro_lazy(self, value): + self.hydro.hydro_lazy = value + @property def _hydro_enabled(self): return self.hydro.hydro_enabled @@ -2771,6 +2726,14 @@ def _hydro_min_depth(self): def _hydro_min_depth(self, value): self.hydro.hydro_min_depth = value + @property + def _hydro_max_depth(self): + return self.hydro.hydro_max_depth + + @_hydro_max_depth.setter + def _hydro_max_depth(self, value): + self.hydro.hydro_max_depth = value + @property def _hydro_ref_depth(self): return self.hydro.hydro_ref_depth @@ -2911,6 +2874,94 @@ def _d_hydro_radii(self): def _d_hydro_radii(self, value): self.hydro.d_hydro_radii = value + @property + def _d_hydro_particles(self): + return self.hydro.d_hydro_particles + + @_d_hydro_particles.setter + def _d_hydro_particles(self, value): + self.hydro.d_hydro_particles = value + + @property + def _d_hydro_particle_accum(self): + return self.hydro.d_hydro_particle_accum + + @_d_hydro_particle_accum.setter + def _d_hydro_particle_accum(self, value): + self.hydro.d_hydro_particle_accum = value + + @property + def _d_hydro_particle_raw_order(self): + return self.hydro.d_hydro_particle_raw_order + + @_d_hydro_particle_raw_order.setter + def _d_hydro_particle_raw_order(self, value): + self.hydro.d_hydro_particle_raw_order = value + + @property + def _d_hydro_flow_u(self): + return self.hydro.d_hydro_flow_u + + @_d_hydro_flow_u.setter + def _d_hydro_flow_u(self, value): + self.hydro.d_hydro_flow_u = value + + @property + def _d_hydro_flow_v(self): + return self.hydro.d_hydro_flow_v + + @_d_hydro_flow_v.setter + def _d_hydro_flow_v(self, value): + self.hydro.d_hydro_flow_v = value + + @property + def _d_hydro_slope_mag(self): + return self.hydro.d_hydro_slope_mag + + @_d_hydro_slope_mag.setter + def _d_hydro_slope_mag(self, value): + self.hydro.d_hydro_slope_mag = value + + @property + def _d_hydro_stream_order(self): + return self.hydro.d_hydro_stream_order + + @_d_hydro_stream_order.setter + def _d_hydro_stream_order(self, value): + self.hydro.d_hydro_stream_order = value + + @property + def _d_hydro_stream_order_raw(self): + return self.hydro.d_hydro_stream_order_raw + + @_d_hydro_stream_order_raw.setter + def _d_hydro_stream_order_raw(self, value): + self.hydro.d_hydro_stream_order_raw = value + + @property + def _d_hydro_accum_norm(self): + return self.hydro.d_hydro_accum_norm + + @_d_hydro_accum_norm.setter + def _d_hydro_accum_norm(self, value): + self.hydro.d_hydro_accum_norm = value + + @property + def _d_hydro_palette(self): + return self.hydro.d_hydro_palette + + @_d_hydro_palette.setter + def _d_hydro_palette(self, value): + self.hydro.d_hydro_palette = value + + @property + def _d_hydro_respawn_flags(self): + return self.hydro.d_hydro_respawn_flags + + @_d_hydro_respawn_flags.setter + def _d_hydro_respawn_flags(self, value): + self.hydro.d_hydro_respawn_flags = value + @property def _hydro_particle_colors(self): return self.hydro.hydro_particle_colors @@ -3115,14 +3166,6 @@ def subsample_factor(self): def subsample_factor(self, value): self.terrain.subsample_factor = value - @property - def _terrain_mesh_cache(self): - return self.terrain._terrain_mesh_cache - - @_terrain_mesh_cache.setter - def _terrain_mesh_cache(self, value): - self.terrain._terrain_mesh_cache = value - @property def _baked_mesh_cache(self): return self.terrain._baked_mesh_cache @@ -3147,13 +3190,6 @@ def _gpu_base_terrain(self): def _gpu_base_terrain(self, value): self.terrain._gpu_base_terrain = value - @property - def mesh_type(self): - return self.terrain.mesh_type - - @mesh_type.setter - def mesh_type(self, value): - self.terrain.mesh_type = value @property def terrain_skirt(self): @@ -3370,9 +3406,6 @@ def _build_title(self): res += f" ({self.subsample_factor}\u00d7 sub)" parts.append(res) - # Mesh type - parts.append(self.mesh_type.upper()) - # Terrain color layer terrain_name = self._terrain_layer_order[self._terrain_layer_idx] if terrain_name != 'elevation' and terrain_name in self._overlay_layers: @@ -3434,6 +3467,91 @@ def _build_title(self): return ' \u2502 '.join(parts) + @staticmethod + def _elevation_to_minimap_rgba(terrain_small): + """Convert a 2D elevation array to an RGBA minimap image. + + Parameters + ---------- + terrain_small : ndarray, shape (H, W), float32 + Elevation data (may contain NaN for water/nodata). + + Returns + ------- + rgba : ndarray, shape (H, W, 4), float32 + RGBA minimap image with hillshade and water coloring. + water : ndarray, shape (H, W), bool + Water mask (NaN pixels). + """ + new_h, new_w = terrain_small.shape + + # Water mask: only NaN (not <= 0, which catches valid terrain) + water = np.isnan(terrain_small) + + # Fill NaNs for gradient computation — use edge extrapolation + # to avoid step-change gradient artifacts at the data boundary + if water.any(): + terrain_small = terrain_small.copy() + # Simple iterative nearest-neighbor fill: propagate valid + # values into NaN regions to avoid gradient discontinuities. + # Cap iterations — minimap is small (~200px), so 50 is plenty. + filled = terrain_small + for _ in range(min(50, max(new_h, new_w))): + still_nan = np.isnan(filled) + if not still_nan.any(): + break + # Shift in 4 directions and average available neighbors + padded = np.pad(filled, 1, mode='edge') + neighbors = np.stack([ + padded[:-2, 1:-1], # up + padded[2:, 1:-1], # down + padded[1:-1, :-2], # left + padded[1:-1, 2:], # right + ], axis=0) + with np.errstate(all='ignore'): + fill_vals = np.nanmean(neighbors, axis=0) + filled = np.where(still_nan & np.isfinite(fill_vals), + fill_vals, filled) + terrain_small = filled + + # Hillshade (sun from upper-left) + dy, dx = np.gradient(terrain_small) + az_rad = np.radians(315) + alt_rad = np.radians(45) + slp = np.sqrt(dx**2 + dy**2) + asp = np.arctan2(-dy, dx) + shaded = (np.sin(alt_rad) * np.cos(np.arctan(slp)) + + np.cos(alt_rad) * np.sin(np.arctan(slp)) * + np.cos(az_rad - asp)) + shaded = np.clip(shaded, 0, 1) + + # Elevation tint: normalise to [0,1] for colour ramp + land = ~water + emin = np.nanmin(terrain_small[land]) if land.any() else 0 + emax = np.nanmax(terrain_small[land]) if land.any() else 1 + erng = emax - emin if emax > emin else 1.0 + elev_norm = np.clip((terrain_small - emin) / erng, 0, 1) + + # Build RGBA image + rgba = np.zeros((new_h, new_w, 4), dtype=np.float32) + + # Grayscale hillshade base for all land + grey = shaded * 0.5 + elev_norm * 0.3 + 0.1 + grey = np.clip(grey, 0, 1) + for c in range(3): + rgba[:, :, c] = grey + rgba[:, :, 3] = 1.0 + + # Water (NaN): dark blue-black + rgba[water, 0] = 0.08 + rgba[water, 1] = 0.10 + rgba[water, 2] = 0.18 + rgba[water, 3] = 0.7 + + rgba[:, :, :3] = np.clip(rgba[:, :, :3], 0, 1) + + return rgba, water + def _compute_minimap_background(self): """Compute a stylised RGBA minimap image. @@ -3462,34 +3580,7 @@ def _compute_minimap_background(self): terrain_small = terrain_np.copy() new_h, new_w = H, W - # Water mask: NaN or <= 0 - water = np.isnan(terrain_small) | (terrain_small <= 0) - - # Fill NaNs for gradient computation - if water.any(): - med = np.nanmedian(terrain_small) - terrain_small = terrain_small.copy() - terrain_small[water] = med if np.isfinite(med) else 0.0 - - # Hillshade (sun from upper-left) - dy, dx = np.gradient(terrain_small) - az_rad = np.radians(315) - alt_rad = np.radians(45) - slp = np.sqrt(dx**2 + dy**2) - asp = np.arctan2(-dy, dx) - shaded = (np.sin(alt_rad) * np.cos(np.arctan(slp)) + - np.cos(alt_rad) * np.sin(np.arctan(slp)) * - np.cos(az_rad - asp)) - shaded = np.clip(shaded, 0, 1) - - # Elevation tint: normalise to [0,1] for colour ramp - emin = np.nanmin(terrain_small[~water]) if (~water).any() else 0 - emax = np.nanmax(terrain_small[~water]) if (~water).any() else 1 - erng = emax - emin if emax > emin else 1.0 - elev_norm = np.clip((terrain_small - emin) / erng, 0, 1) - - # Build RGBA image - rgba = np.zeros((new_h, new_w, 4), dtype=np.float32) + rgba, water = self._elevation_to_minimap_rgba(terrain_small) # Check for categorical overlay layer coloring _layer_data = None @@ -3497,21 +3588,28 @@ def _compute_minimap_background(self): and self._minimap_layer in self._base_overlay_layers): ld = self._base_overlay_layers[self._minimap_layer] if hasattr(ld, 'get'): - ld = ld.get() - ld = np.asarray(ld, dtype=np.float64) - if longest > max_dim: - _layer_data = ld[np.ix_(y_idx, x_idx)] - else: - _layer_data = ld.copy() - - # Grayscale hillshade base for all land - grey = shaded * 0.5 + elev_norm * 0.3 + 0.1 - grey = np.clip(grey, 0, 1) - for c in range(3): - rgba[:, :, c] = grey - rgba[:, :, 3] = 1.0 + ld = ld.get() + ld = np.asarray(ld, dtype=np.float64) + if longest > max_dim: + _layer_data = ld[np.ix_(y_idx, x_idx)] + else: + _layer_data = ld.copy() if _layer_data is not None: + # Recompute hillshade for categorical overlay blending + ts_filled = terrain_small.copy() + if water.any(): + med = np.nanmedian(ts_filled) + ts_filled[water] = med if np.isfinite(med) else 0.0 + dy, dx = np.gradient(ts_filled) + az_rad = np.radians(315) + alt_rad = np.radians(45) + slp = np.sqrt(dx**2 + dy**2) + asp = np.arctan2(-dy, dx) + shaded = np.clip( + np.sin(alt_rad) * np.cos(np.arctan(slp)) + + np.cos(alt_rad) * np.sin(np.arctan(slp)) * + np.cos(az_rad - asp), 0, 1) # Overlay risk colours on matched pixels; unmatched stays grey for val, (r, g, b) in self._minimap_colors.items(): mask = np.isclose(_layer_data, float(val), atol=0.1) @@ -3519,14 +3617,6 @@ def _compute_minimap_background(self): rgba[:, :, c] = np.where( mask, cv * (shaded * 0.5 + 0.5), rgba[:, :, c]) - # Water: dark blue-black - rgba[water, 0] = 0.08 - rgba[water, 1] = 0.10 - rgba[water, 2] = 0.18 - rgba[water, 3] = 0.7 - - rgba[:, :, :3] = np.clip(rgba[:, :, :3], 0, 1) - # Blend satellite imagery if tile service has fetched tiles if (self._tile_service is not None and getattr(self._tile_service, '_fetched', None)): @@ -3552,6 +3642,71 @@ def _compute_minimap_background(self): self._minimap_scale_x = new_w / W self._minimap_scale_y = new_h / H + def _compute_streaming_minimap(self, wx_min, wy_min, wx_max, wy_max): + """Fetch elevation for a world extent and build a minimap image. + + Uses ``_tile_data_fn`` to fetch elevation covering the full + minimap extent (initial terrain + streaming area), then runs + the standard hillshade pipeline on it. + + Parameters + ---------- + wx_min, wy_min, wx_max, wy_max : float + World-space extent to cover in the minimap. + """ + tile_data_fn = getattr(self, '_tile_data_fn', None) + crs_tf = getattr(self, '_minimap_crs_transform', None) + if tile_data_fn is None or crs_tf is None: + return + + crs_x0, crs_y0, crs_dx, crs_dy = crs_tf + psx, psy = self.pixel_spacing_x, self.pixel_spacing_y + + # Convert world coords to CRS coordinates + # world coord: wx = col * psx + offset_x + # col = (wx - offset_x) / psx + # crs_x = crs_x0 + col * crs_dx + lod_mgr = getattr(self.terrain, '_terrain_lod_manager', None) + ox = lod_mgr._offset_x if lod_mgr is not None else 0.0 + oy = lod_mgr._offset_y if lod_mgr is not None else 0.0 + + col0 = (wx_min - ox) / psx + col1 = (wx_max - ox) / psx + row0 = (wy_min - oy) / psy + row1 = (wy_max - oy) / psy + + cx0 = crs_x0 + col0 * crs_dx + cx1 = crs_x0 + col1 * crs_dx + cy0 = crs_y0 + row0 * crs_dy + cy1 = crs_y0 + row1 * crs_dy + + # tile_data_fn expects (x_min, y_min, x_max, y_max, target_samples) + x_min, x_max = min(cx0, cx1), max(cx0, cx1) + y_min, y_max = min(cy0, cy1), max(cy0, cy1) + + try: + elev = tile_data_fn(x_min, y_min, x_max, y_max, 200) + except Exception: + return + if elev is None: + return + + elev = np.asarray(elev, dtype=np.float32) + if elev.ndim != 2 or elev.size == 0: + return + + rgba, water = self._elevation_to_minimap_rgba(elev) + + if self._minimap_style == 'cyberpunk': + rgba = self._apply_cyberpunk_minimap(rgba, water) + + self._minimap_background = rgba + self._minimap_bg_extent = (wx_min, wy_min, wx_max, wy_max) + # Scale factors map the full bg image to the world extent + self._minimap_scale_x = rgba.shape[1] / max(1, wx_max - wx_min) + self._minimap_scale_y = rgba.shape[0] / max(1, wy_max - wy_min) + self._minimap_last_stream_time = time.monotonic() + def _apply_cyberpunk_minimap(self, rgba, water): """Apply a neon-edge cyberpunk filter to the minimap RGBA image. @@ -3636,11 +3791,10 @@ def _apply_cyberpunk_minimap(self, rgba, water): return result def _enable_terrain_lod(self): - """Switch terrain rendering from single-GAS to per-tile LOD. + """Set up per-tile LOD terrain rendering. - Removes the single ``'terrain'`` (and ``'terrain_skirt'``) - geometry and creates a :class:`TerrainLODManager` that renders - each tile at a distance-appropriate resolution. + Creates a :class:`TerrainLODManager` that renders each tile at + a distance-appropriate resolution. Called once during __init__. """ from .viewer.terrain_lod import TerrainLODManager @@ -3652,15 +3806,51 @@ def _enable_terrain_lod(self): else: terrain_np = np.asarray(terrain_data) + # Fill NaN at raster edges (common from UTM reprojection) so both + # the LOD tile builder and the render kernel see clean elevation. + # Without this, NaN pixels render as blue ocean water. + if np.any(np.isnan(terrain_np)): + terrain_np = terrain_np.copy() + for _ in range(20): + still_nan = np.isnan(terrain_np) + if not still_nan.any(): + break + padded = np.pad(terrain_np, 1, mode='edge') + neighbors = np.stack([ + padded[:-2, 1:-1], padded[2:, 1:-1], + padded[1:-1, :-2], padded[1:-1, 2:], + ], axis=0) + with np.errstate(all='ignore'): + fill_vals = np.nanmean(neighbors, axis=0) + terrain_np = np.where( + still_nan & np.isfinite(fill_vals), + fill_vals, terrain_np) + # Update raster so render kernel sees clean data too + is_cupy = hasattr(base.data, 'get') + if is_cupy: + import cupy + self.raster = self.raster.copy( + data=cupy.asarray(terrain_np)) + else: + self.raster = self.raster.copy(data=terrain_np) + # Remove the single terrain geometry if self.rtx.has_geometry('terrain'): self.rtx.remove_geometry('terrain') if self.rtx.has_geometry('terrain_skirt'): self.rtx.remove_geometry('terrain_skirt') - # Choose tile size: aim for ~8-16 tiles across largest dimension + # Choose tile size. When a zarr chunk manager is active, align + # to the zarr elevation chunk size so terrain tiles and mesh chunks + # share the same spatial grid — one distance lookup drives both. H, W = terrain_np.shape - tile_size = max(32, min(256, max(H, W) // 8)) + if (self._chunk_manager is not None + and self._chunk_manager._chunk_h == self._chunk_manager._chunk_w): + tile_size = self._chunk_manager._chunk_h + print(f"LOD tile size {tile_size} (aligned to zarr chunk grid)") + else: + tile_size = max(32, min(256, max(H, W) // 8)) + print(f"LOD tile size {tile_size}") mgr = TerrainLODManager( terrain_np, @@ -3670,12 +3860,181 @@ def _enable_terrain_lod(self): max_lod=3, base_subsample=self.subsample_factor, ) + # Carry forward any world offset from a previous terrain reload + ox = self.terrain._world_offset_x + oy = self.terrain._world_offset_y + if ox != 0.0 or oy != 0.0: + mgr.set_offset(ox, oy) + # Enable tile streaming if a data callback was provided + tile_data_fn = getattr(self, '_tile_data_fn', None) + if tile_data_fn is not None: + mgr.set_tile_data_fn(tile_data_fn) + # Pass CRS coordinate transform so tile_data_fn receives + # actual CRS coordinates (e.g. UTM) instead of viewer + # world-space coords (pixel * abs(spacing) + offset). + try: + x = base.coords['x'].values + y = base.coords['y'].values + if len(x) >= 2 and len(y) >= 2: + crs_dx = float(x[1] - x[0]) + crs_dy = float(y[1] - y[0]) + mgr.set_crs_transform(float(x[0]), float(y[0]), + crs_dx, crs_dy) + self._minimap_crs_transform = ( + float(x[0]), float(y[0]), crs_dx, crs_dy) + except (KeyError, AttributeError): + pass self._terrain_lod_manager = mgr self.lod_enabled = True - # Force initial tile build - mgr.update(self.position, self.rtx, - ve=self.vertical_exaggeration, force=True) + # Create per-tile overlay and texture managers for LOD-aware compositing + from .viewer.overlay_tiles import OverlayTileManager, TextureTileManager + self._overlay_tile_mgr = OverlayTileManager(tile_size) + self._texture_tile_mgr = TextureTileManager(tile_size) + # Register tile lifecycle callbacks so both managers stay + # in sync with the LOD tile set automatically. + otm = self._overlay_tile_mgr + ttm = self._texture_tile_mgr + + # With LOD active, basemap goes through per-tile lazy fetch — stop + # any monolithic XYZ tile fetch that was started before LOD enable. + if self._tile_service is not None: + self._tile_service._generation += 1 # cancel in-flight fetches + + # Lazy per-tile basemap fetching. Each tile's CRS bounds are + # converted to WGS84 and XYZ map tiles are fetched in background + # threads, then stored in the TextureTileManager. + _crs_origin = mgr._crs_origin # (crs_x0, crs_y0) or None + _crs_spacing = mgr._crs_spacing # (crs_dx, crs_dy) or None + _ts = tile_size + _viewer = self + from concurrent.futures import ThreadPoolExecutor as _TPE + _basemap_executor = _TPE(max_workers=4) + _basemap_pending = set() # tiles currently being fetched + + def _on_tile_added(tr, tc, elev): + # Don't call otm/ttm.invalidate() here — set_tile() already + # marks dirty when actual data arrives. Invalidating on + # every terrain LOD change forces composite rebuild + GPU + # upload every frame during camera movement. + + # Lazy-fetch basemap for this tile in background. + # Use elev.shape to get actual tile dimensions (edge tiles + # may be smaller than tile_size). + if (_crs_origin is not None + and _crs_spacing is not None + and _viewer._tiles_enabled + and _viewer._tile_service is not None + and (tr, tc) not in _basemap_pending + and not ttm.has_tile(tr, tc)): + th = elev.shape[0] if elev is not None else _ts + tw = elev.shape[1] if elev is not None else _ts + _basemap_pending.add((tr, tc)) + _basemap_executor.submit( + _fetch_tile_basemap, tr, tc, th, tw) + + def _fetch_tile_basemap(tr, tc, th, tw): + """Fetch basemap RGB for a single LOD tile and store it.""" + try: + crs_x0, crs_y0 = _crs_origin + crs_dx, crs_dy = _crs_spacing + c0 = tc * _ts + r0 = tr * _ts + # CRS bounds: pixel-center to pixel-center so linspace + # produces exact pixel coordinates. Using c0+tw would + # overshoot by one pixel and stretch the basemap. + cx0 = crs_x0 + c0 * crs_dx + cy0 = crs_y0 + r0 * crs_dy + cx1 = crs_x0 + (c0 + tw - 1) * crs_dx + cy1 = crs_y0 + (r0 + th - 1) * crs_dy + x_min, x_max = min(cx0, cx1), max(cx0, cx1) + y_min, y_max = min(cy0, cy1), max(cy0, cy1) + rgb = _viewer._tile_service.fetch_rgb_for_bounds( + x_min, y_min, x_max, y_max, th, tw) + if rgb is not None and not np.all(rgb == 0): + ttm.set_tile(tr, tc, rgb) + except Exception: + pass + finally: + _basemap_pending.discard((tr, tc)) + + def _on_tile_removed(tr, tc): + # Only remove basemap texture — it's re-fetched lazily + # via _on_tile_added when the tile returns. Overlay data + # is bulk-populated from populate_from_array() and won't + # be re-created on re-add, so we must keep it. + ttm.remove_tile(tr, tc) + + mgr.set_tile_callbacks( + on_added=_on_tile_added, + on_removed=_on_tile_removed, + ) + # If an overlay already exists, slice it into per-tile chunks + if self._active_overlay_data is not None: + n_tr = (H + tile_size - 1) // tile_size + n_tc = (W + tile_size - 1) // tile_size + active_name = None + if self._terrain_layer_idx < len(self._terrain_layer_order): + active_name = self._terrain_layer_order[ + self._terrain_layer_idx] + for name, data in self._overlay_layers.items(): + if name == active_name: + self._overlay_tile_mgr.populate_from_array( + data, tile_size, n_tr, n_tc) + lut = self._overlay_color_luts.get(name) + if lut is not None: + self._overlay_tile_mgr.set_color_lut(lut) + + # Basemap tiles are fetched lazily via _on_tile_added callback — + # no monolithic texture slicing needed. + + # Wire LOD manager into HydroManager for streaming support + self.hydro_mgr.set_lod_manager(mgr) + if tile_data_fn is not None: + self.hydro_mgr.set_tile_data_fn(tile_data_fn) + try: + x = base.coords['x'].values + y = base.coords['y'].values + if len(x) >= 2 and len(y) >= 2: + self.hydro_mgr.set_crs_transform( + float(x[0]), float(y[0]), + float(x[1] - x[0]), float(y[1] - y[0])) + except (KeyError, AttributeError): + pass + + # When streaming is active, use TIN meshes for ALL tiles (including + # LOD 0) so initial-extent and streaming tiles render identically. + # Without streaming, heightfield gives better quality for LOD 0. + if not mgr._streaming: + mgr.enable_heightfield_lod0() + + # Force initial tile build — no build limit so all in-bounds + # tiles appear on the first frame (no progressive pop-in on + # enable). Streaming tiles build progressively after launch. + # Use terrain center as fallback if camera position isn't set yet + # (called from __init__ before run() sets the start position). + cam_pos = self.position + if cam_pos is None: + H, W = self.terrain_shape + cx = W * self.pixel_spacing_x * 0.5 + cy = H * self.pixel_spacing_y * 0.5 + cam_pos = np.array([cx, cy, 0.0]) + saved_limit = mgr.per_tick_build_limit + saved_streaming = mgr._streaming + mgr._streaming = False # only in-bounds tiles on initial build + mgr.per_tick_build_limit = 10000 + mgr.update(cam_pos, self.rtx, + ve=self.vertical_exaggeration, force=True, + camera_front=self._get_front(), fov=self.camera.fov) + mgr.per_tick_build_limit = saved_limit + mgr._streaming = saved_streaming + # Force one more update so streaming tiles begin building + if saved_streaming: + mgr._last_update_pos = None + # Enable threaded mesh building for subsequent ticks + mgr.enable_threaded_building() + # Batch same-LOD tiles into single GAS entries to reduce IAS count + mgr.enable_batched_upload() self._update_frame() def _rebuild_at_resolution(self, factor): @@ -3694,15 +4053,13 @@ def _rebuild_at_resolution(self, factor): self.subsample_factor = factor - # If LOD is active, update the LOD manager's base subsample - # and force a full tile rebuild instead of the single-GAS path. - if self.lod_enabled and self._terrain_lod_manager is not None: + # Update LOD manager's base subsample and force tile rebuild + if self._terrain_lod_manager is not None: self._terrain_lod_manager.set_base_subsample(factor) self._terrain_lod_manager.update( self.position, self.rtx, - ve=self.vertical_exaggeration, force=True) - # Still need to update raster/spacing for overlays and re-snapping - # but skip the single-terrain GAS rebuild below. + ve=self.vertical_exaggeration, force=True, + camera_front=self._get_front(), fov=self.camera.fov) base = self._base_raster @@ -3728,102 +4085,13 @@ def _rebuild_at_resolution(self, factor): self.pixel_spacing_x = self._base_pixel_spacing_x * factor self.pixel_spacing_y = self._base_pixel_spacing_y * factor - # 3. Build or retrieve cached terrain mesh + # 3. Get terrain_np for elevation stats ve = self.vertical_exaggeration - cache_key = (factor, self.mesh_type) - - # When LOD is active, tiles are built by the LOD manager. - # We still need terrain_np for elevation stats (computed below). - _lod_active = (self.lod_enabled and self._terrain_lod_manager is not None) - if _lod_active: - terrain_data = sub.data - if hasattr(terrain_data, 'get'): - terrain_np = terrain_data.get() - else: - terrain_np = np.asarray(terrain_data) - elif self.mesh_type == 'heightfield': - # Heightfield path: no triangle mesh needed - if cache_key in self._terrain_mesh_cache: - _, _, terrain_np = self._terrain_mesh_cache[cache_key] - else: - terrain_data = sub.data - if hasattr(terrain_data, 'get'): - terrain_np = terrain_data.get() - else: - terrain_np = np.asarray(terrain_data) - self._terrain_mesh_cache[cache_key] = ( - None, None, terrain_np.copy(), - ) - - if self.rtx is not None: - self.rtx.add_heightfield_geometry( - 'terrain', terrain_np, H, W, - spacing_x=self.pixel_spacing_x, - spacing_y=self.pixel_spacing_y, - ve=ve, - ) - if self.terrain_skirt: - sv, si = mesh_mod.build_terrain_skirt( - terrain_np, H, W, scale=ve, - pixel_spacing_x=self.pixel_spacing_x, - pixel_spacing_y=self.pixel_spacing_y) - self.rtx.add_geometry('terrain_skirt', sv, si) - elif self.rtx.has_geometry('terrain_skirt'): - self.rtx.remove_geometry('terrain_skirt') + terrain_data = sub.data + if hasattr(terrain_data, 'get'): + terrain_np = terrain_data.get() else: - if cache_key in self._terrain_mesh_cache: - # Cache hit — reuse pre-built mesh (stored at scale=1.0) - verts_base, indices, terrain_np = self._terrain_mesh_cache[cache_key] - vertices = verts_base.copy() - if ve != 1.0: - vertices[2::3] *= ve - else: - # Cache miss — build mesh at scale=1.0 and cache it - terrain_data = sub.data - if hasattr(terrain_data, 'get'): - terrain_np = terrain_data.get() - else: - terrain_np = np.asarray(terrain_data) - - if self.mesh_type == 'voxel': - num_verts = H * W * 8 - num_tris = H * W * 12 - vertices = np.zeros(num_verts * 3, dtype=np.float32) - indices = np.zeros(num_tris * 3, dtype=np.int32) - base_elev = float(np.nanmin(terrain_np)) - mesh_mod.voxelate_terrain(vertices, indices, sub, scale=1.0, - base_elevation=base_elev) - else: - num_verts = H * W - num_tris = (H - 1) * (W - 1) * 2 - vertices = np.zeros(num_verts * 3, dtype=np.float32) - indices = np.zeros(num_tris * 3, dtype=np.int32) - mesh_mod.triangulate_terrain(vertices, indices, sub, scale=1.0) - - if self.terrain_skirt: - vertices, indices = mesh_mod.add_terrain_skirt( - vertices, indices, H, W) - - # Scale x,y to world units - if self.pixel_spacing_x != 1.0 or self.pixel_spacing_y != 1.0: - vertices[0::3] *= self.pixel_spacing_x - vertices[1::3] *= self.pixel_spacing_y - - # Store in cache (scale=1.0, x/y already scaled) - self._terrain_mesh_cache[cache_key] = ( - vertices.copy(), indices.copy(), terrain_np.copy() - ) - - # Apply VE to this copy - if ve != 1.0: - vertices[2::3] *= ve - - # 4. Replace terrain geometry (add_geometry overwrites existing key - # in-place, preserving dict insertion order and instance IDs) - if self.rtx is not None: - gd = (H, W) if self.mesh_type != 'voxel' and not self.terrain_skirt else None - self.rtx.add_geometry('terrain', vertices, indices, - grid_dims=gd) + terrain_np = np.asarray(terrain_data) self.elev_min = float(np.nanmin(terrain_np)) * ve self.elev_max = float(np.nanmax(terrain_np)) * ve @@ -3861,18 +4129,22 @@ def _rebuild_at_resolution(self, factor): self._active_overlay_color_lut = self._overlay_color_luts.get( terrain_name) - # 6. Invalidate chunk manager cache (meshes need new Z coords) + # 6. Invalidate chunk manager scene state (re-snap Z at new resolution). + # Raw zarr data in _cache is resolution-independent (Phase 1 ensures + # Z re-snap always uses full-res terrain), so we keep it and only + # clear the baked/active/visible state so update() re-merges cheaply. if self._chunk_manager is not None: - # Clear chunk cache and baked mesh entries for chunk-loaded geometries for gid in list(self._chunk_manager._active_gids): if hasattr(self, '_baked_meshes'): self._baked_meshes.pop(gid, None) if self._accessor is not None: self._accessor._baked_meshes.pop(gid, None) - self._chunk_manager._cache.clear() + # Remove stale geometry from RTX scene + if self.rtx is not None and self.rtx.has_geometry(gid): + self.rtx.remove_geometry(gid) self._chunk_manager._visible.clear() self._chunk_manager._active_gids.clear() - # Force immediate reload at new resolution + # Force immediate re-merge from cache at new resolution if hasattr(self, 'position'): self._chunk_manager.update(self.position[0], self.position[1], self) @@ -3886,8 +4158,6 @@ def _rebuild_at_resolution(self, factor): self._gpu_terrain = gpu_terrain from .viewer.terrain_lod import is_terrain_lod_gid for geom_id in self.rtx.list_geometries(): - if geom_id == 'terrain' or geom_id == 'terrain_skirt': - continue if is_terrain_lod_gid(geom_id): continue # Baked meshes — re-snap Z to new terrain surface + VE @@ -3981,6 +4251,7 @@ def _rebuild_at_resolution(self, factor): self._update_observer_drone_for(obs) # 9. Recompute minimap + self._minimap_bg_extent = None self._compute_minimap_background() # 10. Clear viewshed cache (no longer matches terrain) @@ -4007,98 +4278,19 @@ def _rebuild_vertical_exaggeration(self, ve): from . import mesh as mesh_mod self.vertical_exaggeration = ve - H, W = self.terrain_shape - # If LOD is active, force a full tile rebuild with the new VE - if self.lod_enabled and self._terrain_lod_manager is not None: - self._terrain_lod_manager._tile_cache.clear() + # Force re-upload of all LOD tiles with new VE + if self._terrain_lod_manager is not None: self._terrain_lod_manager._tile_lods.clear() self._terrain_lod_manager.update( - self.position, self.rtx, ve=ve, force=True) - # Still need terrain_np for elevation stats below - terrain_data = self.raster.data - if hasattr(terrain_data, 'get'): - terrain_np = terrain_data.get() - else: - terrain_np = np.asarray(terrain_data) - elif self.mesh_type == 'heightfield': - # Heightfield path: rebuild GAS with new VE - if cache_key in self._terrain_mesh_cache: - _, _, terrain_np = self._terrain_mesh_cache[cache_key] - else: - terrain_data = self.raster.data - if hasattr(terrain_data, 'get'): - terrain_np = terrain_data.get() - else: - terrain_np = np.asarray(terrain_data) - self._terrain_mesh_cache[cache_key] = ( - None, None, terrain_np.copy(), - ) + self.position, self.rtx, ve=ve, force=True, + camera_front=self._get_front(), fov=self.camera.fov) - if self.rtx is not None: - self.rtx.add_heightfield_geometry( - 'terrain', terrain_np, H, W, - spacing_x=self.pixel_spacing_x, - spacing_y=self.pixel_spacing_y, - ve=ve, - ) - if self.terrain_skirt: - sv, si = mesh_mod.build_terrain_skirt( - terrain_np, H, W, scale=ve, - pixel_spacing_x=self.pixel_spacing_x, - pixel_spacing_y=self.pixel_spacing_y) - self.rtx.add_geometry('terrain_skirt', sv, si) - elif self.rtx.has_geometry('terrain_skirt'): - self.rtx.remove_geometry('terrain_skirt') + terrain_data = self.raster.data + if hasattr(terrain_data, 'get'): + terrain_np = terrain_data.get() else: - if cache_key in self._terrain_mesh_cache: - verts_base, indices, terrain_np = self._terrain_mesh_cache[cache_key] - vertices = verts_base.copy() - if ve != 1.0: - vertices[2::3] *= ve - else: - terrain_data = self.raster.data - if hasattr(terrain_data, 'get'): - terrain_np = terrain_data.get() - else: - terrain_np = np.asarray(terrain_data) - - if self.mesh_type == 'voxel': - nv = H * W * 8 - nt = H * W * 12 - vertices = np.zeros(nv * 3, dtype=np.float32) - indices = np.zeros(nt * 3, dtype=np.int32) - base_elev = float(np.nanmin(terrain_np)) - mesh_mod.voxelate_terrain(vertices, indices, self.raster, - scale=1.0, base_elevation=base_elev) - else: - nv = H * W - nt = (H - 1) * (W - 1) * 2 - vertices = np.zeros(nv * 3, dtype=np.float32) - indices = np.zeros(nt * 3, dtype=np.int32) - mesh_mod.triangulate_terrain(vertices, indices, self.raster, - scale=1.0) - - if self.terrain_skirt: - vertices, indices = mesh_mod.add_terrain_skirt( - vertices, indices, H, W) - - if self.pixel_spacing_x != 1.0 or self.pixel_spacing_y != 1.0: - vertices[0::3] *= self.pixel_spacing_x - vertices[1::3] *= self.pixel_spacing_y - - self._terrain_mesh_cache[cache_key] = ( - vertices.copy(), indices.copy(), terrain_np.copy() - ) - - if ve != 1.0: - vertices[2::3] *= ve - - # Replace terrain geometry (preserves dict insertion order) - if self.rtx is not None: - gd = (H, W) if self.mesh_type != 'voxel' and not self.terrain_skirt else None - self.rtx.add_geometry('terrain', vertices, indices, - grid_dims=gd) + terrain_np = np.asarray(terrain_data) # Update elevation stats (scaled) self.elev_min = float(np.nanmin(terrain_np)) * ve @@ -4123,8 +4315,6 @@ def _rebuild_vertical_exaggeration(self, ve): self._gpu_terrain = gpu_terrain from .viewer.terrain_lod import is_terrain_lod_gid for geom_id in self.rtx.list_geometries(): - if geom_id == 'terrain' or geom_id == 'terrain_skirt': - continue if is_terrain_lod_gid(geom_id): continue # Baked meshes (merged buildings/curves) — re-snap Z to terrain + VE @@ -4287,23 +4477,82 @@ def _blit_minimap_on_frame(self, img): mm_h, mm_w = mm_bg.shape[:2] fh, fw = img.shape[:2] - # Size minimap: match legend height if available, else ~20% of frame + # Size minimap: match legend height if available, else ~20% of frame. + # Use initial terrain aspect ratio (not background image shape) + # so dimensions stay fixed even when streaming changes the bg. + H_t, W_t = self.terrain_shape + terrain_aspect = W_t / max(1, H_t) if self._legend_rgba is not None: target_h = self._legend_rgba.shape[0] else: target_w = max(40, int(fw * 0.2)) - scale = target_w / mm_w - target_h = max(20, int(mm_h * scale)) - # Derive width from height to preserve aspect ratio - aspect = mm_w / mm_h - target_w = max(20, int(target_h * aspect)) + target_h = max(20, int(target_w / terrain_aspect)) + target_w = max(20, int(target_h * terrain_aspect)) target_w = min(target_w, fw) target_h = min(target_h, fh) - # Nearest-neighbour resize + # --- World extent for coordinate mapping --- + # When LOD streaming is active, extend beyond initial terrain + # to include camera position so the dot stays visible. + H, W = self.terrain_shape + psx, psy = self.pixel_spacing_x, self.pixel_spacing_y + terrain_wx = W * psx + terrain_wy = H * psy + + lod_mgr = getattr(self.terrain, '_terrain_lod_manager', None) + streaming = (lod_mgr is not None + and getattr(lod_mgr, '_streaming', False)) + + if streaming: + cam_x, cam_y = self.position[0], self.position[1] + # Desired extent: union of terrain bounds and camera + margin + margin = max(terrain_wx, terrain_wy) * 0.5 + wx_min = min(0.0, cam_x - margin) + wy_min = min(0.0, cam_y - margin) + wx_max = max(terrain_wx, cam_x + margin) + wy_max = max(terrain_wy, cam_y + margin) + + # Check if streaming minimap recompute is needed: + # - No extent yet (first time) + # - Camera outside inner 60% of current extent AND 2s throttle + bg_ext = self._minimap_bg_extent + now = time.monotonic() + throttle_ok = (now - self._minimap_last_stream_time >= 2.0) + need_recompute = (bg_ext is None) + if not need_recompute and throttle_ok: + # Recompute if camera is outside inner 60% of extent + bw = bg_ext[2] - bg_ext[0] + bh = bg_ext[3] - bg_ext[1] + inner_margin_x = bw * 0.2 + inner_margin_y = bh * 0.2 + need_recompute = ( + cam_x < bg_ext[0] + inner_margin_x + or cam_x > bg_ext[2] - inner_margin_x + or cam_y < bg_ext[1] + inner_margin_y + or cam_y > bg_ext[3] - inner_margin_y) + + if need_recompute: + self._compute_streaming_minimap( + wx_min, wy_min, wx_max, wy_max) + # Re-read background after streaming recompute + mm_bg = self._minimap_background + mm_h, mm_w = mm_bg.shape[:2] + + # Use streaming extent for coordinate mapping + if self._minimap_bg_extent is not None: + wx_min, wy_min, wx_max, wy_max = self._minimap_bg_extent + else: + wx_min, wy_min = 0.0, 0.0 + wx_max, wy_max = terrain_wx, terrain_wy + + wx_range = wx_max - wx_min + wy_range = wy_max - wy_min + + # Resize background to target dimensions (fixed size, no aspect + # ratio adjustment — keeps minimap dimensions stable) y_idx = np.linspace(0, mm_h - 1, target_h).astype(int) x_idx = np.linspace(0, mm_w - 1, target_w).astype(int) - bg_resized = mm_bg[np.ix_(y_idx, x_idx)].copy() # (th, tw, 4) + bg_resized = mm_bg[np.ix_(y_idx, x_idx)].copy() # Placement: flush bottom-right y0 = fh - target_h @@ -4315,16 +4564,13 @@ def _blit_minimap_on_frame(self, img): region = img[y0:y0+target_h, x0:x0+target_w] region[:] = region * (1 - alpha) + rgb * alpha - # Store minimap rect for click-to-teleport + # Store minimap rect and world extent for click-to-teleport self._minimap_rect = (x0, y0, target_w, target_h) + self._minimap_world_extent = (wx_min, wy_min, wx_max, wy_max) - # --- Terrain footprint (visible area quad) --- - H, W = self.terrain_shape - cam_col = self.position[0] / self.pixel_spacing_x - cam_row = self.position[1] / self.pixel_spacing_y - # Minimap local coords - lx = cam_col / W * target_w - ly = cam_row / H * target_h + # --- Camera position in minimap-local coords --- + lx = (self.position[0] - wx_min) / wx_range * target_w + ly = (self.position[1] - wy_min) / wy_range * target_h ve = self.vertical_exaggeration terrain_z = self._get_terrain_z(self.position[0], self.position[1]) * ve @@ -4366,8 +4612,8 @@ def _blit_minimap_on_frame(self, img): mm_corners = [] break # Convert world XY to minimap-local coords - mcol = hit[0] / self.pixel_spacing_x / W * target_w - mrow = hit[1] / self.pixel_spacing_y / H * target_h + mcol = (hit[0] - wx_min) / wx_range * target_w + mrow = (hit[1] - wy_min) / wy_range * target_h mm_corners.append((mcol, mrow)) if len(mm_corners) == 4: @@ -4405,8 +4651,8 @@ def _blit_minimap_on_frame(self, img): if obs.position is None: continue obs_x, obs_y = obs.position - obs_lx = (obs_x / self.pixel_spacing_x) / W * target_w - obs_ly = (obs_y / self.pixel_spacing_y) / H * target_h + obs_lx = (obs_x - wx_min) / wx_range * target_w + obs_ly = (obs_y - wy_min) / wy_range * target_h r = 4 if slot == self._active_observer else 2 self._draw_dot(img, obs_lx, obs_ly, x0, y0, target_w, target_h, color=np.array(obs.color), radius=r) @@ -5438,8 +5684,14 @@ def _splat_wind_gpu(self, d_frame): def _toggle_hydro(self): """Toggle hydro flow particles + stream_link water overlay together.""" if self._hydro_data is None: - print("No hydro data. Use v.add_hydro(flow_dir, flow_accum).") - return + # Lazy mode: compute hydro from terrain on first enable + if self._hydro_lazy: + print("Computing hydro from terrain (first enable)...") + if not self._compute_hydro_from_terrain(): + return + else: + print("No hydro data. Use v.add_hydro(flow_accum).") + return self._hydro_enabled = not self._hydro_enabled if self._hydro_enabled: @@ -5479,509 +5731,138 @@ def _toggle_hydro(self): def _action_toggle_hydro(self): self._toggle_hydro() - _STREAM_ORDER_PALETTE = np.array([ - [0.0, 0.0, 0.0 ], # 0: unused - [0.50, 0.80, 1.00], # 1: pale sky blue (headwaters) - [0.38, 0.68, 0.98], # 2: light blue - [0.28, 0.55, 0.95], # 3: sky blue - [0.18, 0.42, 0.90], # 4: medium blue - [0.10, 0.30, 0.85], # 5: royal blue - [0.06, 0.20, 0.78], # 6: deep blue - [0.03, 0.12, 0.70], # 7: dark blue - [0.01, 0.06, 0.60], # 8: navy (major rivers) - ], dtype=np.float32) - - @staticmethod - def _build_stream_palette_lut(max_order): - """Build a 256-entry color LUT for stream order overlay rendering. - - Maps normalized [0,1] overlay values back to integer orders - and looks up the categorical palette color. - """ - palette = InteractiveViewer._STREAM_ORDER_PALETTE - lut = np.zeros((256, 3), dtype=np.float32) - denom = max(max_order - 1, 1) - for i in range(256): - # Reverse the normalization: order = 1 + norm * (max_order - 1) - order = int(round(1 + (i / 255.0) * denom)) - order = max(1, min(8, order)) - lut[i] = palette[order] - return lut - - @staticmethod - def _hydro_color_from_order(order_norm, raw_order=None): - """Map stream order → (R, G, B) per particle. - - If *raw_order* is provided, uses categorical palette keyed by - integer Strahler order. Otherwise falls back to continuous - blue gradient from normalized [0,1] order. - """ - if raw_order is not None: - idx = np.clip(raw_order, 1, 8).astype(int) - colors = InteractiveViewer._STREAM_ORDER_PALETTE[idx].copy() - else: - colors = np.empty((len(order_norm), 3), dtype=np.float32) - colors[:, 0] = 0.02 + order_norm * 0.43 # R: 0.02 → 0.45 - colors[:, 1] = 0.10 + order_norm * 0.65 # G: 0.10 → 0.75 - colors[:, 2] = 0.55 + order_norm * 0.40 # B: 0.55 → 0.95 - colors = np.clip(colors, 0.0, 1.0) - return colors + # Delegate palette/helpers to hydro_manager module + from .viewer.hydro_manager import ( + STREAM_ORDER_PALETTE as _STREAM_ORDER_PALETTE, + build_stream_palette_lut as _build_stream_palette_lut_fn, + color_from_order as _hydro_color_from_order_fn, + radius_from_order as _hydro_radius_from_order_fn, + ) - @staticmethod - def _hydro_radius_from_order(order_norm, raw_order=None): - """Map stream order → radius (2–5) per particle. + _build_stream_palette_lut = staticmethod(_build_stream_palette_lut_fn) + _hydro_color_from_order = staticmethod(_hydro_color_from_order_fn) + _hydro_radius_from_order = staticmethod(_hydro_radius_from_order_fn) - If *raw_order* is provided, uses integer order + 1 directly - (clamped 2–5). Otherwise uses continuous mapping. - """ - if raw_order is not None: - return np.clip(raw_order + 1, 2, 5).astype(np.int32) - return np.clip(2 + (order_norm * 3).astype(np.int32), - 2, 5).astype(np.int32) + def _init_hydro(self, flow_accum, **kwargs): + """Initialize hydro flow particles using MFD flow direction. - def _init_hydro(self, flow_dir, flow_accum, **kwargs): - """Initialize hydro flow particles from D8 flow direction and accumulation grids. + Delegates to ``self.hydro_mgr.init_from_flow()``. Parameters ---------- - flow_dir : array-like, shape (H, W) - D8 flow direction grid (1=E, 2=SE, 4=S, 8=SW, 16=W, 32=NW, - 64=N, 128=NE). flow_accum : array-like, shape (H, W) Flow accumulation grid (cell counts or area). **kwargs Optional overrides: n_particles, max_age, trail_len, speed, accum_threshold, color, alpha, dot_radius, min_visible_age, - stream_order (array). + stream_order (array), flow_dir_mfd (xrspatial MFD fractions, + shape (8,H,W)), elevation (conditioned DEM for manual MFD). """ - # Accept CuPy or NumPy arrays — particle advection runs on CPU - if hasattr(flow_dir, 'get'): - flow_dir = flow_dir.get() - if hasattr(flow_accum, 'get'): - flow_accum = flow_accum.get() - flow_dir = np.asarray(flow_dir, dtype=np.int32) - flow_accum = np.asarray(flow_accum, dtype=np.float64) - H, W = flow_dir.shape - - # Stream order grid (optional but strongly recommended) - stream_order = kwargs.pop('stream_order', None) - if stream_order is not None: - if hasattr(stream_order, 'get'): - stream_order = stream_order.get() - stream_order = np.asarray(stream_order, dtype=np.float64) - # Replace NaN with 0 (non-stream cells) - stream_order = np.nan_to_num(stream_order, nan=0.0) - has_stream_order = stream_order is not None and (stream_order > 0).any() - - # Mark hydro as initialised (don't hold the full grids — - # particle advection uses _hydro_flow_u_px / _hydro_flow_v_px). - self._hydro_data = True - - # Apply optional overrides - for key, attr, conv in [ - ('n_particles', '_hydro_n_particles', int), - ('max_age', None, int), - ('trail_len', '_hydro_trail_len', int), - ('speed', '_hydro_speed', float), - ('accum_threshold', '_hydro_accum_threshold', int), - ('color', '_hydro_color', tuple), - ('alpha', '_hydro_alpha', float), - ('dot_radius', '_hydro_dot_radius', int), - ('min_visible_age', None, int), - ]: - if key in kwargs: - val = conv(kwargs[key]) - if attr: - setattr(self, attr, val) - elif key == 'max_age': - self.hydro.hydro_max_age = val - elif key == 'min_visible_age': - self.hydro.hydro_min_visible_age = val - - # D8 code → (drow, dcol) unit vectors - # Row increases downward (south), col increases rightward (east) - sqrt2_inv = 1.0 / np.sqrt(2.0) - d8_to_drow_dcol = { - 1: (0.0, 1.0), # E - 2: (sqrt2_inv, sqrt2_inv), # SE - 4: (1.0, 0.0), # S - 8: (sqrt2_inv, -sqrt2_inv),# SW - 16: (0.0, -1.0), # W - 32: (-sqrt2_inv, -sqrt2_inv),# NW - 64: (-1.0, 0.0), # N - 128: (-sqrt2_inv, sqrt2_inv),# NE - } - - flow_u = np.zeros((H, W), dtype=np.float32) # col direction - flow_v = np.zeros((H, W), dtype=np.float32) # row direction - for code, (dr, dc) in d8_to_drow_dcol.items(): - mask = flow_dir == code - flow_v[mask] = dr - flow_u[mask] = dc - valid_flow = np.isin(flow_dir, list(d8_to_drow_dcol.keys())) - - self._hydro_flow_u_px = flow_u - self._hydro_flow_v_px = flow_v - - # Normalize accumulation: log10(clip(fa, 1)), scale to [0,1] - fa_clipped = np.clip(flow_accum, 1, None) - log_fa = np.log10(fa_clipped) - threshold = np.log10(max(self._hydro_accum_threshold, 1)) - log_max = log_fa.max() - if log_max > threshold: - accum_norm = np.clip((log_fa - threshold) / (log_max - threshold), 0, 1) - else: - accum_norm = np.zeros_like(log_fa) - self._hydro_flow_accum_norm = accum_norm.astype(np.float32) - - # Store stream order grids - if has_stream_order: - max_order = stream_order.max() - so_norm = (stream_order / max(max_order, 1)).astype(np.float32) - self._hydro_stream_order = so_norm - self._hydro_stream_order_raw = stream_order.astype(np.int32) - print(f" Stream order: max {int(max_order)}, " - f"{int((stream_order > 0).sum())} stream cells") - else: - self._hydro_stream_order = None - self._hydro_stream_order_raw = None - - # Accept and store stream_link grid - stream_link_grid = kwargs.pop('stream_link', None) - if stream_link_grid is not None: - if hasattr(stream_link_grid, 'get'): - stream_link_grid = stream_link_grid.get() - self._hydro_stream_link = np.nan_to_num( - np.asarray(stream_link_grid, dtype=np.float64), nan=0.0 - ).astype(np.int32) - else: - self._hydro_stream_link = None - - # Build spawn probabilities — stream-order weighted if available - if has_stream_order: - # Spawn on stream cells, weighted by sqrt(order) — much - # flatter than order^2, giving real coverage to headwaters - spawn_weights = np.where(stream_order > 0, - np.sqrt(stream_order), 0.0) - spawn_weights[~valid_flow] = 0.0 - else: - spawn_weights = accum_norm.copy() - spawn_weights[~valid_flow] = 0.0 - - # Rasterize Overture waterway LineStrings into spawn pool - # and stream_link overlay (for unified water shader rendering). - waterway_geojson = kwargs.pop('waterway_geojson', None) - if waterway_geojson is not None and has_stream_order: - _WATERWAY_ORDER = { - 'river': (5, 3.0), 'canal': (4, 2.5), - 'stream': (2, 1.5), 'drain': (1, 1.0), 'ditch': (1, 1.0), - } - so_raw = self._hydro_stream_order_raw - sl_grid = self._hydro_stream_link - n_ww_cells = 0 - # Use a synthetic link ID for waterway cells not already - # in the stream network. - _ww_link_id = (int(sl_grid.max()) + 1) if sl_grid is not None else 1 - from .geojson import ( - _geojson_to_world_coords, _build_transformer, - ) - terrain_data_np = self.raster.data - if hasattr(terrain_data_np, 'get'): - terrain_data_np = terrain_data_np.get() - terrain_data_np = np.asarray(terrain_data_np, dtype=np.float32) - try: - transformer = _build_transformer(self.raster) - except Exception: - transformer = None - - def _burn_pixels(rows, cols, eq_order, eq_weight): - """Burn a set of (row, col) pixels into hydro grids.""" - nonlocal n_ww_cells - for rr, cc in zip(rows, cols): - if 0 <= rr < H and 0 <= cc < W: - # Upgrade raw order (don't downgrade) - if so_raw[rr, cc] < eq_order: - so_raw[rr, cc] = eq_order - # Upgrade spawn weight - if eq_weight > spawn_weights[rr, cc]: - spawn_weights[rr, cc] = eq_weight - # Ensure cell appears in stream_link overlay - if sl_grid is not None and sl_grid[rr, cc] <= 0: - sl_grid[rr, cc] = _ww_link_id - n_ww_cells += 1 - - def _coords_to_pixels(coords): - """Convert lon/lat coords to (col, row) pixel pairs.""" - try: - _, px = _geojson_to_world_coords( - coords, self.raster, terrain_data_np, - self._base_pixel_spacing_x, - self._base_pixel_spacing_y, - transformer=transformer, - return_pixel_coords=True) - return px - except Exception: - return [] - - def _densify_line(pixel_coords): - """Walk a polyline at 1-pixel steps, return (rows, cols).""" - rows, cols = [], [] - for i in range(len(pixel_coords) - 1): - c0, r0 = pixel_coords[i] - c1, r1 = pixel_coords[i + 1] - dc, dr = c1 - c0, r1 - r0 - n_steps = max(int(max(abs(dr), abs(dc))), 1) - for s in range(n_steps + 1): - t = s / n_steps - rows.append(int(round(r0 + dr * t))) - cols.append(int(round(c0 + dc * t))) - return rows, cols - - for feat in waterway_geojson.get('features', []): - geom = feat.get('geometry', {}) - gtype = geom.get('type', '') - subtype = (feat.get('properties') or {}).get('subtype', '') - eq_order, eq_weight = _WATERWAY_ORDER.get( - subtype, (2, 1.5)) - - if gtype == 'LineString': - coords = geom.get('coordinates', []) - if len(coords) < 2: - continue - px = _coords_to_pixels(coords) - if len(px) < 2: - continue - rs, cs = _densify_line(px) - _burn_pixels(rs, cs, eq_order, eq_weight) - - elif gtype in ('Polygon', 'MultiPolygon'): - # Water bodies: burn outline + filled interior - rings = [] - if gtype == 'Polygon': - rings = geom.get('coordinates', []) - else: - for poly in geom.get('coordinates', []): - rings.extend(poly) - # Lakes/reservoirs get high order - poly_order = max(eq_order, 5) - poly_weight = max(eq_weight, 3.0) - for ring in rings: - if len(ring) < 3: - continue - px = _coords_to_pixels(ring) - if len(px) < 3: - continue - # Outline - rs, cs = _densify_line(px) - _burn_pixels(rs, cs, poly_order, poly_weight) - # Fill interior via scanline - pr = np.array([p[1] for p in px]) - pc = np.array([p[0] for p in px]) - r_min = max(int(pr.min()), 0) - r_max = min(int(pr.max()), H - 1) - for row in range(r_min, r_max + 1): - # Find x-intersections of scanline with edges - xings = [] - n_verts = len(pr) - for j in range(n_verts): - j1 = (j + 1) % n_verts - r0, r1 = pr[j], pr[j1] - if (r0 <= row < r1) or (r1 <= row < r0): - t = (row - r0) / (r1 - r0) - xings.append(pc[j] + t * (pc[j1] - pc[j])) - xings.sort() - # Fill between pairs - for k in range(0, len(xings) - 1, 2): - c_lo = max(int(round(xings[k])), 0) - c_hi = min(int(round(xings[k + 1])), W - 1) - if c_lo <= c_hi: - fill_rs = [row] * (c_hi - c_lo + 1) - fill_cs = list(range(c_lo, c_hi + 1)) - _burn_pixels(fill_rs, fill_cs, - poly_order, poly_weight) - if n_ww_cells > 0: - print(f" Waterway rasterization: {n_ww_cells} cells injected") - - flat_weights = spawn_weights.ravel() - valid_mask = flat_weights > 0 - valid_indices = np.nonzero(valid_mask)[0] - if len(valid_indices) > 0: - valid_probs = flat_weights[valid_indices].astype(np.float64) - valid_probs /= valid_probs.sum() - else: - valid_flow_flat = valid_flow.ravel() - valid_indices = np.nonzero(valid_flow_flat)[0] - if len(valid_indices) > 0: - valid_probs = np.ones(len(valid_indices), dtype=np.float64) - valid_probs /= valid_probs.sum() - else: - valid_indices = np.arange(H * W) - valid_probs = np.ones(H * W, dtype=np.float64) / (H * W) - self._hydro_spawn_indices = valid_indices - self._hydro_spawn_valid_probs = valid_probs - - # Spawn initial particles - N = self._hydro_n_particles - chosen = np.random.choice(len(valid_indices), N, p=valid_probs) - indices = valid_indices[chosen] - rows = (indices // W).astype(np.float32) + np.random.uniform(-0.5, 0.5, N).astype(np.float32) - cols = (indices % W).astype(np.float32) + np.random.uniform(-0.5, 0.5, N).astype(np.float32) - rows = np.clip(rows, 0, H - 1) - cols = np.clip(cols, 0, W - 1) - - self._hydro_particles = np.column_stack([rows, cols]).astype(np.float32) - self._hydro_ages = np.random.randint(0, self._hydro_max_age, N).astype(np.int32) - self._hydro_lifetimes = np.random.randint( - self._hydro_max_age // 2, self._hydro_max_age, N).astype(np.int32) - self._hydro_trails = np.zeros( - (N, self._hydro_trail_len, 2), dtype=np.float32) - for t in range(self._hydro_trail_len): - self._hydro_trails[:, t, :] = self._hydro_particles - - # Compute terrain slope magnitude for speed modulation terrain_data = self.raster.data - if hasattr(terrain_data, 'get'): - elev = terrain_data.get().astype(np.float64) - else: - elev = np.asarray(terrain_data, dtype=np.float64) - grad_row, grad_col = np.gradient(np.nan_to_num(elev, nan=0.0)) - slope_mag = np.sqrt(grad_row**2 + grad_col**2).astype(np.float32) - p95 = np.percentile(slope_mag[slope_mag > 0], 95) if (slope_mag > 0).any() else 1.0 - slope_norm = np.clip(slope_mag / max(p95, 1e-6), 0, 1).astype(np.float32) - self._hydro_slope_mag = slope_norm - - # Per-particle visual properties from stream order (or accum fallback) - r_idx = np.clip(np.floor(rows).astype(int), 0, H - 1) - c_idx = np.clip(np.floor(cols).astype(int), 0, W - 1) - if has_stream_order: - order_val = so_norm[r_idx, c_idx].astype(np.float32) - raw_order = self._hydro_stream_order_raw[r_idx, c_idx] - else: - order_val = accum_norm[r_idx, c_idx].astype(np.float32) - raw_order = None - self._hydro_particle_accum = order_val - self._hydro_particle_raw_order = raw_order - self._hydro_particle_colors = self._hydro_color_from_order( - order_val, raw_order=raw_order) - self._hydro_particle_radii = self._hydro_radius_from_order( - order_val, raw_order=raw_order) - - # Min render distance and depth-scaled alpha reference - world_diag = np.sqrt((W * self._base_pixel_spacing_x)**2 + - (H * self._base_pixel_spacing_y)**2) - self._hydro_min_depth = 1.0 # metres — allow building-level zoom - self._hydro_ref_depth = world_diag * 0.15 + self.hydro_mgr.init_from_flow( + flow_accum, terrain_data, + self._base_pixel_spacing_x, self._base_pixel_spacing_y, + **kwargs) - print(f" Hydro flow initialized on {H}x{W} grid " - f"({N} particles, threshold={self._hydro_accum_threshold})") + def _compute_hydro_from_terrain(self): + """Compute hydrological flow from current terrain on GPU. - def _update_hydro_particles(self): - """Advect hydro particles one tick using D8 flow direction lookup.""" - if self._hydro_flow_u_px is None or self._hydro_particles is None: - return + Delegates to ``self.hydro_mgr.compute_from_terrain()``. + Returns True on success, False on failure. + """ + self.hydro_mgr.set_terrain_ref( + None, self._base_pixel_spacing_x, self._base_pixel_spacing_y) + try: + result = self.hydro_mgr.compute_from_terrain(self.raster) + except Exception as exc: + print(f"Hydro computation failed: {exc}") + return False + if result is None: + return False - H, W = self._hydro_flow_u_px.shape - pts = self._hydro_particles # (N, 2) — (row, col) + # Register stream_link overlay with palette coloring + if 'stream_link_overlay' in result: + overlay = result['stream_link_overlay'] + _add_overlay(self, 'stream_link', overlay, + color_lut=result['palette_lut']) + # Populate per-tile overlay manager for LOD rendering + if self._overlay_tile_mgr is not None: + mgr = self._terrain_lod_manager + ts = mgr._tile_size if mgr is not None else 128 + H, W = overlay.shape + n_tr = (H + ts - 1) // ts + n_tc = (W + ts - 1) // ts + self._overlay_tile_mgr.populate_from_array( + overlay, ts, n_tr, n_tc) + self._overlay_tile_mgr.set_color_lut( + result['palette_lut']) + n_tiles = len(self._overlay_tile_mgr._tile_overlays) + print(f" Overlay tiles: {n_tiles}/{n_tr*n_tc} " + f"(ts={ts}, grid={n_tr}x{n_tc}, " + f"overlay={H}x{W})") + else: + print(" Warning: no stream_link_overlay in hydro result") + return True - # Shift trail buffer (drop oldest, prepend current position) - self._hydro_trails[:, 1:, :] = self._hydro_trails[:, :-1, :] - self._hydro_trails[:, 0, :] = pts + def _transfer_streaming_overlay(self): + """Transfer pending streaming stream overlay to the overlay tile mgr. - # Nearest-neighbor D8 lookup (discrete — no interpolation) - rows = pts[:, 0] - cols = pts[:, 1] - r_idx = np.clip(np.floor(np.nan_to_num(rows, nan=0.0)).astype(int), 0, H - 1) - c_idx = np.clip(np.floor(np.nan_to_num(cols, nan=0.0)).astype(int), 0, W - 1) - - u_val = self._hydro_flow_u_px[r_idx, c_idx] - v_val = self._hydro_flow_v_px[r_idx, c_idx] - - # Max-track per-particle visual weight (stream order or accum). - # Particles only get brighter as they flow into bigger streams. - so = self._hydro_stream_order - so_raw = self._hydro_stream_order_raw - if so is not None: - current_val = so[r_idx, c_idx] - else: - current_val = self._hydro_flow_accum_norm[r_idx, c_idx] - old_val = self._hydro_particle_accum.copy() - np.maximum(old_val, current_val, out=self._hydro_particle_accum) - # Track raw integer order alongside normalized value - if so_raw is not None and self._hydro_particle_raw_order is not None: - current_raw = so_raw[r_idx, c_idx] - np.maximum(self._hydro_particle_raw_order, current_raw, - out=self._hydro_particle_raw_order) - changed = self._hydro_particle_accum > old_val - if changed.any(): - a = self._hydro_particle_accum[changed] - raw_o = (self._hydro_particle_raw_order[changed] - if self._hydro_particle_raw_order is not None - else None) - self._hydro_particle_colors[changed] = \ - self._hydro_color_from_order(a, raw_order=raw_o) - self._hydro_particle_radii[changed] = \ - self._hydro_radius_from_order(a, raw_order=raw_o) - - # Small random jitter for visual variety - jitter = np.random.uniform(-0.1, 0.1, pts.shape).astype(np.float32) - - # Slope-based speed: steeper terrain → faster flow - # Base speed 0.3 + 0.7 * slope so even flat areas move - slope_factor = np.ones(len(r_idx), dtype=np.float32) - if self._hydro_slope_mag is not None: - slope_factor = 0.3 + 0.7 * self._hydro_slope_mag[r_idx, c_idx] - - # Advect - speed = self._hydro_speed - dt_scale = getattr(self, '_dt_scale', 1.0) - pts[:, 0] += (v_val + jitter[:, 0]) * speed * dt_scale * slope_factor - pts[:, 1] += (u_val + jitter[:, 1]) * speed * dt_scale * slope_factor + Called after hydro_mgr.check_streaming_result(). Only sets tiles + that are outside the initial terrain grid (streaming tiles) to + avoid overwriting the more accurate xrspatial-computed overlay. + """ + otm = self._overlay_tile_mgr + if otm is None: + return + # Only transfer when overlay is the active layer + if self._active_overlay_data is None: + return + overlay, win_r0, win_c0 = self.hydro_mgr.pop_streaming_overlay() + if overlay is None: + return - # Age particles - self._hydro_ages += 1 + mgr = self._terrain_lod_manager + if mgr is None: + return - # Respawn: OOB, aged-out, or stuck (zero velocity = pit/sink) - nan_pos = np.isnan(pts[:, 0]) | np.isnan(pts[:, 1]) - oob = nan_pos | (pts[:, 0] < 0) | (pts[:, 0] >= H) | (pts[:, 1] < 0) | (pts[:, 1] >= W) - old = self._hydro_ages >= self._hydro_lifetimes - stuck = (u_val == 0) & (v_val == 0) - respawn = oob | old | stuck + ts = mgr._tile_size + n_tr_initial = mgr._n_tile_rows + n_tc_initial = mgr._n_tile_cols + ov_h, ov_w = overlay.shape + + # Slice overlay into per-tile chunks, only for streaming tiles + count = 0 + # Tile range covered by this window + tr_start = win_r0 // ts + tc_start = win_c0 // ts + tr_end = (win_r0 + ov_h + ts - 1) // ts + tc_end = (win_c0 + ov_w + ts - 1) // ts + + for tr in range(tr_start, tr_end): + for tc in range(tc_start, tc_end): + # Skip initial-grid tiles — they have accurate overlay + if 0 <= tr < n_tr_initial and 0 <= tc < n_tc_initial: + continue + # Extract tile region from overlay + r0 = tr * ts - win_r0 + c0 = tc * ts - win_c0 + r1 = min(r0 + ts, ov_h) + c1 = min(c0 + ts, ov_w) + if r0 < 0 or c0 < 0 or r1 <= r0 or c1 <= c0: + continue + tile_data = overlay[r0:r1, c0:c1] + if np.all(np.isnan(tile_data)): + continue + otm.set_tile(tr, tc, tile_data.copy()) + count += 1 - n_respawn = int(respawn.sum()) - if n_respawn > 0: - chosen = np.random.choice( - len(self._hydro_spawn_indices), n_respawn, - p=self._hydro_spawn_valid_probs) - indices = self._hydro_spawn_indices[chosen] - pts[respawn, 0] = (indices // W).astype(np.float32) + np.random.uniform(-0.5, 0.5, n_respawn).astype(np.float32) - pts[respawn, 1] = (indices % W).astype(np.float32) + np.random.uniform(-0.5, 0.5, n_respawn).astype(np.float32) - pts[respawn, 0] = np.clip(pts[respawn, 0], 0, H - 1) - pts[respawn, 1] = np.clip(pts[respawn, 1], 0, W - 1) - self._hydro_ages[respawn] = 0 - self._hydro_lifetimes[respawn] = np.random.randint( - self._hydro_max_age // 2, self._hydro_max_age, n_respawn) - for t in range(self._hydro_trail_len): - self._hydro_trails[respawn, t, :] = pts[respawn] - # Reset visual weight, color, radius for respawned particles - r_new = np.clip(np.floor(pts[respawn, 0]).astype(int), 0, H - 1) - c_new = np.clip(np.floor(pts[respawn, 1]).astype(int), 0, W - 1) - so = self._hydro_stream_order - so_raw = self._hydro_stream_order_raw - if so is not None: - new_val = so[r_new, c_new] - else: - new_val = self._hydro_flow_accum_norm[r_new, c_new] - self._hydro_particle_accum[respawn] = new_val - # Reset raw order for respawned particles - if so_raw is not None and self._hydro_particle_raw_order is not None: - self._hydro_particle_raw_order[respawn] = so_raw[r_new, c_new] - raw_o = self._hydro_particle_raw_order[respawn] - else: - raw_o = None - self._hydro_particle_colors[respawn] = \ - self._hydro_color_from_order(new_val, raw_order=raw_o) - self._hydro_particle_radii[respawn] = \ - self._hydro_radius_from_order(new_val, raw_order=raw_o) + if count > 0: + print(f" Streaming overlay: {count} tiles added") + + def _update_hydro_particles(self): + """Advect hydro particles one tick on GPU. Delegates to HydroManager.""" + dt_scale = float(getattr(self, '_dt_scale', 1.0)) + self.hydro_mgr.update_particles(dt_scale=dt_scale) def _draw_hydro_on_frame(self, img): """Project hydro particles to screen space and draw on rendered frame. @@ -6140,96 +6021,24 @@ def _draw_hydro_on_frame(self, img): return img def _splat_hydro_gpu(self, d_frame): - """Project and splat hydro particles on GPU via Numba CUDA kernel. - - Alpha is computed entirely on GPU from per-particle ages/lifetimes — - no CPU tile/repeat/clip overhead. Colors/radii are N-sized - (per-particle). Only trails (N*T) need per-frame upload; everything - else is N-sized (~60KB). - - Parameters - ---------- - d_frame : cupy.ndarray, shape (H, W, 3) - GPU frame buffer (float32 0-1). Modified in-place via atomic add. - """ - if self._hydro_particles is None or self._hydro_trails is None: - return - - from .analysis.render import _compute_camera_basis - - N = self._hydro_particles.shape[0] - trail_len = self._hydro_trail_len - total = N * trail_len - - cam_pos = self.position - look_at = self._get_look_at() - forward, right, cam_up = _compute_camera_basis( - tuple(cam_pos), tuple(look_at), (0, 0, 1), - ) - fov_scale = math.tan(math.radians(self.fov) / 2.0) - aspect_ratio = d_frame.shape[1] / d_frame.shape[0] - - # Flatten trails: (N, T, 2) → (N*T, 2) — the only large upload - all_pts = self._hydro_trails.reshape(-1, 2) - - # Allocate / resize GPU buffers - if self._d_hydro_trails is None or self._d_hydro_trails.shape[0] != total: - self._d_hydro_trails = cp.empty((total, 2), dtype=cp.float32) - if self._d_hydro_ages is None or self._d_hydro_ages.shape[0] != N: - self._d_hydro_ages = cp.empty(N, dtype=cp.int32) - self._d_hydro_lifetimes = cp.empty(N, dtype=cp.int32) - self._d_hydro_colors = cp.empty((N, 3), dtype=cp.float32) - self._d_hydro_radii = cp.empty(N, dtype=cp.int32) - - # Upload — trails are N*T (~3MB), rest is N-sized (~60KB each) - self._d_hydro_trails.set(all_pts) - self._d_hydro_ages.set(self._hydro_ages) - self._d_hydro_lifetimes.set(self._hydro_lifetimes) - self._d_hydro_colors.set(self._hydro_particle_colors) - self._d_hydro_radii.set(self._hydro_particle_radii) - - # GPU terrain + """Project and splat hydro particles on GPU. Delegates to HydroManager.""" terrain_data = self.raster.data - if not isinstance(terrain_data, cp.ndarray): + if has_cupy and not isinstance(terrain_data, cp.ndarray): terrain_data = cp.asarray(terrain_data) - # Depth buffer for occlusion culling (populated by _update_frame) depth_t = getattr(self, '_d_depth_t', None) - if depth_t is None: - depth_t = cp.empty((0, 0), dtype=cp.float32) - - # Single kernel launch — alpha computed on GPU - threadsperblock = 256 - blockspergrid = (total + threadsperblock - 1) // threadsperblock - _hydro_splat_kernel[blockspergrid, threadsperblock]( - self._d_hydro_trails, - self._d_hydro_ages, - self._d_hydro_lifetimes, - self._d_hydro_colors, - self._d_hydro_radii, - trail_len, - float(self._hydro_alpha), - int(self._hydro_min_visible_age), - float(self._hydro_ref_depth), - terrain_data, - depth_t, + self.hydro_mgr.splat_gpu( d_frame, - float(cam_pos[0]), float(cam_pos[1]), float(cam_pos[2]), - float(forward[0]), float(forward[1]), float(forward[2]), - float(right[0]), float(right[1]), float(right[2]), - float(cam_up[0]), float(cam_up[1]), float(cam_up[2]), - float(fov_scale), float(aspect_ratio), - float(self._base_pixel_spacing_x), - float(self._base_pixel_spacing_y), - float(self.vertical_exaggeration), - float(self.subsample_factor), - float(self._hydro_min_depth), + camera_pos=self.position, + look_at=self._get_look_at(), + fov=self.fov, + ve=self.vertical_exaggeration, + subsample_factor=self.subsample_factor, + terrain_gpu=terrain_data, + depth_t=depth_t, ) - # Clamp output - cp.clip(d_frame, 0, 1, out=d_frame) - # ------------------------------------------------------------------ # GTFS-RT realtime vehicle overlay # ------------------------------------------------------------------ @@ -6458,24 +6267,15 @@ def _action_toggle_wind(self): def _action_toggle_terrain_vis(self): from .viewer.terrain_lod import is_terrain_lod_gid - entry = self.rtx._geom_state.gas_entries.get('terrain') - if entry is not None: - vis = not entry.visible - self.rtx.set_geometry_visible('terrain', vis) - print(f"Terrain {'shown' if vis else 'hidden'}") - self._needs_render = True - elif self.lod_enabled: - # Determine current visibility from any LOD tile - vis = True - for gid in self.rtx.list_geometries(): - if is_terrain_lod_gid(gid): + # Toggle all LOD terrain tiles together + vis = None + for gid in self.rtx.list_geometries(): + if is_terrain_lod_gid(gid): + if vis is None: e = self.rtx._geom_state.gas_entries.get(gid) - if e is not None: - vis = not e.visible - break - for gid in self.rtx.list_geometries(): - if is_terrain_lod_gid(gid): - self.rtx.set_geometry_visible(gid, vis) + vis = not e.visible if e is not None else True + self.rtx.set_geometry_visible(gid, vis) + if vis is not None: print(f"Terrain {'shown' if vis else 'hidden'}") self._needs_render = True @@ -6565,12 +6365,6 @@ def _action_cycle_color_stretch(self): print(f"Color stretch: {self.color_stretch}") self._update_frame() - def _action_cycle_mesh_type(self): - cycle = {'tin': 'voxel', 'voxel': 'heightfield', 'heightfield': 'tin'} - self.mesh_type = cycle.get(self.mesh_type, 'tin') - self._rebuild_vertical_exaggeration(self.vertical_exaggeration) - print(f"Mesh type: {self.mesh_type}") - def _action_cycle_basemap_fwd(self): self._cycle_basemap() @@ -6608,25 +6402,6 @@ def _action_resolution_finer(self): if new_factor != self.subsample_factor: self._rebuild_at_resolution(new_factor) - def _action_toggle_terrain_lod(self): - """Toggle distance-based terrain LOD on/off.""" - if self.rtx is None: - return - - if self.lod_enabled: - # Disable LOD — remove tile geometries, restore single terrain - if self._terrain_lod_manager is not None: - self._terrain_lod_manager.remove_all(self.rtx) - self._terrain_lod_manager = None - self.lod_enabled = False - # Rebuild the single terrain geometry - self._rebuild_at_resolution(self.subsample_factor) - print("Terrain LOD: OFF") - else: - # Enable LOD — replace single terrain with tiled LOD - self._enable_terrain_lod() - print(f"Terrain LOD: ON ({self._terrain_lod_manager.get_stats()})") - def _action_ve_down(self): new_ve = max(0.1, round(self.vertical_exaggeration - 0.1, 1)) if new_ve != self.vertical_exaggeration: @@ -6782,15 +6557,24 @@ def _sync_drone_from_pos_for(self, obs, pos): self._calculate_viewshed(quiet=True) def _check_terrain_reload(self): - """Check if camera is near terrain edge and reload a new window. + """Check if camera is near terrain edge and prefetch the next window. The terrain loader runs in a background thread so it doesn't block the render loop (erosion/hydro can take many seconds). Each tick we either (a) submit a new loader job if near-edge, or (b) poll for a completed result and swap in the new terrain. + + Prefetch strategy: triggers at 40% from any edge (not 20%) so the + load starts well before the camera reaches the boundary. The load + center is offset in the camera's direction of travel so the new + terrain extends further ahead. """ if self._terrain_loader is None: return + # Streaming LOD handles edge loading — skip terrain replacement + if (getattr(self, '_tile_data_fn', None) is not None + and self.lod_enabled): + return # --- Phase 2: check for completed background load --- future = self.terrain._terrain_reload_future @@ -6820,12 +6604,17 @@ def _check_terrain_reload(self): return H, W = self.terrain_shape - cam_col = self.position[0] / self.pixel_spacing_x - cam_row = self.position[1] / self.pixel_spacing_y - - # Check if camera is within 20% of any edge - margin_x = W * 0.2 - margin_y = H * 0.2 + # Camera position relative to the terrain grid (accounting for + # any world offset from previous reloads). + ox = self.terrain._world_offset_x + oy = self.terrain._world_offset_y + cam_col = (self.position[0] - ox) / self.pixel_spacing_x + cam_row = (self.position[1] - oy) / self.pixel_spacing_y + + # Prefetch at 40% from any edge — starts loading well before + # the camera reaches the boundary. + margin_x = W * 0.4 + margin_y = H * 0.4 near_edge = (cam_col < margin_x or cam_col > W - margin_x or cam_row < margin_y or cam_row > H - margin_y) if not near_edge: @@ -6835,6 +6624,19 @@ def _check_terrain_reload(self): cam_lon = self._coord_origin_x + cam_col * self._coord_step_x cam_lat = self._coord_origin_y + cam_row * self._coord_step_y + # Offset load center in the direction of camera travel so the + # new terrain extends further ahead of the camera. + front = self._get_front() + fx, fy = float(front[0]), float(front[1]) + flen = np.sqrt(fx * fx + fy * fy) + if flen > 0.01: + # Offset by 25% of the window in the camera's forward direction, + # converted from pixel space to geographic coordinates. + offset_px_x = (W * 0.25) * (fx / flen) + offset_px_y = (H * 0.25) * (fy / flen) + cam_lon += offset_px_x * self._coord_step_x + cam_lat += offset_px_y * self._coord_step_y + # Submit loader to background thread from concurrent.futures import ThreadPoolExecutor pool = self.terrain._terrain_reload_pool @@ -6852,13 +6654,22 @@ def _bg_load(lon, lat): self._last_reload_time = now + 999999 def _apply_terrain_reload(self, result, cam_lon, cam_lat): - """Apply a completed terrain reload result (runs on main thread).""" + """Apply a completed terrain reload result (runs on main thread). + + The camera position is kept stable — instead of teleporting the + camera to its new-grid coordinates, we offset the terrain vertices + so the same geographic point maps to the same world-space position. + This eliminates the jarring jump that would otherwise occur. + """ new_hydro = None if isinstance(result, tuple): new_raster, new_hydro = result else: new_raster = result + # --- Compute world offset to keep camera stable --- + old_pos_x = self.position[0] + old_pos_y = self.position[1] cam_z = self.position[2] # Extract coordinate metadata from new raster @@ -6867,10 +6678,28 @@ def _apply_terrain_reload(self, result, cam_lon, cam_lat): new_step_x = float(new_raster.x.values[1] - new_raster.x.values[0]) new_step_y = float(new_raster.y.values[1] - new_raster.y.values[0]) - # Compute camera position in new window's pixel space + # Where the camera would land in the new grid (pixel coords) new_col = (cam_lon - new_origin_x) / new_step_x new_row = (cam_lat - new_origin_y) / new_step_y + # Offset = current world position minus where new grid would + # place the camera. Adding this to all vertices keeps the camera + # at (old_pos_x, old_pos_y) without moving it. + psx = self.pixel_spacing_x + psy = self.pixel_spacing_y + offset_x = old_pos_x - new_col * psx + offset_y = old_pos_y - new_row * psy + self.terrain._world_offset_x = offset_x + self.terrain._world_offset_y = offset_y + + # Update coordinate mapping so world-to-geo still works: + # lon = coord_origin_x + (pos_x / psx) * coord_step_x + # We need this to produce cam_lon when pos_x = old_pos_x. + self._coord_origin_x = cam_lon - (old_pos_x / psx) * new_step_x + self._coord_origin_y = cam_lat - (old_pos_y / psy) * new_step_y + self._coord_step_x = new_step_x + self._coord_step_y = new_step_y + # Replace rasters self._base_raster = new_raster self.raster = new_raster @@ -6880,12 +6709,6 @@ def _apply_terrain_reload(self, result, cam_lon, cam_lat): self._d_wind_scratch = None self._d_depth_t = None # invalidate depth buffer - # Update coordinate tracking - self._coord_origin_x = new_origin_x - self._coord_origin_y = new_origin_y - self._coord_step_x = new_step_x - self._coord_step_y = new_step_y - # Recompute terrain stats new_H, new_W = new_raster.shape self.terrain_shape = (new_H, new_W) @@ -6925,92 +6748,40 @@ def _apply_terrain_reload(self, result, cam_lon, cam_lat): self._land_color_range = (float(np.nanmin(land_pixels)) * ve, float(np.nanmax(land_pixels)) * ve) - # Clear terrain mesh cache (old window geometry is stale) - self._terrain_mesh_cache.clear() - self._baked_mesh_cache.clear() - - # Rebuild terrain mesh - from . import mesh as mesh_mod - - H, W = new_H, new_W - cache_key = (self.subsample_factor, self.mesh_type) - - if self.mesh_type == 'heightfield': - if self.rtx is not None: - self.rtx.add_heightfield_geometry( - 'terrain', terrain_np, H, W, - spacing_x=self.pixel_spacing_x, - spacing_y=self.pixel_spacing_y, - ve=ve, - ) - if self.terrain_skirt: - sv, si = mesh_mod.build_terrain_skirt( - terrain_np, H, W, scale=ve, - pixel_spacing_x=self.pixel_spacing_x, - pixel_spacing_y=self.pixel_spacing_y) - self.rtx.add_geometry('terrain_skirt', sv, si) - elif self.rtx.has_geometry('terrain_skirt'): - self.rtx.remove_geometry('terrain_skirt') - self._terrain_mesh_cache[cache_key] = (None, None, terrain_np.copy()) - else: - if self.mesh_type == 'voxel': - num_verts = H * W * 8 - num_tris = H * W * 12 - vertices = np.zeros(num_verts * 3, dtype=np.float32) - indices = np.zeros(num_tris * 3, dtype=np.int32) - base_elev = float(np.nanmin(terrain_np)) - mesh_mod.voxelate_terrain(vertices, indices, new_raster, scale=1.0, - base_elevation=base_elev) - else: - num_verts = H * W - num_tris = (H - 1) * (W - 1) * 2 - vertices = np.zeros(num_verts * 3, dtype=np.float32) - indices = np.zeros(num_tris * 3, dtype=np.int32) - mesh_mod.triangulate_terrain(vertices, indices, new_raster, scale=1.0) - - if self.terrain_skirt: - vertices, indices = mesh_mod.add_terrain_skirt( - vertices, indices, H, W) - - # Scale x,y to world units - if self.pixel_spacing_x != 1.0 or self.pixel_spacing_y != 1.0: - vertices[0::3] *= self.pixel_spacing_x - vertices[1::3] *= self.pixel_spacing_y - - # Apply vertical exaggeration - if ve != 1.0: - vertices[2::3] *= ve - - # Cache the new mesh - base_verts = vertices.copy() - if ve != 1.0: - base_verts[2::3] /= ve - self._terrain_mesh_cache[cache_key] = (base_verts, indices.copy(), terrain_np.copy()) - - # Replace terrain geometry - if self.rtx is not None: - gd = (H, W) if self.mesh_type != 'voxel' and not self.terrain_skirt else None - self.rtx.add_geometry('terrain', vertices, indices, - grid_dims=gd) - - # Reinitialize hydro if the loader provided new flow data + # Clear all terrain caches (old window geometry is stale) + self.terrain.clear_all_caches() + + # --- Rebuild terrain mesh via LOD manager --- + if self._terrain_lod_manager is not None: + mgr = self._terrain_lod_manager + mgr.set_terrain(terrain_np, offset_x=offset_x, offset_y=offset_y) + # Force immediate full tile rebuild + saved_limit = mgr.per_tick_build_limit + mgr.per_tick_build_limit = 10000 + mgr.update(self.position, self.rtx, + ve=ve, force=True, + camera_front=self._get_front(), fov=self.camera.fov) + mgr.per_tick_build_limit = saved_limit + + # Reinitialize hydro for new terrain if new_hydro is not None and self._hydro_data is not None: was_enabled = self._hydro_enabled - flow_dir = new_hydro['flow_dir'] flow_accum = new_hydro['flow_accum'] hydro_opts = {k: v for k, v in new_hydro.items() - if k not in ('flow_dir', 'flow_accum', 'enabled')} - self._init_hydro(flow_dir, flow_accum, **hydro_opts) + if k not in ('flow_accum', 'enabled')} + self._init_hydro(flow_accum, **hydro_opts) + self._hydro_enabled = was_enabled + elif self._hydro_lazy and self._hydro_data is not None: + was_enabled = self._hydro_enabled + self._compute_hydro_from_terrain() self._hydro_enabled = was_enabled - # Reposition camera in new window - self.position = np.array([ - new_col * self.pixel_spacing_x, - new_row * self.pixel_spacing_y, - cam_z - ], dtype=float) + # Camera stays at its current position — no jump. + # Only update Z if the terrain height changed significantly + # under the camera (keeps altitude above ground consistent). # Refresh minimap + self._minimap_bg_extent = None self._compute_minimap_background() self._last_reload_time = time.time() @@ -7115,15 +6886,30 @@ def _tick(self): # --- Simulation (terrain reload, chunk loading, AO accumulation) --- self._check_terrain_reload() - if self._chunk_manager is not None: - if self._chunk_manager.update(self.position[0], self.position[1], self): - self._geometry_colors_builder = self._accessor._build_geometry_colors_gpu - self._render_needed = True - # Terrain LOD: update tile resolutions based on camera distance + # Terrain LOD runs first so tile_lods are fresh for chunk manager if self.lod_enabled and self._terrain_lod_manager is not None: if self._terrain_lod_manager.update( self.position, self.rtx, - ve=self.vertical_exaggeration): + ve=self.vertical_exaggeration, + camera_front=self._get_front(), fov=self.camera.fov): + self._render_needed = True + if self._chunk_manager is not None: + # When LOD is active, sync distance parameters from LOD manager + if (self.lod_enabled and self._terrain_lod_manager is not None + and self._terrain_lod_manager._lod_distances): + self._chunk_manager.max_distance = ( + self._terrain_lod_manager._lod_distances[-1]) + self._chunk_manager._lod_distances = ( + self._terrain_lod_manager._lod_distances) + # When grids are aligned, pass tile LOD assignments directly + # so the chunk manager skips its own distance computation. + self._chunk_manager._tile_lods = ( + self._terrain_lod_manager._tile_lods) + else: + self._chunk_manager._lod_distances = None + self._chunk_manager._tile_lods = None + if self._chunk_manager.update(self.position[0], self.position[1], self): + self._geometry_colors_builder = self._accessor._build_geometry_colors_gpu self._render_needed = True # AO/DOF: keep accumulating samples when camera is stationary if ((self.ao_enabled or self.dof_enabled) and not self._held_keys @@ -7145,6 +6931,11 @@ def _tick(self): self._update_wind_particles() self._splat_wind_gpu(self._d_wind_scratch) if self._hydro_enabled and self._hydro_particles is not None: + self.hydro_mgr.check_streaming_result() + self._transfer_streaming_overlay() + cam_r = self.position[1] / self._base_pixel_spacing_y + cam_c = self.position[0] / self._base_pixel_spacing_x + self.hydro_mgr.update_streaming_window(cam_r, cam_c) self._update_hydro_particles() self._splat_hydro_gpu(self._d_wind_scratch) if self._clouds_enabled and self._rain_particles is not None: @@ -7198,6 +6989,8 @@ def _cycle_basemap(self, reverse=False): if provider == 'none': self._tiles_enabled = False + if self._texture_tile_mgr is not None: + self._texture_tile_mgr.clear() print("Basemap: none") else: from .tiles import XYZTileService @@ -7208,14 +7001,26 @@ def _cycle_basemap(self, reverse=False): self._tile_service = XYZTileService( url_template=provider, raster=self._base_raster, ) - self._tile_service.fetch_visible_tiles() else: self._tile_service = XYZTileService( url_template=provider, raster=self._base_raster, ) - self._tile_service.fetch_visible_tiles() self._tiles_enabled = True print(f"Basemap: {provider}") + # When LOD active, clear and re-fetch basemap lazily per tile. + # When LOD not active, use monolithic fetch. + if (self._texture_tile_mgr is not None + and self.lod_enabled + and self._terrain_lod_manager is not None): + self._texture_tile_mgr.clear() + # Re-fire tile callbacks for all visible tiles so basemap + # gets fetched for them in background threads. + mgr = self._terrain_lod_manager + if mgr._on_tile_added is not None: + for (tr, tc) in list(mgr._tile_lods.keys()): + mgr._fire_tile_added(tr, tc) + else: + self._tile_service.fetch_visible_tiles() self._update_frame() @@ -7252,7 +7057,7 @@ def _cycle_geometry_layer(self): for geom_id in self._all_geometries: parts = geom_id.rsplit('_', 1) base_name = parts[0] if len(parts) == 2 and parts[1].isdigit() else geom_id - if base_name == layer_name or geom_id == layer_name or geom_id == 'terrain': + if base_name == layer_name or geom_id == layer_name: self.rtx.set_geometry_visible(geom_id, True) visible_count += 1 else: @@ -8102,9 +7907,21 @@ def _save_screenshot(self): if self.viewshed_enabled and self._viewshed_cache is not None: viewshed_data = self._viewshed_cache - # Get tile texture for screenshot if enabled + # Basemap texture for screenshot rgb_texture = None - if self._tiles_enabled and self._tile_service is not None: + _tex_off_y = 0 + _tex_off_x = 0 + if (self._texture_tile_mgr is not None + and self.lod_enabled + and self._terrain_lod_manager is not None): + visible = set(self._terrain_lod_manager._tile_lods.keys()) + d_tex, tex_r, tex_c = self._texture_tile_mgr.get_composite( + visible) + if d_tex is not None: + rgb_texture = d_tex + _tex_off_y = tex_r + _tex_off_x = tex_c + elif self._tiles_enabled and self._tile_service is not None: rgb_texture = self._tile_service.get_gpu_texture() if rgb_texture is not None and self.subsample_factor > 1: f = self.subsample_factor @@ -8137,6 +7954,27 @@ def _save_screenshot(self): _cloud_fog_map = self._d_cloud_fog_map _cloud_fog_density = 12.0 / self._scene_diagonal + # Resolve overlay: per-tile composite when LOD active, + # else monolithic array. Used by both screenshot and live paths. + _ov_data = self._active_overlay_data + _ov_lut = self._active_overlay_color_lut + _ov_off_y = 0 + _ov_off_x = 0 + if (self._overlay_tile_mgr is not None + and self._active_overlay_data is not None + and self.lod_enabled + and self._terrain_lod_manager is not None): + visible = set(self._terrain_lod_manager._tile_lods.keys()) + d_comp, off_r, off_c = self._overlay_tile_mgr.get_composite( + visible) + if d_comp is not None: + _ov_data = d_comp + _ov_off_y = off_r + _ov_off_x = off_c + lut = self._overlay_tile_mgr.color_lut + if lut is not None: + _ov_lut = lut + render_kwargs = dict( camera_position=tuple(self.position), look_at=tuple(self._get_look_at()), @@ -8159,10 +7997,14 @@ def _save_screenshot(self): color_stretch=self.color_stretch, color_range=self._land_color_range, rgb_texture=rgb_texture, - overlay_data=self._active_overlay_data, + rgb_texture_offset_y=_tex_off_y, + rgb_texture_offset_x=_tex_off_x, + overlay_data=_ov_data, overlay_alpha=self._overlay_alpha, overlay_as_water=self._overlay_as_water, - overlay_color_lut=self._active_overlay_color_lut, + overlay_color_lut=_ov_lut, + overlay_offset_y=_ov_off_y, + overlay_offset_x=_ov_off_x, geometry_colors=geometry_colors, cloud_fog_map=_cloud_fog_map, cloud_fog_density=_cloud_fog_density, @@ -8251,12 +8093,23 @@ def _render_frame(self): if self.frame_count % 100 == 0: # Only print occasionally print(f"[DEBUG] Viewshed enabled but cache is None") - # Get GPU texture from tile service if enabled + # Basemap texture — per-tile composite when LOD active, + # else monolithic GPU texture from tile service. rgb_texture = None - if self._tiles_enabled and self._tile_service is not None: + _tex_off_y = 0 + _tex_off_x = 0 + if (self._texture_tile_mgr is not None + and self.lod_enabled + and self._terrain_lod_manager is not None): + visible = set(self._terrain_lod_manager._tile_lods.keys()) + d_tex, tex_r, tex_c = self._texture_tile_mgr.get_composite( + visible) + if d_tex is not None: + rgb_texture = d_tex + _tex_off_y = tex_r + _tex_off_x = tex_c + elif self._tiles_enabled and self._tile_service is not None: rgb_texture = self._tile_service.get_gpu_texture() - # Tile texture is always at base resolution — stride-subsample - # to match the current (possibly subsampled) raster if rgb_texture is not None and self.subsample_factor > 1: f = self.subsample_factor rgb_texture = rgb_texture[::f, ::f, :] @@ -8312,6 +8165,27 @@ def _render_frame(self): # Spatial frequency: ~12 cloud cells across scene _cloud_fog_density = 12.0 / self._scene_diagonal + # Resolve overlay: per-tile composite when LOD active, + # else monolithic array. + _ov_data = self._active_overlay_data + _ov_lut = self._active_overlay_color_lut + _ov_off_y = 0 + _ov_off_x = 0 + if (self._overlay_tile_mgr is not None + and self._active_overlay_data is not None + and self.lod_enabled + and self._terrain_lod_manager is not None): + visible = set(self._terrain_lod_manager._tile_lods.keys()) + d_comp, off_r, off_c = self._overlay_tile_mgr.get_composite( + visible) + if d_comp is not None: + _ov_data = d_comp + _ov_off_y = off_r + _ov_off_x = off_c + lut = self._overlay_tile_mgr.color_lut + if lut is not None: + _ov_lut = lut + d_output = render( self.raster, camera_position=tuple(self.position), @@ -8332,15 +8206,19 @@ def _render_frame(self): observer_position=observer_pos, pixel_spacing_x=self.pixel_spacing_x, pixel_spacing_y=self.pixel_spacing_y, - mesh_type=self.mesh_type, + mesh_type='heightfield', color_data=self._active_color_data, color_stretch=self.color_stretch, color_range=self._land_color_range, rgb_texture=rgb_texture, - overlay_data=self._active_overlay_data, + rgb_texture_offset_y=_tex_off_y, + rgb_texture_offset_x=_tex_off_x, + overlay_data=_ov_data, overlay_alpha=self._overlay_alpha, overlay_as_water=self._overlay_as_water, - overlay_color_lut=self._active_overlay_color_lut, + overlay_color_lut=_ov_lut, + overlay_offset_y=_ov_off_y, + overlay_offset_x=_ov_off_x, geometry_colors=geometry_colors, ao_samples=ao_samples, ao_radius=self.ao_radius, @@ -8492,8 +8370,13 @@ def _update_frame(self): self._update_wind_particles() self._splat_wind_gpu(d_display) - # GPU hydro: advect on CPU, splat on GPU + # GPU hydro: advect + splat on GPU if self._hydro_enabled and self._hydro_particles is not None: + self.hydro_mgr.check_streaming_result() + self._transfer_streaming_overlay() + cam_r = self.position[1] / self._base_pixel_spacing_y + cam_c = self.position[0] / self._base_pixel_spacing_x + self.hydro_mgr.update_streaming_window(cam_r, cam_c) self._update_hydro_particles() self._splat_hydro_gpu(d_display) @@ -8639,14 +8522,18 @@ def _handle_mouse_press(self, button, xpos, ypos): frame_x = xpos * self.render_width / max(1, self.width) frame_y = ypos * self.render_height / max(1, self.height) if (mx0 <= frame_x < mx0 + mw and my0 <= frame_y < my0 + mh): - # Convert minimap-local → terrain pixel → world XY + # Convert minimap-local → world XY local_x = frame_x - mx0 local_y = frame_y - my0 - H, W = self.terrain_shape - terrain_col = local_x / mw * W - terrain_row = local_y / mh * H - world_x = terrain_col * self.pixel_spacing_x - world_y = terrain_row * self.pixel_spacing_y + ext = self._minimap_world_extent + if ext is not None: + wx_min, wy_min, wx_max, wy_max = ext + world_x = wx_min + local_x / mw * (wx_max - wx_min) + world_y = wy_min + local_y / mh * (wy_max - wy_min) + else: + H, W = self.terrain_shape + world_x = local_x / mw * W * self.pixel_spacing_x + world_y = local_y / mh * H * self.pixel_spacing_y self.position[0] = world_x self.position[1] = world_y self._update_frame() @@ -8972,7 +8859,6 @@ def _render_help_text(self): ("Y", "Cycle color stretch"), (", / .", "Overlay alpha"), ("R / Shift+R", "Resolution down / up"), - ("Shift+A", "Toggle terrain LOD"), ("Z / Shift+Z", "Vert. exag. down / up"), ("B", "Toggle TIN / Voxel"), ("T", "Toggle shadows"), @@ -9324,6 +9210,7 @@ def run(self, start_position: Optional[Tuple[float, float, float]] = None, self._render_help_text() # --- Initialize minimap --- + self._minimap_bg_extent = None self._compute_minimap_background() # --- GLFW callbacks --- @@ -9473,6 +9360,10 @@ def _run_repl(): sys.stdout.write('\033[?25h') # show cursor sys.stdout.flush() + # Clean up LOD thread pool + if self._terrain_lod_manager is not None: + self._terrain_lod_manager.shutdown() + # Clean up terrain reload thread pool pool = self.terrain._terrain_reload_pool if pool is not None: @@ -9496,7 +9387,6 @@ def explore(raster, width: int = 800, height: int = 600, key_repeat_interval: float = 0.05, rtx: 'RTX' = None, pixel_spacing_x: float = 1.0, pixel_spacing_y: float = 1.0, - mesh_type: str = 'heightfield', overlay_layers: dict = None, color_stretch: str = 'linear', title: str = None, @@ -9512,6 +9402,7 @@ def explore(raster, width: int = 800, height: int = 600, gtfs_data=None, accessor=None, terrain_loader=None, + tile_data_fn=None, scene_zarr=None, ao_samples: int = 0, gi_bounces: int = 1, @@ -9563,8 +9454,6 @@ def explore(raster, width: int = 800, height: int = 600, Must match the spacing used when triangulating terrain. Default 1.0. pixel_spacing_y : float, optional Y spacing between pixels in world units. Default 1.0. - mesh_type : str, optional - Mesh generation method: 'tin' or 'voxel'. Default is 'tin'. scene_zarr : str or Path, optional Path to a zarr store with a ``meshes/`` group. When provided, mesh chunks are loaded dynamically based on camera position @@ -9575,10 +9464,12 @@ def explore(raster, width: int = 800, height: int = 600, wind_data : dict, optional Wind data from ``fetch_wind()``. If provided, Shift+W toggles wind particle animation. - hydro_data : dict, optional - Hydrological flow data with keys ``'flow_dir'`` (D8 direction - grid) and ``'flow_accum'`` (flow accumulation grid). If provided, - Shift+Y toggles hydro flow particle animation. Optional keys: + hydro_data : dict or True, optional + Hydrological flow data. Pass ``True`` or ``{'enabled': False}`` + for lazy mode: MFD flow analysis is computed on GPU from the + current terrain when Shift+Y is first pressed. Or pass a dict + with key ``'flow_accum'`` for pre-computed data. Optional keys: + ``'flow_dir_mfd'`` (xrspatial MFD fractions, shape (8,H,W)), ``'n_particles'``, ``'max_age'``, ``'trail_len'``, ``'speed'``, ``'accum_threshold'``, ``'color'``, ``'alpha'``, ``'dot_radius'``. gtfs_data : dict, optional @@ -9622,10 +9513,8 @@ def explore(raster, width: int = 800, height: int = 600, - [/]: Decrease/increase observer height - R: Decrease terrain resolution (coarser, up to 8x subsample) - Shift+R: Increase terrain resolution (finer, down to 1x) - - Shift+A: Toggle distance-based terrain LOD - Z: Decrease vertical exaggeration - Shift+Z: Increase vertical exaggeration - - B: Toggle mesh type (TIN / voxel) - Y: Cycle color stretch (linear, sqrt, cbrt, log) - T: Toggle shadows - 0: Toggle ambient occlusion (progressive) @@ -9668,7 +9557,6 @@ def explore(raster, width: int = 800, height: int = 600, rtx=rtx, pixel_spacing_x=pixel_spacing_x, pixel_spacing_y=pixel_spacing_y, - mesh_type=mesh_type, overlay_layers=overlay_layers, title=title, subtitle=subtitle, @@ -9684,6 +9572,7 @@ def explore(raster, width: int = 800, height: int = 600, viewer._info_text = info_text viewer._accessor = accessor viewer._terrain_loader = terrain_loader + viewer._tile_data_fn = tile_data_fn if scene_zarr is not None: viewer._chunk_manager = _MeshChunkManager( scene_zarr, pixel_spacing_x, pixel_spacing_y) @@ -9709,40 +9598,50 @@ def explore(raster, width: int = 800, height: int = 600, # Hydro flow initialization if hydro_data is not None: - hydro_start_enabled = hydro_data.get('enabled', True) - flow_dir = hydro_data['flow_dir'] - flow_accum = hydro_data['flow_accum'] - hydro_opts = {k: v for k, v in hydro_data.items() - if k not in ('flow_dir', 'flow_accum', 'enabled')} - viewer._init_hydro(flow_dir, flow_accum, **hydro_opts) - # Re-register stream_link overlay with NaN + palette coloring - if (viewer._hydro_stream_order_raw is not None - and 'stream_link' in viewer._overlay_layers): - max_order = int(viewer._hydro_stream_order_raw.max()) - palette_lut = InteractiveViewer._build_stream_palette_lut( - max_order) - sl_data = viewer._base_overlay_layers['stream_link'] - if hasattr(sl_data, 'get'): - sl_data = sl_data.get() - sl_data = np.asarray(sl_data, dtype=np.float32) - so_raw = viewer._hydro_stream_order_raw.astype(np.float32) - sl_color = np.where( - (sl_data <= 0) | (so_raw <= 0), - np.float32(np.nan), so_raw) - _add_overlay(viewer, 'stream_link', sl_color, - color_lut=palette_lut) - if hydro_start_enabled: - viewer._hydro_enabled = True - # Stream cells rendered with water reflection shader - if 'stream_link' in viewer._overlay_layers: - viewer._overlay_as_water = True + if hydro_data is True or 'flow_accum' not in hydro_data: + # Lazy mode: compute MFD hydro from terrain on first enable + viewer._hydro_lazy = True + hydro_start_enabled = ( + hydro_data.get('enabled', False) + if isinstance(hydro_data, dict) else False) + if hydro_start_enabled: + viewer._compute_hydro_from_terrain() + viewer._hydro_enabled = True + if 'stream_link' in viewer._overlay_layers: + viewer._overlay_as_water = True else: - viewer._hydro_enabled = False - # Switch back to elevation (don't leave stream_link active) - viewer._terrain_layer_idx = 0 - viewer._active_overlay_data = None - viewer._overlay_as_water = False - viewer._active_overlay_color_lut = None + # Pre-computed hydro data provided + hydro_start_enabled = hydro_data.get('enabled', True) + flow_accum = hydro_data['flow_accum'] + hydro_opts = {k: v for k, v in hydro_data.items() + if k not in ('flow_accum', 'enabled')} + viewer._init_hydro(flow_accum, **hydro_opts) + # Re-register stream_link overlay with NaN + palette coloring + if (viewer._hydro_stream_order_raw is not None + and 'stream_link' in viewer._overlay_layers): + max_order = int(viewer._hydro_stream_order_raw.max()) + palette_lut = InteractiveViewer._build_stream_palette_lut( + max_order) + sl_data = viewer._base_overlay_layers['stream_link'] + if hasattr(sl_data, 'get'): + sl_data = sl_data.get() + sl_data = np.asarray(sl_data, dtype=np.float32) + so_raw = viewer._hydro_stream_order_raw.astype(np.float32) + sl_color = np.where( + (sl_data <= 0) | (so_raw <= 0), + np.float32(np.nan), so_raw) + _add_overlay(viewer, 'stream_link', sl_color, + color_lut=palette_lut) + if hydro_start_enabled: + viewer._hydro_enabled = True + if 'stream_link' in viewer._overlay_layers: + viewer._overlay_as_water = True + else: + viewer._hydro_enabled = False + viewer._terrain_layer_idx = 0 + viewer._active_overlay_data = None + viewer._overlay_as_water = False + viewer._active_overlay_color_lut = None # GTFS-RT initialization if gtfs_data is not None: diff --git a/rtxpy/kernel.ptx b/rtxpy/kernel.ptx index df7cdd1..99f3539 100644 --- a/rtxpy/kernel.ptx +++ b/rtxpy/kernel.ptx @@ -7,11 +7,11 @@ // .version 9.1 -.target sm_75 +.target sm_86 .address_size 64 // .globl __raygen__main -.const .align 8 .b8 params[96]; +.const .align 8 .b8 params[104]; .visible .entry __raygen__main() { @@ -130,103 +130,174 @@ $L__BB0_4: // .globl __closesthit__chit .visible .entry __closesthit__chit() { - .reg .pred %p<3>; - .reg .f32 %f<38>; - .reg .b32 %r<34>; - .reg .b64 %rd<3>; + .reg .pred %p<5>; + .reg .f32 %f<77>; + .reg .b32 %r<33>; + .reg .b64 %rd<21>; // begin inline asm - call (%f2), _optix_get_ray_tmax, (); + call (%f11), _optix_get_ray_tmax, (); // end inline asm // begin inline asm - call (%r8), _optix_read_primitive_idx, (); + call (%r2), _optix_read_primitive_idx, (); // end inline asm // begin inline asm - call (%r9), _optix_get_hit_kind, (); + call (%r3), _optix_get_hit_kind, (); // end inline asm - setp.eq.s32 %p1, %r9, 254; + setp.eq.s32 %p1, %r3, 254; @%p1 bra $L__BB2_2; // begin inline asm - call (%r10), _optix_get_hit_kind, (); + call (%r4), _optix_get_hit_kind, (); // end inline asm - setp.ne.s32 %p2, %r10, 255; - mov.u32 %r33, 1065353216; - mov.u32 %r31, 0; - mov.u32 %r32, %r31; - @%p2 bra $L__BB2_3; + setp.ne.s32 %p2, %r4, 255; + mov.f32 %f76, 0f3F800000; + mov.f32 %f74, 0f00000000; + mov.f32 %f75, %f74; + @%p2 bra $L__BB2_6; $L__BB2_2: - // begin inline asm - call (%rd1), _optix_get_gas_traversable_handle, (); - // end inline asm - // begin inline asm - call (%r14), _optix_read_sbt_gas_idx, (); - // end inline asm - // begin inline asm - call (%f3), _optix_get_ray_time, (); - // end inline asm - // begin inline asm - call (%f4, %f5, %f6, %f7, %f8, %f9, %f10, %f11, %f12), _optix_get_triangle_vertex_data, (%rd1, %r8, %r14, %f3); - // end inline asm - sub.ftz.f32 %f14, %f7, %f4; - sub.ftz.f32 %f15, %f8, %f5; - sub.ftz.f32 %f16, %f9, %f6; - sub.ftz.f32 %f17, %f10, %f4; - sub.ftz.f32 %f18, %f11, %f5; - sub.ftz.f32 %f19, %f12, %f6; - mul.ftz.f32 %f20, %f15, %f19; - mul.ftz.f32 %f21, %f16, %f18; - sub.ftz.f32 %f22, %f20, %f21; - mul.ftz.f32 %f23, %f14, %f19; - mul.ftz.f32 %f24, %f16, %f17; - sub.ftz.f32 %f25, %f23, %f24; - mul.ftz.f32 %f26, %f14, %f18; - mul.ftz.f32 %f27, %f15, %f17; - sub.ftz.f32 %f28, %f26, %f27; - mul.ftz.f32 %f29, %f25, %f25; - fma.rn.ftz.f32 %f30, %f22, %f22, %f29; - fma.rn.ftz.f32 %f31, %f28, %f28, %f30; - rsqrt.approx.ftz.f32 %f32, %f31; - mul.ftz.f32 %f33, %f32, %f22; - mul.ftz.f32 %f34, %f25, %f32; - neg.ftz.f32 %f35, %f34; - mul.ftz.f32 %f36, %f32, %f28; - mov.b32 %r31, %f33; - mov.b32 %r32, %f35; - mov.b32 %r33, %f36; + ld.const.u64 %rd1, [params+96]; + setp.eq.s64 %p3, %rd1, 0; + @%p3 bra $L__BB2_5; -$L__BB2_3: - cvt.rzi.ftz.u32.f32 %r30, %f2; - cvt.rn.f32.u32 %f37, %r30; - mov.b32 %r18, %f37; - mov.u32 %r17, 0; - // begin inline asm - call _optix_set_payload, (%r17, %r18); - // end inline asm - mov.u32 %r19, 1; - // begin inline asm - call _optix_set_payload, (%r19, %r31); - // end inline asm - mov.u32 %r21, 2; - // begin inline asm - call _optix_set_payload, (%r21, %r32); - // end inline asm - mov.u32 %r23, 3; // begin inline asm - call _optix_set_payload, (%r23, %r33); - // end inline asm - mov.u32 %r25, 4; - // begin inline asm - call _optix_set_payload, (%r25, %r8); + call (%r5), _optix_read_instance_id, (); // end inline asm + shl.b32 %r6, %r5, 1; + cvta.to.global.u64 %rd4, %rd1; + mul.wide.u32 %rd5, %r6, 8; + add.s64 %rd2, %rd4, %rd5; + ld.global.u64 %rd3, [%rd2]; + setp.eq.s64 %p4, %rd3, 0; + @%p4 bra $L__BB2_5; + + ld.global.u64 %rd6, [%rd2+8]; // begin inline asm - call (%r27), _optix_read_instance_id, (); + call (%f15, %f16), _optix_get_triangle_barycentrics, (); // end inline asm - mov.u32 %r28, 5; - // begin inline asm - call _optix_set_payload, (%r28, %r27); + mov.f32 %f17, 0f3F800000; + sub.ftz.f32 %f18, %f17, %f15; + sub.ftz.f32 %f19, %f18, %f16; + mul.lo.s32 %r7, %r2, 3; + mul.wide.u32 %rd7, %r7, 4; + add.s64 %rd8, %rd6, %rd7; + add.s32 %r8, %r7, 1; + mul.wide.u32 %rd9, %r8, 4; + add.s64 %rd10, %rd6, %rd9; + add.s32 %r9, %r7, 2; + mul.wide.u32 %rd11, %r9, 4; + add.s64 %rd12, %rd6, %rd11; + ld.u32 %r10, [%rd8]; + mul.lo.s32 %r11, %r10, 3; + mul.wide.s32 %rd13, %r11, 4; + add.s64 %rd14, %rd3, %rd13; + ld.f32 %f20, [%rd14]; + mul.ftz.f32 %f21, %f20, %f19; + ld.u32 %r12, [%rd10]; + mul.lo.s32 %r13, %r12, 3; + mul.wide.s32 %rd15, %r13, 4; + add.s64 %rd16, %rd3, %rd15; + ld.f32 %f22, [%rd16]; + fma.rn.ftz.f32 %f23, %f15, %f22, %f21; + ld.u32 %r14, [%rd12]; + mul.lo.s32 %r15, %r14, 3; + mul.wide.s32 %rd17, %r15, 4; + add.s64 %rd18, %rd3, %rd17; + ld.f32 %f24, [%rd18]; + fma.rn.ftz.f32 %f25, %f16, %f24, %f23; + ld.f32 %f26, [%rd14+4]; + ld.f32 %f27, [%rd16+4]; + mul.ftz.f32 %f28, %f15, %f27; + fma.rn.ftz.f32 %f29, %f19, %f26, %f28; + ld.f32 %f30, [%rd18+4]; + fma.rn.ftz.f32 %f31, %f16, %f30, %f29; + ld.f32 %f32, [%rd14+8]; + ld.f32 %f33, [%rd16+8]; + mul.ftz.f32 %f34, %f15, %f33; + fma.rn.ftz.f32 %f35, %f19, %f32, %f34; + ld.f32 %f36, [%rd18+8]; + fma.rn.ftz.f32 %f37, %f16, %f36, %f35; + mul.ftz.f32 %f38, %f31, %f31; + fma.rn.ftz.f32 %f39, %f25, %f25, %f38; + fma.rn.ftz.f32 %f40, %f37, %f37, %f39; + rsqrt.approx.ftz.f32 %f41, %f40; + mul.ftz.f32 %f74, %f41, %f25; + mul.ftz.f32 %f75, %f41, %f31; + mul.ftz.f32 %f76, %f41, %f37; + bra.uni $L__BB2_6; + +$L__BB2_5: + // begin inline asm + call (%rd19), _optix_get_gas_traversable_handle, (); + // end inline asm + // begin inline asm + call (%r16), _optix_read_sbt_gas_idx, (); + // end inline asm + // begin inline asm + call (%f42), _optix_get_ray_time, (); + // end inline asm + // begin inline asm + call (%f43, %f44, %f45, %f46, %f47, %f48, %f49, %f50, %f51), _optix_get_triangle_vertex_data, (%rd19, %r2, %r16, %f42); + // end inline asm + sub.ftz.f32 %f53, %f46, %f43; + sub.ftz.f32 %f54, %f47, %f44; + sub.ftz.f32 %f55, %f48, %f45; + sub.ftz.f32 %f56, %f49, %f43; + sub.ftz.f32 %f57, %f50, %f44; + sub.ftz.f32 %f58, %f51, %f45; + mul.ftz.f32 %f59, %f54, %f58; + mul.ftz.f32 %f60, %f55, %f57; + sub.ftz.f32 %f61, %f59, %f60; + mul.ftz.f32 %f62, %f53, %f58; + mul.ftz.f32 %f63, %f55, %f56; + sub.ftz.f32 %f64, %f62, %f63; + mul.ftz.f32 %f65, %f53, %f57; + mul.ftz.f32 %f66, %f54, %f56; + sub.ftz.f32 %f67, %f65, %f66; + mul.ftz.f32 %f68, %f64, %f64; + fma.rn.ftz.f32 %f69, %f61, %f61, %f68; + fma.rn.ftz.f32 %f70, %f67, %f67, %f69; + rsqrt.approx.ftz.f32 %f71, %f70; + mul.ftz.f32 %f74, %f71, %f61; + mul.ftz.f32 %f72, %f64, %f71; + neg.ftz.f32 %f75, %f72; + mul.ftz.f32 %f76, %f71, %f67; + +$L__BB2_6: + cvt.rzi.ftz.u32.f32 %r32, %f11; + cvt.rn.f32.u32 %f73, %r32; + mov.b32 %r20, %f73; + mov.u32 %r19, 0; + // begin inline asm + call _optix_set_payload, (%r19, %r20); + // end inline asm + mov.b32 %r22, %f74; + mov.u32 %r21, 1; + // begin inline asm + call _optix_set_payload, (%r21, %r22); + // end inline asm + mov.b32 %r24, %f75; + mov.u32 %r23, 2; + // begin inline asm + call _optix_set_payload, (%r23, %r24); + // end inline asm + mov.b32 %r26, %f76; + mov.u32 %r25, 3; + // begin inline asm + call _optix_set_payload, (%r25, %r26); + // end inline asm + mov.u32 %r27, 4; + // begin inline asm + call _optix_set_payload, (%r27, %r2); + // end inline asm + // begin inline asm + call (%r29), _optix_read_instance_id, (); + // end inline asm + mov.u32 %r30, 5; + // begin inline asm + call _optix_set_payload, (%r30, %r29); // end inline asm ret; diff --git a/rtxpy/lod.py b/rtxpy/lod.py index c25d050..eac7723 100644 --- a/rtxpy/lod.py +++ b/rtxpy/lod.py @@ -39,6 +39,56 @@ def compute_lod_level(distance, lod_distances): return len(lod_distances) +def compute_lod_level_with_hysteresis(distance, lod_distances, prev_lod, + hysteresis=0.2): + """Return the LOD level with hysteresis to prevent popping. + + Uses wider thresholds for downgrading (increasing LOD level) than + upgrading (decreasing LOD level), creating a dead zone around each + transition boundary. + + Parameters + ---------- + distance : float + Distance from camera to object or tile center. + lod_distances : list of float + Ascending distance thresholds defining LOD transitions. + prev_lod : int + Previous LOD level for this tile (-1 if never assigned). + hysteresis : float + Fractional band width. A tile must move ``hysteresis`` past the + threshold before switching. E.g. 0.2 means 20% beyond. + + Returns + ------- + int + LOD level (0 = highest detail). + """ + if prev_lod < 0: + return compute_lod_level(distance, lod_distances) + + # Upgrade (reduce LOD number = more detail): require distance to be + # clearly inside the better band (threshold * (1 - hysteresis)). + # Downgrade (increase LOD number = less detail): require distance to + # exceed threshold * (1 + hysteresis). + new_lod = compute_lod_level(distance, lod_distances) + if new_lod < prev_lod: + # Upgrading — use tighter threshold + lod_tight = compute_lod_level( + distance / (1.0 - hysteresis), lod_distances) + if lod_tight < prev_lod: + return new_lod + return prev_lod + elif new_lod > prev_lod: + # Downgrading — use looser threshold + lod_loose = compute_lod_level( + distance / (1.0 + hysteresis), lod_distances) + if lod_loose > prev_lod: + return new_lod + return prev_lod + return prev_lod + + def compute_lod_distances(tile_diagonal, factor=3.0, max_lod=3): """Compute LOD distance thresholds from tile geometry. @@ -61,6 +111,50 @@ def compute_lod_distances(tile_diagonal, factor=3.0, max_lod=3): return [tile_diagonal * factor * (2 ** i) for i in range(max_lod)] +def compute_tile_roughness(tile_2d): + """Compute terrain roughness as std deviation from bilinear fit. + + Fits a bilinear surface through the tile's four corners and + measures how much the actual terrain deviates. Flat or planar + tiles score near zero; jagged ridgelines score high. + + Parameters + ---------- + tile_2d : np.ndarray + 2D elevation array, shape ``(H, W)``. May contain NaN. + + Returns + ------- + float + Standard deviation of elevation residuals (world-space units). + Returns 0.0 for degenerate tiles (all NaN or < 2×2). + """ + h, w = tile_2d.shape + if h < 2 or w < 2: + return 0.0 + + # Corner elevations — fall back to tile nanmean if a corner is NaN + corners = np.array([tile_2d[0, 0], tile_2d[0, -1], + tile_2d[-1, 0], tile_2d[-1, -1]]) + valid = ~np.isnan(corners) + if not np.any(valid): + return 0.0 + fill = float(np.mean(corners[valid])) + c00 = corners[0] if valid[0] else fill + c01 = corners[1] if valid[1] else fill + c10 = corners[2] if valid[2] else fill + c11 = corners[3] if valid[3] else fill + + # Bilinear interpolation between corners + ys = np.linspace(0.0, 1.0, h, dtype=np.float32).reshape(h, 1) + xs = np.linspace(0.0, 1.0, w, dtype=np.float32).reshape(1, w) + bilinear = (c00 * (1 - xs) * (1 - ys) + c01 * xs * (1 - ys) + + c10 * (1 - xs) * ys + c11 * xs * ys) + + residuals = tile_2d - bilinear + return float(np.nanstd(residuals)) + + def simplify_mesh(vertices, indices, ratio): """Simplify a triangle mesh using quadric decimation. diff --git a/rtxpy/mesh.py b/rtxpy/mesh.py index 8102e93..9878e32 100644 --- a/rtxpy/mesh.py +++ b/rtxpy/mesh.py @@ -133,6 +133,152 @@ def triangulate_terrain(verts, triangles, terrain, scale=1.0): return 0 +def compute_terrain_normals(terrain, H, W, psx=1.0, psy=1.0): + """Compute smooth vertex normals for a regular-grid terrain. + + Uses central differences on the elevation grid to compute per-vertex + normals analytically. This is much faster than the generic + face-weighted accumulation in :func:`compute_vertex_normals` and + produces identical results for regular grids. + + Parameters + ---------- + terrain : array-like + 2-D elevation array of shape ``(H, W)``. + H, W : int + Grid dimensions. + psx, psy : float + World-space pixel spacing in X and Y. + + Returns + ------- + np.ndarray + Flat float32 normal buffer, shape ``(H * W * 3,)``, with the + same vertex ordering as :func:`triangulate_terrain`. + """ + data = np.asarray(terrain, dtype=np.float32) + + # Central differences with forward/backward at edges + # dz/dx: gradient in column direction (maps to X) + dz_dx = np.empty_like(data) + dz_dx[:, 1:-1] = (data[:, 2:] - data[:, :-2]) / (2.0 * psx) + dz_dx[:, 0] = (data[:, 1] - data[:, 0]) / psx + dz_dx[:, -1] = (data[:, -1] - data[:, -2]) / psx + + # dz/dy: gradient in row direction (maps to Y) + dz_dy = np.empty_like(data) + dz_dy[1:-1, :] = (data[2:, :] - data[:-2, :]) / (2.0 * psy) + dz_dy[0, :] = (data[1, :] - data[0, :]) / psy + dz_dy[-1, :] = (data[-1, :] - data[-2, :]) / psy + + # Normal = normalize(-dz/dx, -dz/dy, 1) + normals = np.empty(H * W * 3, dtype=np.float32) + nx = -dz_dx.ravel() + ny = -dz_dy.ravel() + nz = np.ones(H * W, dtype=np.float32) + + # Handle NaN elevations: use flat up normal + nan_mask = np.isnan(data.ravel()) + nx[nan_mask] = 0.0 + ny[nan_mask] = 0.0 + + length = np.sqrt(nx * nx + ny * ny + nz * nz) + length[length < 1e-10] = 1.0 + normals[0::3] = nx / length + normals[1::3] = ny / length + normals[2::3] = nz / length + + return normals + + +def compute_vertex_normals(vertices, indices): + """Compute area-weighted smooth vertex normals for a triangle mesh. + + For each vertex, sums the face normals of all adjacent triangles + (weighted by triangle area), then normalizes. Vertices with no + adjacent triangles get a default up-facing normal ``(0, 0, 1)``. + + Parameters + ---------- + vertices : np.ndarray + Flat float32 vertex buffer, shape ``(N * 3,)``. + indices : np.ndarray + Flat int32 index buffer, shape ``(M * 3,)``. + + Returns + ------- + np.ndarray + Flat float32 normal buffer, shape ``(N * 3,)``. + """ + verts = np.asarray(vertices, dtype=np.float32) + idx = np.asarray(indices, dtype=np.int32) + num_verts = len(verts) // 3 + num_tris = len(idx) // 3 + + # Reshape for vectorized access + v = verts.reshape(-1, 3) + f = idx.reshape(-1, 3) + + # Triangle vertices + v0 = v[f[:, 0]] + v1 = v[f[:, 1]] + v2 = v[f[:, 2]] + + # Face normals (area-weighted — cross product magnitude = 2 * area) + e1 = v1 - v0 + e2 = v2 - v0 + fn = np.cross(e1, e2) # (num_tris, 3) + + # Accumulate onto vertices using bincount (much faster than np.add.at) + all_idx = np.concatenate([f[:, 0], f[:, 1], f[:, 2]]) + all_fn = np.tile(fn, (3, 1)) # same face normal for each corner vertex + normals = np.zeros((num_verts, 3), dtype=np.float64) + for c in range(3): + normals[:, c] = np.bincount(all_idx, weights=all_fn[:, c], + minlength=num_verts) + + # Normalize + length = np.sqrt(np.sum(normals * normals, axis=1)) + length[length < 1e-10] = 1.0 + normals /= length[:, np.newaxis] + + # Default up-facing for isolated vertices + zero_mask = np.all(normals == 0.0, axis=1) + normals[zero_mask] = [0.0, 0.0, 1.0] + + return normals.astype(np.float32).ravel() + + +def compute_skirt_normals(H, W): + """Compute outward-facing normals for skirt bottom vertices. + + The perimeter layout matches :func:`add_terrain_skirt`: top row + (W verts), right column (H-1), + bottom row reversed (W-1), left column reversed (H-2). + + Parameters + ---------- + H, W : int + Grid dimensions of the terrain tile. + + Returns + ------- + np.ndarray + Flat float32 normal buffer, shape ``((2*(H+W)-4) * 3,)``. + """ + n_perim = 2 * (H + W) - 4 + normals = np.zeros((n_perim, 3), dtype=np.float32) + off = 0 + normals[off:off + W, 1] = -1.0 # top edge: (0, -1, 0) + off += W + normals[off:off + H - 1, 0] = 1.0 # right edge: (1, 0, 0) + off += H - 1 + normals[off:off + W - 1, 1] = 1.0 # bottom edge: (0, 1, 0) + off += W - 1 + normals[off:off + H - 2, 0] = -1.0 # left edge: (-1, 0, 0) + return normals.ravel() + + def add_terrain_skirt(vertices, indices, H, W, skirt_depth=None): """Add a vertical skirt around the edges of a TIN terrain mesh. diff --git a/rtxpy/mesh_store.py b/rtxpy/mesh_store.py index 0ae7cd6..6db5590 100644 --- a/rtxpy/mesh_store.py +++ b/rtxpy/mesh_store.py @@ -56,6 +56,39 @@ def chunks_for_pixel_window(yi0, yi1, xi0, xi1, chunk_h, chunk_w): for cc in range(cc0, cc1 + 1)] +def chunks_for_world_rect(x0, y0, x1, y1, psx, psy, chunk_h, chunk_w, + elev_shape): + """Map a world-coordinate rectangle to overlapping chunk indices. + + Parameters + ---------- + x0, y0 : float + Lower-left corner in world coordinates. + x1, y1 : float + Upper-right corner in world coordinates. + psx, psy : float + Pixel spacing (world units per pixel). + chunk_h, chunk_w : int + Chunk size in pixels (rows, cols). + elev_shape : tuple of int + ``(H, W)`` of the elevation grid. + + Returns + ------- + list of (int, int) + List of ``(chunk_row, chunk_col)`` tuples overlapping the rect. + """ + H, W = elev_shape + # World coords -> pixel coords (clamp to grid) + xi0 = max(int(x0 / psx), 0) + xi1 = min(int(x1 / psx) + 1, W) + yi0 = max(int(y0 / psy), 0) + yi1 = min(int(y1 / psy) + 1, H) + if xi0 >= xi1 or yi0 >= yi1: + return [] + return chunks_for_pixel_window(yi0, yi1, xi0, xi1, chunk_h, chunk_w) + + def save_meshes_to_zarr(zarr_path, meshes, colors, pixel_spacing, elevation_shape, elevation_chunks, curves=None, spheres=None): diff --git a/rtxpy/quickstart.py b/rtxpy/quickstart.py index 58525dc..88ed74c 100644 --- a/rtxpy/quickstart.py +++ b/rtxpy/quickstart.py @@ -8,7 +8,7 @@ def _rasterize_waterways_to_dem(water_geojson, terrain, elev_np, ocean): """Rasterize Overture waterway features into a burn-depth grid. Carves LineStrings (rivers/streams) and fills Polygons (lakes) into - the DEM so the D8 algorithm routes flow through known channels. + the DEM so the MFD algorithm routes flow through known channels. Returns a float32 array of burn depths (positive = carve down), or None if no features were rasterized. @@ -441,9 +441,9 @@ def quickstart( wind : bool Fetch live wind data from Open-Meteo. Default ``True``. hydro : bool - Compute D8 flow direction and flow accumulation from the terrain - using xarray-spatial and enable hydro flow particle animation - (Shift+Y). Default ``False``. + Compute flow accumulation from the terrain using xarray-spatial + and enable MFD hydro flow particle animation (Shift+Y). + Default ``False``. coast_distance : bool Compute terrain-aware surface distance from the coast using xrspatial's ``surface_distance`` (3-D Dijkstra). Adds a @@ -599,10 +599,10 @@ def quickstart( try: import gc from xrspatial import fill as _fill - from xrspatial import flow_direction as _flow_direction - from xrspatial import flow_accumulation as _flow_accumulation - from xrspatial import stream_order as _stream_order - from xrspatial import stream_link as _stream_link + from xrspatial import flow_direction_mfd as _flow_direction_mfd + from xrspatial import flow_accumulation_mfd as _flow_accumulation_mfd + from xrspatial import stream_order_mfd as _stream_order_mfd + from xrspatial import stream_link_mfd as _stream_link_mfd print("Conditioning DEM for hydrological flow...") from scipy.ndimage import uniform_filter as _uniform_filter @@ -630,7 +630,7 @@ def quickstart( print(f" Waterway GeoJSON load failed: {_e}") # 1b. Burn Overture waterways into the DEM — carve known - # river/stream channels so D8 routes through them. + # river/stream channels so MFD routes through them. if _water_geojson is not None: try: _ww_burn = _rasterize_waterways_to_dem( @@ -675,27 +675,30 @@ def quickstart( resolved[ocean] = -100.0 del ocean_gradient - # 3c. Channel burning — compute an initial flow + # 3c. Channel burning — compute an initial MFD flow # accumulation, then lower high-accumulation cells # to carve channels into the DEM. Re-fill and # re-compute so streams connect into a network. - _fd0 = _flow_direction(terrain.copy(data=resolved)) - _fa0 = _flow_accumulation(_fd0) - del _fd0 - _fa0_np = _fa0.data.get() if is_cupy else np.asarray(_fa0.data) - _fa0_np = np.nan_to_num(_fa0_np, nan=0.0) + xp = _cp if is_cupy else np # array module for GPU/CPU + _fd_mfd0 = _flow_direction_mfd(terrain.copy(data=resolved)) + _fa0 = _flow_accumulation_mfd(_fd_mfd0) + del _fd_mfd0 + _fa0_data = xp.nan_to_num(_fa0.data, nan=0.0) del _fa0 # Burn proportional to log(accumulation): cells with # more upstream area get carved deeper (up to ~2 m). - _log_acc = np.log10(np.clip(_fa0_np, 1, None)) - _log_max = max(_log_acc.max(), 1.0) + _log_acc = xp.log10(xp.clip(_fa0_data, 1, None)) + _log_max = max(float(_log_acc.max()), 1.0) _burn = (_log_acc / _log_max) * 2.0 # 0–2 m carve - _burn[ocean] = 0.0 - del _fa0_np, _log_acc + if is_cupy: + _burn[_cp.asarray(ocean)] = 0.0 + else: + _burn[ocean] = 0.0 + del _fa0_data, _log_acc if is_cupy: - resolved = resolved - _cp.asarray(_burn.astype(np.float32)) + resolved = resolved - _burn.astype(_cp.float32) resolved[_cp.asarray(ocean)] = -100.0 else: resolved -= _burn.astype(np.float32) @@ -725,44 +728,38 @@ def quickstart( _cp.get_default_memory_pool().free_all_blocks() gc.collect() - # 4. Compute final D8 flow direction and accumulation. - fd = _flow_direction(terrain.copy(data=resolved)) - del resolved - fa = _flow_accumulation(fd) + # 4. Compute final MFD flow direction and accumulation. + _resolved_da = terrain.copy(data=resolved) + fd_mfd = _flow_direction_mfd(_resolved_da) + fa_mfd = _flow_accumulation_mfd(fd_mfd) + del _resolved_da, resolved - # 5. Compute Strahler stream order — only stream cells - # (accum >= threshold) get an order; rest are NaN. - so = _stream_order(fd, fa, threshold=50) + # 5. Compute Strahler stream order (MFD) — only stream + # cells (accum >= threshold) get an order; rest are NaN. + so = _stream_order_mfd(fd_mfd, fa_mfd, threshold=50) - # 5b. Compute stream link — unique segment IDs per reach. - sl = _stream_link(fd, fa, threshold=50) + # 5b. Compute stream link (MFD) — unique segment IDs. + sl = _stream_link_mfd(fd_mfd, fa_mfd, threshold=50) - # 6. Mask ocean back to NaN/0 in the output grids. - fd_out = fd.data - fa_out = fa.data + # 6. Mask ocean back to NaN in the output grids. + fa_mfd_out = fa_mfd.data + fd_mfd_out = fd_mfd.data # (8, H, W) so_out = so.data + _sl_out = sl.data if is_cupy: ocean_gpu = _cp.asarray(ocean) - fd_out[ocean_gpu] = _cp.nan - fa_out[ocean_gpu] = _cp.nan + fa_mfd_out[ocean_gpu] = _cp.nan + fd_mfd_out[:, ocean_gpu] = _cp.nan so_out[ocean_gpu] = _cp.nan - else: - fd_out[ocean] = np.nan - fa_out[ocean] = np.nan - so_out[ocean] = np.nan - - # Add stream_link to the dataset so it shows up as an - # overlay layer (G key) with palette-matched colors. - _sl_out = sl.data - if is_cupy: _sl_out[ocean_gpu] = _cp.nan else: + fa_mfd_out[ocean] = np.nan + fd_mfd_out[:, ocean] = np.nan + so_out[ocean] = np.nan _sl_out[ocean] = np.nan - _sl_np = _sl_out.get() if is_cupy else np.asarray(_sl_out) - _sl_clean = np.nan_to_num(_sl_np, nan=0.0).astype(np.float32) - if is_cupy: - _sl_clean = _cp.asarray(_sl_clean) - del _sl_np + + # Clean stream_link for overlay (stay on GPU when possible) + _sl_clean = xp.nan_to_num(_sl_out, nan=0.0).astype(xp.float32) # 6b. Burn Overture waterways into stream_link + stream_order # so they appear in the overlay with the water shader. @@ -792,11 +789,20 @@ def quickstart( except Exception as _e2: print(f" Waterway overlay burn skipped: {_e2}") - # 6c. Trace tributary network via flow_path + # 6c. Trace tributary network via flow_path (D8-only). + # Compute a lightweight D8 fd/fa from the raw terrain + # since the conditioned DEM was already released. if _water_geojson is not None: try: + from xrspatial import ( + flow_direction as _flow_direction, + flow_accumulation as _flow_accumulation, + ) + _fd_d8 = _flow_direction(terrain) + _fa_d8 = _flow_accumulation(_fd_d8) _trib, _ww_net = _trace_tributaries_flow_path( - fd, fa, _water_geojson, terrain, ocean) + _fd_d8, _fa_d8, _water_geojson, terrain, ocean) + del _fd_d8, _fa_d8 if _trib is not None: _so_np3 = so_out.get() if is_cupy else np.asarray(so_out) _so_np3 = np.nan_to_num(_so_np3, nan=0.0).astype( @@ -805,13 +811,11 @@ def quickstart( _sl_clean) _sl_np3 = np.array(_sl_np3, dtype=np.float32) _max_link3 = int(_sl_np3.max()) + 1 - # New waterway-channel cells → stream_order 3 _new_ww = _ww_net & (_sl_np3 == 0) & (~ocean) _sl_np3[_new_ww] = _max_link3 _so_np3[_new_ww] = np.maximum( _so_np3[_new_ww], 3.0) _max_link3 += 1 - # New tributary cells → stream_order 1 _new_trib = _trib & (_sl_np3 == 0) & (~ocean) _sl_np3[_new_trib] = _max_link3 _so_np3[_new_trib] = np.maximum( @@ -831,18 +835,18 @@ def quickstart( print(f" flow_path tracing skipped: {_e3}") # Drop xrspatial DataArray wrappers — we only need .data - del fd, fa, so, sl + del fd_mfd, fa_mfd, so, sl ds['stream_link'] = terrain.copy(data=_sl_clean).rename(None) hydro_data = { - 'flow_dir': fd_out, - 'flow_accum': fa_out, + 'flow_accum': fa_mfd_out, + 'flow_dir_mfd': fd_mfd_out, 'stream_order': so_out, 'stream_link': _sl_out, 'accum_threshold': 50, } - del fd_out, fa_out, _sl_out + del fa_mfd_out, fd_mfd_out, _sl_out # Pass overrides from explore_kwargs if present for key in ('n_particles', 'max_age', 'trail_len', 'speed', 'accum_threshold', 'color', 'alpha', 'dot_radius'): diff --git a/rtxpy/rtx.py b/rtxpy/rtx.py index d5d7356..c32ab20 100644 --- a/rtxpy/rtx.py +++ b/rtxpy/rtx.py @@ -43,6 +43,8 @@ class _GASEntry: is_curve: bool = False # True for round curve tube GAS is_heightfield: bool = False # True for heightfield custom primitive GAS is_sphere: bool = False # True for sphere primitive GAS (point clouds) + d_normals: Optional[cupy.ndarray] = None # GPU per-vertex normals (N*3 float32) + d_indices: Optional[cupy.ndarray] = None # GPU index buffer (M*3 int32) for normal lookup # ----------------------------------------------------------------------------- @@ -173,6 +175,9 @@ def __init__(self): self.point_colors = None # concatenated GPU buffer (built on demand) self.point_color_offsets = None # GPU int32 per-instance offsets + # Smooth normal table — GPU uint64 array [2*N], built in _build_ias + self.d_smooth_normal_table = None + # Device buffers for CPU->GPU transfers (per-instance) self.d_rays = None self.d_rays_size = 0 @@ -208,6 +213,9 @@ def clear(self): self.point_colors = None self.point_color_offsets = None + # Clear smooth normal table + self.d_smooth_normal_table = None + # Reset to single-GAS mode self.single_gas_mode = True @@ -693,8 +701,8 @@ def _init_optix(device: Optional[int] = None): # Create shader binding table _create_sbt() - # Allocate params buffer: 48 + 40 (heightfield) + 8 (point_colors) = 96 - _state.d_params = cupy.zeros(96, dtype=cupy.uint8) + # Allocate params buffer: 48 + 40 (heightfield) + 8 (point_colors) + 8 (smooth_normal_table) = 104 + _state.d_params = cupy.zeros(104, dtype=cupy.uint8) _state.initialized = True atexit.register(_cleanup_at_exit) @@ -1274,7 +1282,7 @@ def _build_gas_for_curves(vertices, widths, indices, num_segments): return gas_handle, gas_buffer -def _build_gas_for_heightfield(elevation_data, H, W, spacing_x, spacing_y, ve, tile_size): +def _build_gas_for_heightfield(elevation_data, H, W, spacing_x, spacing_y, ve, tile_size, active_mask=None): """ Build a GAS for heightfield terrain using custom AABB primitives. @@ -1289,6 +1297,9 @@ def _build_gas_for_heightfield(elevation_data, H, W, spacing_x, spacing_y, ve, t spacing_y: World-space pixel spacing in Y ve: Vertical exaggeration factor tile_size: Tile dimension (e.g. 32) + active_mask: Optional numpy bool array of length num_tiles. + When provided, inactive tiles get zero-volume AABBs so only + a subset of the heightfield grid participates in ray tracing. Returns: Tuple of (gas_handle, gas_buffer, d_elevation, num_tiles_x, num_tiles_y) @@ -1316,6 +1327,12 @@ def _build_gas_for_heightfield(elevation_data, H, W, spacing_x, spacing_y, ve, t for tx in range(num_tiles_x): tile_idx = ty * num_tiles_x + tx + # Skip inactive tiles (LOD-managed heightfield mode) + if active_mask is not None and not active_mask[tile_idx]: + base = tile_idx * 6 + aabbs[base:base + 6] = 0.0 + continue + # Cell range for this tile c0 = tx * tile_size r0 = ty * tile_size @@ -1592,6 +1609,19 @@ def _build_ias(geom_state: _GeometryState): geom_state.ias_dirty = False + # Build smooth normal lookup table: [2*i]=normals_ptr, [2*i+1]=indices_ptr + has_any_normals = any( + e.d_normals is not None for e in geom_state.gas_entries.values()) + if has_any_normals: + table = np.zeros(2 * num_instances, dtype=np.uint64) + for i, (gid, entry) in enumerate(geom_state.gas_entries.items()): + if entry.d_normals is not None and entry.d_indices is not None: + table[2 * i] = entry.d_normals.data.ptr + table[2 * i + 1] = entry.d_indices.data.ptr + geom_state.d_smooth_normal_table = cupy.asarray(table) + else: + geom_state.d_smooth_normal_table = None + def _build_accel(geom_state: _GeometryState, hash_value: int, vertices, indices) -> int: """ @@ -1852,8 +1882,13 @@ def _trace_rays(geom_state: _GeometryState, rays, hits, num_rays: int, if geom_state.point_colors is not None: pc_colors_ptr = geom_state.point_colors.data.ptr + # Smooth normal table pointer + sn_table_ptr = 0 + if geom_state.d_smooth_normal_table is not None: + sn_table_ptr = geom_state.d_smooth_normal_table.data.ptr + params_data = struct.pack( - 'QQQQQIIQiifffiiIQ', + 'QQQQQIIQiifffiiIQQ', trace_handle, # 8 d_rays.data.ptr, # 8 d_hits.data.ptr, # 8 @@ -1871,6 +1906,7 @@ def _trace_rays(geom_state: _GeometryState, rays, hits, num_rays: int, hf_ntx, # 4 0, # 4 padding for alignment pc_colors_ptr, # 8 + sn_table_ptr, # 8 ) _state.d_params[:] = cupy.frombuffer(np.frombuffer(params_data, dtype=np.uint8), dtype=cupy.uint8) @@ -1879,7 +1915,7 @@ def _trace_rays(geom_state: _GeometryState, rays, hits, num_rays: int, _state.pipeline, 0, # stream _state.d_params.data.ptr, - 96, # sizeof(Params) + 104, # sizeof(Params) _state.sbt, num_rays, # width 1, # height @@ -2078,7 +2114,8 @@ def pick(self, origin, direction) -> dict: def add_geometry(self, geometry_id: str, vertices, indices, transform: Optional[List[float]] = None, - grid_dims: Optional[tuple] = None) -> int: + grid_dims: Optional[tuple] = None, + normals=None) -> int: """ Add a geometry (GAS) to the scene with an optional transform. @@ -2095,6 +2132,10 @@ def add_geometry(self, geometry_id: str, vertices, indices, grid_dims: Optional (H, W) grid dimensions for cluster-accelerated builds. When provided and OptiX 9+ clusters are available, uses the CLAS pipeline for faster BVH builds. + normals: Optional per-vertex normal buffer (flattened float32 array, + 3 floats per vertex). When provided, the closest-hit + shader interpolates smooth normals using barycentrics + instead of computing flat face normals. Returns: 0 on success, non-zero on error @@ -2120,10 +2161,17 @@ def add_geometry(self, geometry_id: str, vertices, indices, existing = self._geom_state.gas_entries.get(geometry_id) if existing is not None and existing.vertices_hash == vertices_hash: - # GAS already built for identical vertices — update transform only + # GAS already built for identical vertices — update transform/normals only if transform is not None: existing.transform = list(transform) self._geom_state.ias_dirty = True + if normals is not None: + existing.d_normals = cupy.asarray( + np.asarray(normals, dtype=np.float32)) + if existing.d_indices is None: + existing.d_indices = cupy.asarray( + np.asarray(indices, dtype=np.int32)) + self._geom_state.ias_dirty = True return 0 # Build the GAS for this geometry @@ -2157,6 +2205,14 @@ def add_geometry(self, geometry_id: str, vertices, indices, indices_np = indices.get() if isinstance(indices, cupy.ndarray) else np.asarray(indices) num_triangles = len(indices_np.ravel()) // 3 + # Upload smooth normals and index buffer if provided + d_normals_gpu = None + d_indices_gpu = None + if normals is not None: + d_normals_gpu = cupy.asarray( + np.asarray(normals, dtype=np.float32)) + d_indices_gpu = cupy.asarray(indices_np) + # Create or update the GAS entry self._geom_state.gas_entries[geometry_id] = _GASEntry( gas_id=geometry_id, @@ -2166,6 +2222,8 @@ def add_geometry(self, geometry_id: str, vertices, indices, transform=transform, num_vertices=num_vertices, num_triangles=num_triangles, + d_normals=d_normals_gpu, + d_indices=d_indices_gpu, ) # Mark IAS as needing rebuild @@ -2258,7 +2316,9 @@ def add_heightfield_geometry(self, geometry_id: str, elevation, H: int, W: int, spacing_x: float, spacing_y: float, ve: float = 1.0, - tile_size: int = 32) -> int: + tile_size: int = 32, + active_mask=None, + transform=None) -> int: """ Add a heightfield terrain as a custom-primitive GAS. @@ -2277,6 +2337,10 @@ def add_heightfield_geometry(self, geometry_id: str, elevation, spacing_y: World-space pixel spacing in Y. ve: Vertical exaggeration. Default 1.0. tile_size: Tile dimension for AABB grouping. Default 32. + active_mask: Optional bool array (one per AABB tile). When + provided, inactive tiles get zero-volume AABBs. + transform: Optional 12-float affine transform (3x4 row-major). + Defaults to identity. Returns: 0 on success, non-zero on error. @@ -2300,7 +2364,7 @@ def add_heightfield_geometry(self, geometry_id: str, elevation, elev_np = np.asarray(elevation, dtype=np.float32) gas_handle, gas_buffer, d_elevation, num_tiles_x, num_tiles_y = \ - _build_gas_for_heightfield(elev_np, H, W, spacing_x, spacing_y, ve, tile_size) + _build_gas_for_heightfield(elev_np, H, W, spacing_x, spacing_y, ve, tile_size, active_mask) if gas_handle == 0: return -1 @@ -2315,12 +2379,12 @@ def add_heightfield_geometry(self, geometry_id: str, elevation, self._geom_state.hf_tile_size = tile_size self._geom_state.hf_num_tiles_x = num_tiles_x - # Identity transform - transform = [ - 1.0, 0.0, 0.0, 0.0, - 0.0, 1.0, 0.0, 0.0, - 0.0, 0.0, 1.0, 0.0, - ] + if transform is None: + transform = [ + 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 0.0, + ] # Compute hash for cache invalidation vertices_hash = hash(elev_np.tobytes()) diff --git a/rtxpy/tests/test_lod.py b/rtxpy/tests/test_lod.py index 5e475c4..4de22ab 100644 --- a/rtxpy/tests/test_lod.py +++ b/rtxpy/tests/test_lod.py @@ -5,15 +5,17 @@ from rtxpy.lod import ( compute_lod_level, + compute_lod_level_with_hysteresis, compute_lod_distances, + compute_tile_roughness, simplify_mesh, build_lod_chain, ) from rtxpy.viewer.terrain_lod import ( TerrainLODManager, is_terrain_lod_gid, - _add_tile_skirt, _tile_gid, + _batch_gid, ) @@ -51,6 +53,87 @@ def test_negative_distance(self): assert compute_lod_level(-10, [100, 200]) == 0 +# --------------------------------------------------------------------------- +# compute_lod_level_with_hysteresis +# --------------------------------------------------------------------------- + +class TestComputeLodLevelWithHysteresis: + """Tests for hysteresis-aware LOD selection.""" + + def test_no_prev_lod_matches_basic(self): + """First assignment (prev_lod=-1) should match compute_lod_level.""" + thresholds = [100, 200, 400] + assert compute_lod_level_with_hysteresis(50, thresholds, -1) == 0 + assert compute_lod_level_with_hysteresis(150, thresholds, -1) == 1 + assert compute_lod_level_with_hysteresis(500, thresholds, -1) == 3 + + def test_stays_at_current_lod_near_boundary(self): + """Camera near threshold boundary should not flip LOD.""" + thresholds = [100, 200, 400] + # At LOD 0, distance 105 is just past the 100 threshold + # but within the hysteresis band — should stay at LOD 0 + assert compute_lod_level_with_hysteresis( + 105, thresholds, prev_lod=0, hysteresis=0.2) == 0 + + def test_downgrades_past_hysteresis(self): + """Camera well past threshold should downgrade.""" + thresholds = [100, 200, 400] + # distance 125 / (1+0.2) = 104.2 → LOD 1 at base → downgrade + assert compute_lod_level_with_hysteresis( + 125, thresholds, prev_lod=0, hysteresis=0.2) == 1 + + def test_upgrades_past_hysteresis(self): + """Camera well inside better band should upgrade.""" + thresholds = [100, 200, 400] + # At LOD 1, distance 75 / (1-0.2) = 93.75 → LOD 0 → upgrade + assert compute_lod_level_with_hysteresis( + 75, thresholds, prev_lod=1, hysteresis=0.2) == 0 + + def test_stays_at_current_lod_upgrade_boundary(self): + """Camera near threshold from above should not upgrade.""" + thresholds = [100, 200, 400] + # At LOD 1, distance 95 → base LOD 0, but 95/(1-0.2) = 118.75 → LOD 1 + # so should stay at LOD 1 + assert compute_lod_level_with_hysteresis( + 95, thresholds, prev_lod=1, hysteresis=0.2) == 1 + + def test_same_lod_unchanged(self): + """If base LOD matches prev_lod, return prev_lod.""" + thresholds = [100, 200, 400] + assert compute_lod_level_with_hysteresis( + 50, thresholds, prev_lod=0, hysteresis=0.2) == 0 + assert compute_lod_level_with_hysteresis( + 150, thresholds, prev_lod=1, hysteresis=0.2) == 1 + + def test_dead_zone_spans_both_directions(self): + """Hysteresis dead zone should prevent both upgrade and downgrade. + + With threshold=100 and hysteresis=0.2: + - Downgrade requires distance > 100 * 1.2 = 120 + - Upgrade requires distance < 100 * 0.8 = 80 + - Between 80 and 120 is a dead zone where prev_lod sticks. + """ + thresholds = [100, 200, 400] + h = 0.2 + # At prev_lod=0 (below threshold): distances 101-119 in dead zone + for d in [101, 110, 119]: + assert compute_lod_level_with_hysteresis( + d, thresholds, prev_lod=0, hysteresis=h) == 0, \ + f"distance {d} should stay at LOD 0 (downgrade dead zone)" + # At 121, should finally downgrade + assert compute_lod_level_with_hysteresis( + 121, thresholds, prev_lod=0, hysteresis=h) == 1 + + # At prev_lod=1 (above threshold): distances 81-99 in dead zone + for d in [81, 90, 99]: + assert compute_lod_level_with_hysteresis( + d, thresholds, prev_lod=1, hysteresis=h) == 1, \ + f"distance {d} should stay at LOD 1 (upgrade dead zone)" + # At 79, should finally upgrade + assert compute_lod_level_with_hysteresis( + 79, thresholds, prev_lod=1, hysteresis=h) == 0 + + # --------------------------------------------------------------------------- # compute_lod_distances # --------------------------------------------------------------------------- @@ -75,6 +158,266 @@ def test_zero_max_lod(self): assert dists == [] +# --------------------------------------------------------------------------- +# compute_tile_roughness +# --------------------------------------------------------------------------- + +class TestComputeTileRoughness: + """Tests for bilinear-fit residual roughness metric.""" + + def test_flat_tile(self): + """A constant-elevation tile should have near-zero roughness.""" + tile = np.full((16, 16), 500.0, dtype=np.float32) + assert compute_tile_roughness(tile) == pytest.approx(0.0, abs=1e-4) + + def test_planar_tile(self): + """A perfectly planar (tilted) tile should have ~zero roughness. + + The bilinear fit matches a linear surface exactly, so residuals + are zero everywhere. + """ + ys = np.arange(16).reshape(16, 1).astype(np.float32) + xs = np.arange(16).reshape(1, 16).astype(np.float32) + tile = 100.0 + 3.0 * xs + 2.0 * ys + assert compute_tile_roughness(tile) == pytest.approx(0.0, abs=1e-4) + + def test_rough_tile(self): + """A tile with a central peak should have non-zero roughness.""" + tile = np.zeros((16, 16), dtype=np.float32) + tile[7:9, 7:9] = 100.0 # sharp peak + r = compute_tile_roughness(tile) + assert r > 1.0 + + def test_rougher_is_higher(self): + """A tile with bigger deviation should score higher.""" + tile_mild = np.zeros((16, 16), dtype=np.float32) + tile_mild[8, 8] = 10.0 + tile_wild = np.zeros((16, 16), dtype=np.float32) + tile_wild[8, 8] = 1000.0 + assert compute_tile_roughness(tile_wild) > compute_tile_roughness(tile_mild) + + def test_all_nan(self): + """All-NaN tile should return zero roughness.""" + tile = np.full((8, 8), np.nan, dtype=np.float32) + assert compute_tile_roughness(tile) == 0.0 + + def test_partial_nan(self): + """Tile with some NaN values should still return a valid float.""" + tile = np.ones((8, 8), dtype=np.float32) * 50.0 + tile[3:5, 3:5] = np.nan + r = compute_tile_roughness(tile) + assert np.isfinite(r) + + def test_nan_corner(self): + """NaN corner should be filled with mean of valid corners.""" + tile = np.zeros((8, 8), dtype=np.float32) + tile[0, 0] = np.nan # one corner NaN + r = compute_tile_roughness(tile) + assert np.isfinite(r) + + def test_tiny_tile(self): + """Tiles smaller than 2x2 should return 0.""" + assert compute_tile_roughness(np.array([[5.0]])) == 0.0 + assert compute_tile_roughness(np.zeros((1, 10))) == 0.0 + + +# --------------------------------------------------------------------------- +# compute_terrain_normals +# --------------------------------------------------------------------------- + +class TestComputeTerrainNormals: + """Tests for central-difference terrain normal computation.""" + + def test_flat_terrain(self): + """Flat terrain should produce all (0, 0, 1) normals.""" + from rtxpy.mesh import compute_terrain_normals + terrain = np.full((4, 4), 100.0, dtype=np.float32) + normals = compute_terrain_normals(terrain, 4, 4) + assert normals.shape == (4 * 4 * 3,) + nx = normals[0::3] + ny = normals[1::3] + nz = normals[2::3] + np.testing.assert_allclose(nx, 0.0, atol=1e-6) + np.testing.assert_allclose(ny, 0.0, atol=1e-6) + np.testing.assert_allclose(nz, 1.0, atol=1e-6) + + def test_x_slope(self): + """Constant slope in X (z = col) with psx=1 should tilt normals.""" + from rtxpy.mesh import compute_terrain_normals + H, W = 3, 5 + terrain = np.zeros((H, W), dtype=np.float32) + for c in range(W): + terrain[:, c] = float(c) + normals = compute_terrain_normals(terrain, H, W, psx=1.0, psy=1.0) + # Interior vertices: dz/dx = 1, normal = normalize(-1, 0, 1) + expected_nx = -1.0 / np.sqrt(2.0) + expected_nz = 1.0 / np.sqrt(2.0) + # Check an interior vertex (row=1, col=2) + idx = 1 * W + 2 + assert normals[idx * 3] == pytest.approx(expected_nx, abs=1e-5) + assert normals[idx * 3 + 1] == pytest.approx(0.0, abs=1e-5) + assert normals[idx * 3 + 2] == pytest.approx(expected_nz, abs=1e-5) + + def test_y_slope(self): + """Constant slope in Y (z = row) should tilt normals in Y.""" + from rtxpy.mesh import compute_terrain_normals + H, W = 5, 3 + terrain = np.zeros((H, W), dtype=np.float32) + for r in range(H): + terrain[r, :] = float(r) + normals = compute_terrain_normals(terrain, H, W, psx=1.0, psy=1.0) + expected_ny = -1.0 / np.sqrt(2.0) + expected_nz = 1.0 / np.sqrt(2.0) + idx = 2 * W + 1 + assert normals[idx * 3] == pytest.approx(0.0, abs=1e-5) + assert normals[idx * 3 + 1] == pytest.approx(expected_ny, abs=1e-5) + assert normals[idx * 3 + 2] == pytest.approx(expected_nz, abs=1e-5) + + def test_nan_elevation_gets_up_normal(self): + """NaN elevation pixels should get (0, 0, 1).""" + from rtxpy.mesh import compute_terrain_normals + terrain = np.ones((4, 4), dtype=np.float32) * 50.0 + terrain[1, 2] = np.nan + normals = compute_terrain_normals(terrain, 4, 4) + idx = 1 * 4 + 2 + assert normals[idx * 3] == pytest.approx(0.0, abs=1e-6) + assert normals[idx * 3 + 1] == pytest.approx(0.0, abs=1e-6) + assert normals[idx * 3 + 2] == pytest.approx(1.0, abs=1e-6) + + def test_pixel_spacing_affects_normals(self): + """Wider pixel spacing should flatten normals (smaller nx).""" + from rtxpy.mesh import compute_terrain_normals + H, W = 3, 5 + terrain = np.zeros((H, W), dtype=np.float32) + for c in range(W): + terrain[:, c] = float(c) + n1 = compute_terrain_normals(terrain, H, W, psx=1.0, psy=1.0) + n10 = compute_terrain_normals(terrain, H, W, psx=10.0, psy=1.0) + idx = 1 * W + 2 + # With psx=10, dz/dx = 1/10, so |nx| should be much smaller + assert abs(n10[idx * 3]) < abs(n1[idx * 3]) + + def test_all_unit_length(self): + """All normals should be unit-length.""" + from rtxpy.mesh import compute_terrain_normals + terrain = np.random.RandomState(42).rand(8, 8).astype(np.float32) * 100 + normals = compute_terrain_normals(terrain, 8, 8) + nx = normals[0::3] + ny = normals[1::3] + nz = normals[2::3] + lengths = np.sqrt(nx**2 + ny**2 + nz**2) + np.testing.assert_allclose(lengths, 1.0, atol=1e-5) + + +# --------------------------------------------------------------------------- +# compute_vertex_normals +# --------------------------------------------------------------------------- + +class TestComputeVertexNormals: + """Tests for area-weighted smooth vertex normal computation.""" + + def test_flat_quad(self): + """Two-triangle flat quad should produce all (0, 0, 1) normals.""" + from rtxpy.mesh import compute_vertex_normals + verts = np.array([ + 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, + ], dtype=np.float32) + indices = np.array([0, 1, 2, 0, 2, 3], dtype=np.int32) + normals = compute_vertex_normals(verts, indices) + assert normals.shape == (12,) + nz = normals[2::3] + np.testing.assert_allclose(nz, 1.0, atol=1e-6) + np.testing.assert_allclose(normals[0::3], 0.0, atol=1e-6) + + def test_dihedral_edge(self): + """Shared edge of 90° dihedral: vertex normals should be averaged.""" + from rtxpy.mesh import compute_vertex_normals + # Two triangles sharing edge along X axis, forming a V + verts = np.array([ + 0, 0, 0, # 0: left of shared edge + 2, 0, 0, # 1: right of shared edge + 1, 1, 0, # 2: flat face vertex + 1, -1, 0, # 3: also flat (mirror) + ], dtype=np.float32) + # Both triangles lie in z=0 plane with different orientations + indices = np.array([0, 1, 2, 0, 3, 1], dtype=np.int32) + normals = compute_vertex_normals(verts, indices) + # All vertices have faces in z=0, so nz should dominate + nz = normals[2::3] + assert all(abs(n) > 0.5 for n in nz) + + def test_isolated_vertex_gets_up_normal(self): + """Vertices not referenced by any triangle should get (0,0,1).""" + from rtxpy.mesh import compute_vertex_normals + verts = np.array([ + 0, 0, 0, 1, 0, 0, 0, 1, 0, # triangle + 5, 5, 5, # isolated vertex + ], dtype=np.float32) + indices = np.array([0, 1, 2], dtype=np.int32) + normals = compute_vertex_normals(verts, indices) + # Isolated vertex (index 3) should be (0, 0, 1) + assert normals[9] == pytest.approx(0.0, abs=1e-6) + assert normals[10] == pytest.approx(0.0, abs=1e-6) + assert normals[11] == pytest.approx(1.0, abs=1e-6) + + def test_all_unit_length(self): + """All normals should be unit-length.""" + from rtxpy.mesh import compute_vertex_normals + rng = np.random.RandomState(42) + verts = rng.rand(30).astype(np.float32) + indices = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 3, 6], + dtype=np.int32) + normals = compute_vertex_normals(verts, indices) + nx = normals[0::3] + ny = normals[1::3] + nz = normals[2::3] + lengths = np.sqrt(nx**2 + ny**2 + nz**2) + np.testing.assert_allclose(lengths, 1.0, atol=1e-5) + + +# --------------------------------------------------------------------------- +# compute_skirt_normals +# --------------------------------------------------------------------------- + +class TestComputeSkirtNormals: + """Tests for skirt outward-facing normal computation.""" + + def test_output_shape(self): + from rtxpy.mesh import compute_skirt_normals + H, W = 5, 7 + normals = compute_skirt_normals(H, W) + n_perim = 2 * (H + W) - 4 + assert normals.shape == (n_perim * 3,) + + def test_edge_directions(self): + """Top/right/bottom/left edges should have correct outward normals.""" + from rtxpy.mesh import compute_skirt_normals + H, W = 5, 7 + normals = compute_skirt_normals(H, W).reshape(-1, 3) + off = 0 + # Top: W verts, normal (0, -1, 0) + for i in range(W): + np.testing.assert_allclose(normals[off + i], [0, -1, 0], atol=1e-6) + off += W + # Right: H-1 verts, normal (1, 0, 0) + for i in range(H - 1): + np.testing.assert_allclose(normals[off + i], [1, 0, 0], atol=1e-6) + off += H - 1 + # Bottom: W-1 verts, normal (0, 1, 0) + for i in range(W - 1): + np.testing.assert_allclose(normals[off + i], [0, 1, 0], atol=1e-6) + off += W - 1 + # Left: H-2 verts, normal (-1, 0, 0) + for i in range(H - 2): + np.testing.assert_allclose(normals[off + i], [-1, 0, 0], atol=1e-6) + + def test_all_unit_length(self): + from rtxpy.mesh import compute_skirt_normals + normals = compute_skirt_normals(10, 10).reshape(-1, 3) + lengths = np.sqrt(np.sum(normals**2, axis=1)) + np.testing.assert_allclose(lengths, 1.0, atol=1e-6) + + # --------------------------------------------------------------------------- # simplify_mesh # --------------------------------------------------------------------------- @@ -161,7 +504,25 @@ def __init__(self): self.geometries = {} def add_geometry(self, gid, verts, indices, **kw): - self.geometries[gid] = (verts.copy(), indices.copy()) + normals = kw.get('normals') + self.geometries[gid] = ( + verts.copy(), indices.copy(), + normals.copy() if normals is not None else None) + return 0 + + def add_heightfield_geometry(self, gid, elevation, H, W, + spacing_x, spacing_y, ve=1.0, + tile_size=32, active_mask=None, + transform=None): + """Stub for heightfield GAS — stores metadata for test assertions.""" + self.geometries[gid] = { + 'type': 'heightfield', + 'H': H, 'W': W, + 'spacing_x': spacing_x, 'spacing_y': spacing_y, + 've': ve, 'tile_size': tile_size, + 'active_mask': active_mask.copy() if active_mask is not None else None, + 'transform': list(transform) if transform is not None else None, + } return 0 def remove_geometry(self, gid): @@ -225,6 +586,8 @@ def test_lod_varies_with_distance(self): pixel_spacing_x=1.0, pixel_spacing_y=1.0, max_lod=3, lod_distance_factor=1.0, ) + # Allow all 16 tiles to build in one pass for this test + mgr.per_tick_build_limit = 100 # Camera at corner — near tiles get LOD 0, far tiles get higher LOD mgr.update(np.array([0, 0, 0]), rtx, force=True) lods = mgr.tile_lods @@ -316,23 +679,22 @@ def test_boundary_shared_at_higher_subsample(self): f"tile(0,1) min_x={min_x_01}" ) - def test_interior_tile_has_no_skirt(self): - """Interior tiles (not at terrain edge) should have no skirt. #79""" + def test_interior_tile_no_skirt(self): + """Tiles have no skirt — edge stitching replaces skirts.""" terrain = self._make_terrain(256, 256) rtx = _FakeRTX() mgr = TerrainLODManager(terrain, tile_size=64, pixel_spacing_x=1.0, pixel_spacing_y=1.0) + mgr.per_tick_build_limit = 100 mgr.update(np.array([128, 128, 0]), rtx, force=True) - # Tile (1,1) is fully interior — no edges touch terrain boundary. - # Its mesh should have no skirt (vertex count = grid verts only). + # Tile (1,1) is fully interior — should have only grid verts, no skirt gid = _tile_gid(1, 1) verts = rtx.geometries[gid][0] - # With +1 overlap and subsample=1, tile covers 65x65 grid - # (64 tile_size + 1 overlap). No skirt → exactly 65*65 verts. + n_grid = 65 * 65 n_verts = len(verts) // 3 - assert n_verts == 65 * 65, ( + assert n_verts == n_grid, ( f"Interior tile should have no skirt, got {n_verts} verts " - f"(expected {65 * 65})" + f"(expected {n_grid})" ) def test_get_stats(self): @@ -345,7 +707,542 @@ def test_get_stats(self): rtx = _FakeRTX() mgr.update(np.array([64, 64, 0]), rtx, force=True) stats = mgr.get_stats() - assert "LOD tiles:" in stats + assert "LOD:" in stats + assert "tiles" in stats + + def test_set_terrain_updates_tile_grid(self): + """set_terrain with different shape must update tile count.""" + terrain = self._make_terrain(128, 128) + mgr = TerrainLODManager(terrain, tile_size=64) + assert mgr.n_tiles == 4 # 2x2 + + # Replace with larger terrain + terrain2 = self._make_terrain(256, 256) + mgr.set_terrain(terrain2) + assert mgr.n_tiles == 16 # 4x4 + + def test_stale_tiles_evicted_from_cache(self): + """Tiles leaving distance range should be evicted from mesh cache.""" + terrain = self._make_terrain(512, 512) + rtx = _FakeRTX() + mgr = TerrainLODManager( + terrain, tile_size=128, + pixel_spacing_x=1.0, pixel_spacing_y=1.0, + max_lod=3, lod_distance_factor=1.0, + ) + mgr.per_tick_build_limit = 100 + # Build tiles near corner + mgr.update(np.array([0, 0, 0]), rtx, force=True) + cache_keys_before = set(mgr._tile_cache.keys()) + assert len(cache_keys_before) > 0 + + # Move camera far away so original tiles leave range + mgr.update(np.array([10000, 10000, 0]), rtx, force=True) + # Old tile cache entries should be evicted + cache_keys_after = set(mgr._tile_cache.keys()) + evicted = cache_keys_before - cache_keys_after + assert len(evicted) > 0, "Stale tile cache entries were not evicted" + + def test_offset_shifts_tile_vertices(self): + """World offset should shift all tile vertex positions.""" + terrain = self._make_terrain(128, 128) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + mgr.update(np.array([64, 64, 0]), rtx, force=True) + gid = _tile_gid(0, 0) + verts_no_offset = rtx.geometries[gid][0].copy() + + # Apply an offset and rebuild + mgr.set_offset(100.0, 200.0) + mgr.update(np.array([164, 264, 0]), rtx, force=True) + verts_with_offset = rtx.geometries[gid][0] + + # X coords should be shifted by 100, Y by 200 + x_diff = verts_with_offset[0::3] - verts_no_offset[0::3] + y_diff = verts_with_offset[1::3] - verts_no_offset[1::3] + np.testing.assert_allclose(x_diff, 100.0, atol=0.01) + np.testing.assert_allclose(y_diff, 200.0, atol=0.01) + + def test_set_terrain_with_offset(self): + """set_terrain with offset should shift subsequent tile vertices.""" + terrain = self._make_terrain(128, 128) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + # Replace terrain with an offset + terrain2 = self._make_terrain(128, 128) + mgr.set_terrain(terrain2, offset_x=50.0, offset_y=75.0) + mgr.update(np.array([114, 139, 0]), rtx, force=True) + gid = _tile_gid(0, 0) + verts = rtx.geometries[gid][0] + # Min X should be at offset (50.0), min Y at offset (75.0) + assert float(np.min(verts[0::3])) == pytest.approx(50.0) + assert float(np.min(verts[1::3])) == pytest.approx(75.0) + + def test_streaming_creates_tiles_beyond_bounds(self): + """Streaming callback should produce tiles at negative indices.""" + terrain = self._make_terrain(128, 128) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + + # Callback returns a flat elevation grid + def fake_tile_fn(x_min, y_min, x_max, y_max, target_samples): + return np.full((target_samples, target_samples), 50.0, + dtype=np.float32) + + mgr.set_tile_data_fn(fake_tile_fn) + assert mgr._streaming + + mgr.per_tick_build_limit = 100 + # Camera beyond the initial terrain bounds (x < 0) + mgr.update(np.array([-200, 64, 0]), rtx, force=True) + + # Should have tiles with negative column indices + neg_tiles = [gid for gid in rtx.geometries + if is_terrain_lod_gid(gid) and '_c-' in gid] + assert len(neg_tiles) > 0, ( + f"Expected streaming tiles at negative col, got: " + f"{list(rtx.geometries.keys())}" + ) + + def test_streaming_tile_positions_correct(self): + """Streaming tiles should be positioned at correct world coords.""" + terrain = self._make_terrain(64, 64) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=10.0, pixel_spacing_y=10.0) + + # Fixed elevation for easy verification + def fake_tile_fn(x_min, y_min, x_max, y_max, target_samples): + return np.full((target_samples, target_samples), 42.0, + dtype=np.float32) + + mgr.set_tile_data_fn(fake_tile_fn) + mgr.per_tick_build_limit = 100 + # Camera at tile (0, 1) which is just past the initial terrain + # (initial terrain is 1×1 tiles at 64px × 10.0 spacing = 640 world units) + mgr.update(np.array([960, 320, 0]), rtx, force=True) + + gid = _tile_gid(0, 1) + if gid in rtx.geometries: + verts = rtx.geometries[gid][0] + x_min_v = float(np.min(verts[0::3])) + # Tile (0,1) starts at col 64 → x = 64 * 10.0 = 640.0 + assert x_min_v == pytest.approx(640.0, abs=1.0) + + def test_streaming_disabled_by_default(self): + """Streaming should be off when no callback is set.""" + terrain = self._make_terrain(128, 128) + mgr = TerrainLODManager(terrain, tile_size=64) + assert not mgr._streaming + assert mgr._tile_data_fn is None + + def test_set_tile_data_fn_none_disables_streaming(self): + """Setting callback to None should disable streaming.""" + terrain = self._make_terrain(128, 128) + mgr = TerrainLODManager(terrain, tile_size=64) + + mgr.set_tile_data_fn(lambda *a: None) + assert mgr._streaming + + mgr.set_tile_data_fn(None) + assert not mgr._streaming + + def test_streaming_callback_receives_correct_bounds(self): + """Callback should receive tile world bounds and its data should + appear in the built mesh Z values.""" + terrain = self._make_terrain(64, 64) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=10.0, pixel_spacing_y=10.0) + received_calls = [] + + def tracking_fn(x_min, y_min, x_max, y_max, target_samples): + received_calls.append({ + 'x_min': x_min, 'y_min': y_min, + 'x_max': x_max, 'y_max': y_max, + 'target_samples': target_samples, + }) + # Return a distinctive elevation + return np.full((target_samples, target_samples), 999.0, + dtype=np.float32) + + mgr.set_tile_data_fn(tracking_fn) + mgr.per_tick_build_limit = 100 + # Camera at tile (0, 2) — beyond initial 1×1 grid + mgr.update(np.array([1600, 320, 0]), rtx, force=True) + + # Verify callback was called for out-of-bounds tiles + assert len(received_calls) > 0, "Callback was never called" + + # Verify bounds are sensible (positive width/height) + for call in received_calls: + assert call['x_max'] > call['x_min'] + assert call['y_max'] > call['y_min'] + assert call['target_samples'] >= 2 + + # Verify the distinctive elevation appears in mesh Z values + gid = _tile_gid(0, 2) + if gid in rtx.geometries: + verts = rtx.geometries[gid][0] + # Surface verts (not skirt) should have z=999.0 + z_vals = verts[2::3] + assert float(np.max(z_vals)) == pytest.approx(999.0, abs=0.1) + + def test_streaming_stats_no_denominator(self): + """Streaming mode stats should not show misleading total.""" + terrain = self._make_terrain(128, 128) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + mgr.set_tile_data_fn( + lambda *a: np.full((a[4], a[4]), 50.0, dtype=np.float32)) + mgr.per_tick_build_limit = 100 + mgr.update(np.array([64, 64, 0]), rtx, force=True) + stats = mgr.get_stats() + assert '/' not in stats, ( + f"Streaming stats should not have active/total format: {stats}") + assert 'tiles' in stats + + def test_hysteresis_prevents_flip(self): + """Tiles near LOD boundary should not flip on small movements.""" + terrain = self._make_terrain(512, 512) + rtx = _FakeRTX() + mgr = TerrainLODManager( + terrain, tile_size=128, + pixel_spacing_x=1.0, pixel_spacing_y=1.0, + max_lod=3, lod_distance_factor=1.0, + ) + mgr.per_tick_build_limit = 100 + # Place camera so some tiles are near LOD threshold + mgr.update(np.array([0, 0, 0]), rtx, force=True) + lods_1 = mgr.tile_lods + + # Small movement — tiles near boundary should keep their LOD + mgr.update(np.array([5, 5, 0]), rtx, force=True) + lods_2 = mgr.tile_lods + + # At least the tiles far from boundaries should be stable + stable = sum(1 for k in lods_1 if k in lods_2 and lods_1[k] == lods_2[k]) + assert stable > 0, "No tiles maintained their LOD across small movement" + + def test_tiles_have_normals(self): + """LOD tiles should pass per-vertex normals to add_geometry.""" + terrain = self._make_terrain(128, 128) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + mgr.per_tick_build_limit = 100 + mgr.update(np.array([64, 64, 0]), rtx, force=True) + for gid, (verts, indices, normals) in rtx.geometries.items(): + assert normals is not None, f"Tile {gid} missing normals" + n_verts = len(verts) // 3 + assert len(normals) == n_verts * 3, ( + f"Tile {gid}: normals length {len(normals)} != " + f"verts count {n_verts} * 3") + # All normals should be unit-length + nx = normals[0::3] + ny = normals[1::3] + nz = normals[2::3] + lengths = np.sqrt(nx**2 + ny**2 + nz**2) + np.testing.assert_allclose(lengths, 1.0, atol=1e-4, + err_msg=f"Non-unit normals in {gid}") + + def test_threaded_building(self): + """Threaded mesh building should produce tiles over multiple ticks.""" + import time + terrain = self._make_terrain(256, 256) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + mgr.per_tick_build_limit = 2 # only 2 builds per tick + mgr.enable_threaded_building(max_workers=2) + + # First tick: submits builds to the thread pool + mgr.update(np.array([128, 128, 0]), rtx, force=True) + + # Subsequent ticks collect completed futures — allow time + # for the thread pool to finish (builds are fast but async) + for _ in range(50): + time.sleep(0.02) + mgr.update(np.array([128, 128, 0]), rtx) + if not mgr._has_in_flight_work and not mgr._pending_futures: + break + + # All visible tiles should eventually be built + assert len(rtx.geometries) > 0, "No tiles built with threaded building" + # Verify normals are present + for gid, (verts, indices, normals) in rtx.geometries.items(): + assert normals is not None, f"Threaded tile {gid} missing normals" + assert len(normals) == (len(verts) // 3) * 3 + + mgr.shutdown() + + def test_threaded_shutdown_cancels_pending(self): + """Shutdown should cancel in-flight futures.""" + terrain = self._make_terrain(256, 256) + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + mgr.enable_threaded_building() + mgr.shutdown() + assert mgr._executor is None + assert len(mgr._pending_futures) == 0 + assert len(mgr._io_futures) == 0 + assert not mgr._threaded + + def test_build_retry_budget(self): + """Tiles that fail repeatedly should stop retrying.""" + from unittest.mock import patch + terrain = self._make_terrain(128, 128) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + + # Simulate a tile that has exhausted its retry budget + fail_key = (0, 0, 1, 1) # (tr, tc, lod, base_sub) + mgr._build_retries[fail_key] = mgr._MAX_BUILD_RETRIES + # Verify the tile is skipped: force a queue entry for this tile + queue = [(1.0, 0, 0, 1, 'terrain_lod_r0_c0')] + changed, pending = mgr._process_tile_queue(queue, rtx, ve=1.0) + assert fail_key not in mgr._pending_futures + assert fail_key not in mgr._tile_cache + + # After cancel_pending, retries should be cleared (terrain reload) + mgr._cancel_pending() + assert len(mgr._build_retries) == 0 + + def test_streaming_io_prefetch(self): + """Streaming tiles should prefetch I/O ahead of mesh builds.""" + import time + terrain = self._make_terrain(128, 128) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + fetch_calls = [] + + def tracking_fn(x_min, y_min, x_max, y_max, target_samples): + fetch_calls.append((x_min, y_min, x_max, y_max)) + return np.full((target_samples, target_samples), 100.0, + dtype=np.float32) + + mgr.set_tile_data_fn(tracking_fn) + mgr.per_tick_build_limit = 2 # tight limit to force prefetch + mgr.enable_threaded_building(max_workers=2) + + # Position camera so some tiles are out-of-bounds (streaming) + mgr.update(np.array([64, 64, 0]), rtx, force=True) + + # Let threads complete and collect results + for _ in range(50): + time.sleep(0.02) + mgr.update(np.array([64, 64, 0]), rtx) + if not mgr._has_in_flight_work: + break + + # Should have tiles built — both in-bounds and streaming + assert len(rtx.geometries) > 0 + # The tile_data_fn should have been called for out-of-bounds tiles + assert len(fetch_calls) > 0 + mgr.shutdown() + + def test_batched_upload_reduces_gas_count(self): + """Batched mode should produce fewer GAS entries than tiles.""" + terrain = self._make_terrain(256, 256) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0, + max_lod=2, lod_distance_factor=1.0) + mgr.per_tick_build_limit = 100 + mgr.enable_batched_upload() + mgr.update(np.array([128, 128, 0]), rtx, force=True) + + # Scene should have batch GAS entries, not individual tile GAS + n_tiles = len(mgr._tile_lods) + assert n_tiles > 0, "No tiles assigned LOD" + n_gas = len(rtx.geometries) + assert n_gas < n_tiles, ( + f"Batch mode should reduce GAS count: {n_gas} GAS >= {n_tiles} tiles") + # All GAS IDs should be batch IDs + for gid in rtx.geometries: + assert gid.startswith('terrain_lod_batch_L'), ( + f"Non-batch GAS ID in scene: {gid}") + + def test_batched_upload_correct_geometry(self): + """Batched tiles should produce valid concatenated geometry.""" + terrain = self._make_terrain(128, 128) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + mgr.per_tick_build_limit = 100 + mgr.enable_batched_upload() + mgr.update(np.array([64, 64, 0]), rtx, force=True) + + for gid, (verts, indices, normals) in rtx.geometries.items(): + n_verts = len(verts) // 3 + assert normals is not None, f"Batch {gid} missing normals" + assert len(normals) == n_verts * 3 + # All indices should be within vertex bounds + assert np.all(indices >= 0) + assert np.all(indices < n_verts), ( + f"Index out of range in {gid}: max={np.max(indices)}, " + f"n_verts={n_verts}") + + def test_batched_remove_all(self): + """remove_all should clear batch GAS entries.""" + terrain = self._make_terrain(128, 128) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + mgr.per_tick_build_limit = 100 + mgr.enable_batched_upload() + mgr.update(np.array([64, 64, 0]), rtx, force=True) + assert len(rtx.geometries) > 0 + mgr.remove_all(rtx) + for gid in rtx.geometries: + assert not is_terrain_lod_gid(gid) + assert len(mgr._batch_gids) == 0 + assert len(mgr._lod_tile_meshes) == 0 + + def test_batched_stats_show_gas_count(self): + """Stats should report batch GAS count.""" + terrain = self._make_terrain(128, 128) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + mgr.per_tick_build_limit = 100 + mgr.enable_batched_upload() + mgr.update(np.array([64, 64, 0]), rtx, force=True) + stats = mgr.get_stats() + assert 'GAS' in stats + + def test_batched_stale_tile_eviction(self): + """Tiles leaving distance range should be unstaged from batches.""" + terrain = self._make_terrain(256, 256) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0, + max_lod=2, lod_distance_factor=0.5) + mgr.per_tick_build_limit = 100 + mgr.enable_batched_upload() + # Camera at corner — only nearby tiles in range + mgr.update(np.array([0, 0, 0]), rtx, force=True) + lods_1 = set(mgr._tile_lods.keys()) + # Move far away — some tiles should become stale + mgr.update(np.array([256, 256, 0]), rtx, force=True) + lods_2 = set(mgr._tile_lods.keys()) + # Some tiles from position 1 should have been evicted + evicted = lods_1 - lods_2 + # Evicted tiles should not appear in any batch + for lod, tiles in mgr._lod_tile_meshes.items(): + for k in evicted: + assert k not in tiles, ( + f"Evicted tile {k} still in LOD {lod} batch") + + +# --------------------------------------------------------------------------- +# Terrain-adaptive LOD (roughness) +# --------------------------------------------------------------------------- + +class TestTerrainAdaptiveLOD: + """Tests for roughness-based LOD threshold adaptation.""" + + @staticmethod + def _make_terrain(H, W, elevation=100.0): + return np.full((H, W), elevation, dtype=np.float32) + + def test_flat_terrain_uniform_roughness(self): + """All tiles on a flat terrain should get neutral roughness (1.0).""" + terrain = self._make_terrain(256, 256) + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + for scale in mgr._tile_roughness.values(): + assert scale == pytest.approx(1.0) + + def test_rough_tile_gets_higher_scale(self): + """A tile with a peak should get roughness_scale > 1.""" + terrain = self._make_terrain(256, 256) + # Add a sharp peak to tile (1, 1) — rows 64:128, cols 64:128 + terrain[90:100, 90:100] = 5000.0 + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + rough_scale = mgr._tile_roughness.get((1, 1), 1.0) + # Should be promoted (scale > 1) + assert rough_scale > 1.0, f"Rough tile scale {rough_scale} not > 1" + + def test_smooth_tile_gets_lower_scale(self): + """A flat tile adjacent to a rough one should get scale < 1.""" + terrain = self._make_terrain(256, 256) + # Make one tile very rough + terrain[90:100, 90:100] = 5000.0 + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + # Tile (0, 0) is flat — should be demoted + smooth_scale = mgr._tile_roughness.get((0, 0), 1.0) + assert smooth_scale < 1.0, f"Smooth tile scale {smooth_scale} not < 1" + + def test_roughness_affects_lod_assignment(self): + """Rough tiles should get finer LOD than smooth tiles at same distance.""" + terrain = self._make_terrain(512, 512) + # Make tile (2, 2) rough — rows 256:384, cols 256:384 + terrain[300:320, 300:320] = 3000.0 + rtx = _FakeRTX() + mgr = TerrainLODManager( + terrain, tile_size=128, + pixel_spacing_x=1.0, pixel_spacing_y=1.0, + max_lod=3, lod_distance_factor=1.0, + ) + mgr.per_tick_build_limit = 100 + # Camera at center — all tiles roughly equidistant + mgr.update(np.array([256, 256, 0]), rtx, force=True) + lods = mgr.tile_lods + + # Rough tile (2, 2) should have same or lower LOD number (higher + # detail) than the smooth tile (0, 0) at comparable distance + rough_lod = lods.get((2, 2)) + smooth_lod = lods.get((0, 0)) + if rough_lod is not None and smooth_lod is not None: + assert rough_lod <= smooth_lod, ( + f"Rough tile LOD {rough_lod} > smooth tile LOD {smooth_lod}") + + def test_set_terrain_recomputes_roughness(self): + """Replacing terrain should recompute roughness.""" + terrain1 = self._make_terrain(128, 128) + mgr = TerrainLODManager(terrain1, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + # All uniform initially + for s in mgr._tile_roughness.values(): + assert s == pytest.approx(1.0) + + # Replace with terrain that has one rough tile + terrain2 = self._make_terrain(128, 128) + terrain2[10:20, 10:20] = 999.0 + mgr.set_terrain(terrain2) + # Should now have non-uniform roughness + scales = list(mgr._tile_roughness.values()) + assert max(scales) > min(scales), "Roughness not recomputed" + + def test_roughness_scale_range(self): + """Roughness scales should fall within [0.5, 2.0].""" + rng = np.random.RandomState(42) + terrain = rng.randn(256, 256).astype(np.float32) * 100 + # Make one tile extra rough + terrain[50:70, 50:70] += 5000.0 + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + for scale in mgr._tile_roughness.values(): + assert 0.5 - 1e-6 <= scale <= 2.0 + 1e-6, ( + f"Scale {scale} outside [0.5, 2.0]") + + def test_streaming_tiles_get_neutral_roughness(self): + """Streaming tiles (out of bounds) should use default scale 1.0.""" + terrain = self._make_terrain(128, 128) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + # Out-of-bounds tile (10, 10) — not in _tile_roughness + assert mgr._tile_roughness.get((10, 10), 1.0) == 1.0 # --------------------------------------------------------------------------- @@ -357,73 +1254,623 @@ def test_tile_gid_format(self): assert _tile_gid(0, 0) == "terrain_lod_r0_c0" assert _tile_gid(3, 7) == "terrain_lod_r3_c7" + def test_batch_gid_format(self): + assert _batch_gid(0) == "terrain_lod_batch_L0" + assert _batch_gid(3) == "terrain_lod_batch_L3" + def test_is_terrain_lod_gid(self): assert is_terrain_lod_gid("terrain_lod_r0_c0") assert is_terrain_lod_gid("terrain_lod_r12_c34") + assert is_terrain_lod_gid("terrain_lod_batch_L0") + assert is_terrain_lod_gid("terrain_lod_batch_L3") + assert is_terrain_lod_gid("terrain_lod_hf") assert not is_terrain_lod_gid("terrain") assert not is_terrain_lod_gid("terrain_skirt") assert not is_terrain_lod_gid("buildings_0") -class TestAddTileSkirt: - def test_adds_skirt_vertices(self): - """Skirt should add perimeter + wall vertices.""" - H, W = 4, 4 - n_verts = H * W - n_tris = (H - 1) * (W - 1) * 2 +# --------------------------------------------------------------------------- +# Edge stitching +# --------------------------------------------------------------------------- - verts = np.zeros(n_verts * 3, dtype=np.float32) - indices = np.zeros(n_tris * 3, dtype=np.int32) - for h in range(H): - for w in range(W): - idx = (h * W + w) * 3 - verts[idx] = float(w) - verts[idx + 1] = float(h) - verts[idx + 2] = float(h + w) +class TestEdgeStitching: + """Tests for boundary vertex stitching between tiles at different LODs.""" - new_v, new_i = _add_tile_skirt(verts, indices, H, W) - # Perimeter of 4x4 grid: 4+3+3+2 = 12 vertices added - n_perim = 2 * (H + W) - 4 - assert len(new_v) == (n_verts + n_perim) * 3 - assert len(new_i) > len(indices) - - def test_skirt_z_below_min(self): - H, W = 3, 3 - verts = np.zeros(9 * 3, dtype=np.float32) - for i in range(9): - verts[i * 3 + 2] = 10.0 # all z = 10 - indices = np.zeros(8 * 3, dtype=np.int32) - - new_v, _ = _add_tile_skirt(verts, indices, H, W) - skirt_z = new_v[9 * 3 + 2::3] # z of skirt vertices - assert np.all(skirt_z < 10.0) - - def test_no_edges_returns_unchanged(self): - """edges=all False should return original mesh unchanged. #79""" - H, W = 3, 3 - verts = np.zeros(9 * 3, dtype=np.float32) - indices = np.zeros(8 * 3, dtype=np.int32) - new_v, new_i = _add_tile_skirt( - verts, indices, H, W, edges=(False, False, False, False)) - np.testing.assert_array_equal(new_v, verts) - np.testing.assert_array_equal(new_i, indices) - - def test_partial_edges_fewer_wall_tris(self): - """Activating only some edges should produce fewer wall tris. #79""" - H, W = 4, 4 - n_verts = H * W - n_tris = (H - 1) * (W - 1) * 2 - verts = np.zeros(n_verts * 3, dtype=np.float32) - indices = np.zeros(n_tris * 3, dtype=np.int32) - for h in range(H): - for w in range(W): - idx = (h * W + w) * 3 - verts[idx] = float(w) - verts[idx + 1] = float(h) - verts[idx + 2] = float(h + w) + @staticmethod + def _make_terrain(H=256, W=256): + y = np.linspace(0, 100, H, dtype=np.float32) + x = np.linspace(0, 100, W, dtype=np.float32) + return y[:, None] + x[None, :] - _, all_i = _add_tile_skirt(verts, indices, H, W) - _, partial_i = _add_tile_skirt( - verts, indices, H, W, edges=(True, False, False, False)) - # Only top edge active → fewer wall triangles - assert len(partial_i) < len(all_i) + def test_coarser_neighbor_stitches_boundary(self): + """Finer tile boundary Z should match coarser neighbor's grid.""" + terrain = self._make_terrain(256, 256) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + mgr.per_tick_build_limit = 200 + # Force tile (1,0) to LOD 0 and tile (1,1) to LOD 1 by positioning + # the camera close to tile (1,0) + mgr.update(np.array([32, 32, 0]), rtx, ve=1.0, force=True) + # Check that adjacent tiles with different LODs exist + lods = mgr._tile_lods + # Find any pair where LODs differ + pairs_found = False + for (tr, tc), lod in lods.items(): + for edge, (nr, nc) in [('right', (tr, tc+1)), + ('bottom', (tr+1, tc))]: + nlod = lods.get((nr, nc), -1) + if nlod >= 0 and nlod != lod: + pairs_found = True + break + if pairs_found: + break + # If we found differently-LODed neighbors, the stitching ran + # (it's applied in _prepare_tile → _stitch_tile_boundary). + # Verify the boundary Z values are from the coarser level's pyramid. + if pairs_found: + # The finer tile's boundary should have been modified + finer_tile = (tr, tc) if lod < nlod else (nr, nc) + finer_lod = lods[finer_tile] + gid = _tile_gid(*finer_tile) + assert gid in rtx.geometries or gid in mgr._active_tiles + + def test_same_lod_no_stitching(self): + """Tiles at the same LOD should not have their boundaries modified.""" + terrain = self._make_terrain(128, 128) + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + # Set all four tiles to the same LOD + mgr._tile_lods = {(0, 0): 1, (0, 1): 1, (1, 0): 1, (1, 1): 1} + # Build mesh for interior-adjacent tile (0,0) + mesh_data = mgr._build_tile_mesh(0, 0, 1) + assert mesh_data is not None + verts_orig = mesh_data[0].copy() + verts = mesh_data[0].copy() + # Stitch — all neighbors are same LOD, so nothing should change + mgr._stitch_tile_boundary(verts, 0, 0, 1) + np.testing.assert_array_equal(verts, verts_orig, + err_msg="Same-LOD neighbors should not be stitched") + + def test_stitch_tile_boundary_method(self): + """Direct test of _stitch_tile_boundary with controlled neighbor LODs.""" + terrain = self._make_terrain(128, 128) + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + # Manually set tile LODs: (0,0) at LOD 0, (0,1) at LOD 2 + mgr._tile_lods = {(0, 0): 0, (0, 1): 2} + # Build a mesh for tile (0,0) at LOD 0 + mesh_data = mgr._build_tile_mesh(0, 0, 0) + assert mesh_data is not None + verts_orig = mesh_data[0].copy() + verts = mesh_data[0].copy() + # Stitch — tile (0,0) LOD 0 has right neighbor (0,1) at LOD 2 + mgr._stitch_tile_boundary(verts, 0, 0, 0) + th, tw = mgr._tile_grid_dims(0, 0, 0) + # Right edge should be modified (neighbor is coarser) + right_col = tw - 1 + right_indices = np.arange(th) * tw + right_col + right_z_orig = verts_orig[right_indices * 3 + 2] + right_z_stitched = verts[right_indices * 3 + 2] + # Stitched Z should differ from original (interpolated from LOD 2) + assert not np.array_equal(right_z_orig, right_z_stitched), \ + "Right edge should be stitched to coarser neighbor" + # Left edge should NOT be modified (no neighbor on left) + left_indices = np.arange(th) * tw + np.testing.assert_array_equal( + verts[left_indices * 3 + 2], + verts_orig[left_indices * 3 + 2], + err_msg="Left edge should be unchanged (no neighbor)") + + def test_get_boundary_z_ref(self): + """_get_boundary_z_ref returns Z values from pyramid at correct edge.""" + terrain = self._make_terrain(128, 128) + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + # Get top boundary Z at LOD 0 for tile (0,0) — should be from row 0 + z_top = mgr._get_boundary_z_ref(0, 0, 'top', 0) + assert z_top is not None + pyr0 = mgr._get_pyramid_level(0) + np.testing.assert_array_equal(z_top, pyr0[0, :65]) + # Bottom boundary for tile (0,0) + z_bottom = mgr._get_boundary_z_ref(0, 0, 'bottom', 0) + assert z_bottom is not None + # Left boundary + z_left = mgr._get_boundary_z_ref(0, 0, 'left', 0) + assert z_left is not None + np.testing.assert_array_equal(z_left, pyr0[:65, 0]) + + def test_tile_grid_dims(self): + """_tile_grid_dims should return correct grid size for a tile.""" + terrain = self._make_terrain(128, 128) + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + # LOD 0 on 128x128 terrain with tile_size=64, base_subsample=1 + th, tw = mgr._tile_grid_dims(0, 0, 0) + assert th == 65 and tw == 65, f"Expected 65x65, got {th}x{tw}" + # LOD 1 should be coarser + th1, tw1 = mgr._tile_grid_dims(0, 0, 1) + assert th1 < th and tw1 < tw, \ + f"LOD 1 ({th1}x{tw1}) should be coarser than LOD 0 ({th}x{tw})" + + def test_heightfield_neighbor_stitching(self): + """TIN tile adjacent to heightfield LOD 0 should stitch to full-res.""" + terrain = self._make_terrain(256, 256) + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + mgr.enable_heightfield_lod0() + # Set tile LODs: (0,0) LOD 0 (heightfield), (0,1) LOD 1 (TIN) + mgr._tile_lods = {(0, 0): 0, (0, 1): 1} + # Build mesh for tile (0,1) at LOD 1 + mesh_data = mgr._build_tile_mesh(0, 1, 1) + assert mesh_data is not None + verts_orig = mesh_data[0].copy() + verts = mesh_data[0].copy() + # Stitch — tile (0,1) LOD 1 has left neighbor (0,0) at LOD 0 (HF) + mgr._stitch_tile_boundary(verts, 0, 1, 1) + th, tw = mgr._tile_grid_dims(0, 1, 1) + # Left edge should be modified (stitched to full-res pyramid 0) + left_indices = np.arange(th) * tw + left_z_orig = verts_orig[left_indices * 3 + 2] + left_z_stitched = verts[left_indices * 3 + 2] + assert not np.array_equal(left_z_orig, left_z_stitched), \ + "Left edge should be stitched to heightfield LOD 0 (full-res)" + + def test_stitch_with_nan_terrain(self): + """Stitching should handle NaN values in terrain without crashing.""" + terrain = self._make_terrain(128, 128) + # Inject NaN in the boundary region between tile (0,0) and (0,1) + terrain[0:65, 63:66] = np.nan + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + mgr._tile_lods = {(0, 0): 0, (0, 1): 2} + mesh_data = mgr._build_tile_mesh(0, 0, 0) + assert mesh_data is not None + verts = mesh_data[0].copy() + # Should not crash even with NaN in the reference boundary + mgr._stitch_tile_boundary(verts, 0, 0, 0) + + def test_stitch_with_ve(self): + """Stitching + VE should apply both transformations correctly.""" + terrain = self._make_terrain(128, 128) + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + mgr._tile_lods = {(0, 0): 0, (0, 1): 2} + mesh_data = mgr._build_tile_mesh(0, 0, 0) + assert mesh_data is not None + # _prepare_tile with VE=2.0 should stitch then scale Z + verts, _, _ = mgr._prepare_tile(mesh_data, 0, 0, 0, ve=2.0) + # Compare to _prepare_tile with VE=1.0 — Z should be 2× larger + verts_ve1, _, _ = mgr._prepare_tile(mesh_data, 0, 0, 0, ve=1.0) + np.testing.assert_allclose(verts[2::3], verts_ve1[2::3] * 2.0, + rtol=1e-5) + + def test_stitch_streaming_tile_with_cached_neighbor(self): + """Out-of-bounds (streaming) tiles stitch to in-bounds neighbors.""" + terrain = self._make_terrain(128, 128) + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + # Tile (2, 0) is out of bounds (only 2 tile rows: 0, 1) + # Tile (1, 0) is in-bounds and at coarser LOD + mgr._tile_lods = {(2, 0): 0, (1, 0): 2} + # Build a mesh for tile (1, 0) to use as reference + mesh_data_ref = mgr._build_tile_mesh(1, 0, 0) + assert mesh_data_ref is not None + # Cache it so the stitch code can find it + mgr._tile_cache[(1, 0, 0, 1)] = mesh_data_ref + # Build an in-bounds mesh and pretend it's for tile (2, 0) + mesh_data = mgr._build_tile_mesh(0, 0, 0) + assert mesh_data is not None + verts_orig = mesh_data[0].copy() + # Prepare for OOB tile — stitching should now happen via + # pyramid (neighbor (1,0) is in-bounds) + verts, _, _ = mgr._prepare_tile(mesh_data, 2, 0, 0, ve=1.0, + own=True) + # Top boundary of tile (2,0) should be modified to match + # the bottom of tile (1,0) + assert not np.array_equal(verts[2::3], verts_orig[2::3]), \ + "OOB tile should be stitched to in-bounds neighbor" + + def test_stitched_z_matches_coarser_pyramid(self): + """Stitched boundary Z values should match interpolated coarser pyramid.""" + terrain = self._make_terrain(128, 128) + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + # Tile (0,0) at LOD 0, tile (0,1) at LOD 2 (coarser) + mgr._tile_lods = {(0, 0): 0, (0, 1): 2} + mesh_data = mgr._build_tile_mesh(0, 0, 0) + assert mesh_data is not None + verts = mesh_data[0].copy() + mgr._stitch_tile_boundary(verts, 0, 0, 0) + + th, tw = mgr._tile_grid_dims(0, 0, 0) + # Right edge vertices (shared with coarser neighbor) + right_col = tw - 1 + right_indices = np.arange(th) * tw + right_col + stitched_z = verts[right_indices * 3 + 2] + + # Compute expected Z: interpolated from LOD 2 pyramid boundary + ref_z = mgr._get_boundary_z_ref(0, 0, 'right', 2) + assert ref_z is not None + n_self = th + n_ref = len(ref_z) + positions = (np.arange(n_self, dtype=np.float64) + * (n_ref - 1) / (n_self - 1)) + expected_z = np.interp( + positions, + np.arange(n_ref, dtype=np.float64), + ref_z.astype(np.float64)).astype(np.float32) + np.testing.assert_array_almost_equal( + stitched_z, expected_z, decimal=5, + err_msg="Stitched Z should match interpolated coarser pyramid") + + # Interior vertices should be unchanged + interior_indices = np.arange(th) * tw + (tw // 2) + interior_z = verts[interior_indices * 3 + 2] + interior_z_orig = mesh_data[0][interior_indices * 3 + 2] + np.testing.assert_array_equal( + interior_z, interior_z_orig, + err_msg="Interior vertices should not be modified by stitching") + + def test_needs_stitch_fast_check(self): + """_needs_stitch should return False when all neighbors are same LOD.""" + terrain = self._make_terrain(128, 128) + mgr = TerrainLODManager(terrain, tile_size=64, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + mgr._tile_lods = {(0, 0): 1, (0, 1): 1, (1, 0): 1, (1, 1): 1} + assert not mgr._needs_stitch(0, 0, 1) + assert not mgr._needs_stitch(1, 1, 1) + # Change one neighbor to different LOD + mgr._tile_lods[(0, 1)] = 2 + assert mgr._needs_stitch(0, 0, 1) # right neighbor is coarser + + +# --------------------------------------------------------------------------- +# Heightfield LOD 0 +# --------------------------------------------------------------------------- + +class TestHeightfieldLOD0: + """Tests for heightfield ray marching on LOD 0 tiles.""" + + @staticmethod + def _make_terrain(H=256, W=256): + y = np.linspace(0, 100, H, dtype=np.float32) + x = np.linspace(0, 100, W, dtype=np.float32) + return y[:, None] + x[None, :] + + def test_enable_heightfield_creates_hf_gas(self): + """Enabling heightfield LOD 0 should produce a heightfield GAS.""" + terrain = self._make_terrain(256, 256) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=128, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + mgr.enable_heightfield_lod0() + mgr.update([64, 64, 100], rtx, ve=1.0, force=True) + assert rtx.has_geometry('terrain_lod_hf'), \ + "Heightfield GAS not created" + hf = rtx.geometries['terrain_lod_hf'] + assert hf['type'] == 'heightfield' + + def test_lod0_tiles_skip_mesh_building(self): + """LOD 0 in-bounds tiles should not create triangle meshes.""" + terrain = self._make_terrain(256, 256) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=128, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + mgr.enable_heightfield_lod0() + mgr.update([64, 64, 100], rtx, ve=1.0, force=True) + # LOD 0 tiles should be tracked but no individual TIN GAS + lod0_tiles = [k for k, v in mgr._tile_lods.items() if v == 0] + assert len(lod0_tiles) > 0, "No LOD 0 tiles assigned" + for tr, tc in lod0_tiles: + gid = _tile_gid(tr, tc) + assert not rtx.has_geometry(gid), \ + f"LOD 0 tile {gid} has TIN GAS — should use heightfield" + + def test_heightfield_active_mask(self): + """Active mask should cover only LOD 0 tile regions.""" + terrain = self._make_terrain(128, 128) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=128, + pixel_spacing_x=1.0, pixel_spacing_y=1.0, + max_lod=3) + mgr.enable_heightfield_lod0() + # Single tile terrain — close camera → LOD 0 + mgr.update([64, 64, 50], rtx, ve=1.0, force=True) + hf = rtx.geometries.get('terrain_lod_hf') + assert hf is not None, "No heightfield GAS" + mask = hf['active_mask'] + assert mask is not None + # All AABB tiles should be active (single LOD tile covers everything) + assert np.all(mask), "Not all AABB tiles are active" + + def test_heightfield_partial_active_mask(self): + """Only LOD 0 tiles should have active AABBs.""" + terrain = self._make_terrain(256, 256) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=128, + pixel_spacing_x=1.0, pixel_spacing_y=1.0, + max_lod=3, lod_distance_factor=1.0) + mgr.enable_heightfield_lod0() + # Camera at corner — close tiles LOD 0, far tiles higher LOD + mgr.update([20, 20, 50], rtx, ve=1.0, force=True) + hf = rtx.geometries.get('terrain_lod_hf') + if hf is not None: + mask = hf['active_mask'] + # Not all AABB tiles should be active (some tiles are LOD 1+) + n_active = np.sum(mask) + n_total = len(mask) + # At least some should be inactive if any tiles are LOD 1+ + lod_counts = {} + for v in mgr._tile_lods.values(): + lod_counts[v] = lod_counts.get(v, 0) + 1 + if any(l > 0 for l in lod_counts): + assert n_active < n_total, \ + "All AABB tiles active but some LOD tiles are > 0" + + def test_heightfield_with_batched_upload(self): + """Heightfield LOD 0 + batched TIN LOD 1+ should coexist.""" + terrain = self._make_terrain(256, 256) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=128, + pixel_spacing_x=1.0, pixel_spacing_y=1.0, + max_lod=3, lod_distance_factor=1.0) + mgr.enable_heightfield_lod0() + mgr.enable_batched_upload() + mgr.update([20, 20, 50], rtx, ve=1.0, force=True) + # Should have heightfield GAS and possibly TIN batch GAS + gids = rtx.list_geometries() + hf_gids = [g for g in gids if g == 'terrain_lod_hf'] + batch_gids = [g for g in gids if g.startswith('terrain_lod_batch_')] + assert len(hf_gids) <= 1, "Multiple heightfield GAS" + # LOD 1+ tiles should be in batch GAS, not individual + for tr, tc in mgr._tile_lods: + lod = mgr._tile_lods[(tr, tc)] + if lod > 0: + gid = _tile_gid(tr, tc) + assert not rtx.has_geometry(gid), \ + f"LOD {lod} tile {gid} not batched" + + def test_heightfield_ve_update(self): + """VE change should rebuild heightfield with new VE.""" + terrain = self._make_terrain(128, 128) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=128, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + mgr.enable_heightfield_lod0() + mgr.update([64, 64, 50], rtx, ve=1.0, force=True) + hf1 = rtx.geometries['terrain_lod_hf'] + assert hf1['ve'] == 1.0 + # Simulate VE change: clear tile_lods and force rebuild + mgr._tile_lods.clear() + mgr.update([64, 64, 50], rtx, ve=2.5, force=True) + hf2 = rtx.geometries['terrain_lod_hf'] + assert hf2['ve'] == 2.5 + + def test_heightfield_remove_all(self): + """remove_all should clear heightfield GAS.""" + terrain = self._make_terrain(128, 128) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=128, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + mgr.enable_heightfield_lod0() + mgr.update([64, 64, 50], rtx, ve=1.0, force=True) + assert rtx.has_geometry('terrain_lod_hf') + mgr.remove_all(rtx) + assert not rtx.has_geometry('terrain_lod_hf') + assert len(mgr._tile_lods) == 0 + + def test_heightfield_stats_show_hf(self): + """Stats should label LOD 0 as 'HF' when heightfield enabled.""" + terrain = self._make_terrain(128, 128) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=128, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + mgr.enable_heightfield_lod0() + mgr.update([64, 64, 50], rtx, ve=1.0, force=True) + stats = mgr.get_stats() + assert 'HF:' in stats, f"Stats missing HF label: {stats}" + + def test_heightfield_transform_has_offset(self): + """Heightfield transform should include world offset.""" + terrain = self._make_terrain(128, 128) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=128, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + mgr.enable_heightfield_lod0() + mgr.set_offset(1000.0, 2000.0) + mgr.update([1064, 2064, 50], rtx, ve=1.0, force=True) + hf = rtx.geometries.get('terrain_lod_hf') + assert hf is not None + transform = hf['transform'] + assert transform[3] == 1000.0, f"X offset wrong: {transform[3]}" + assert transform[7] == 2000.0, f"Y offset wrong: {transform[7]}" + + def test_streaming_lod0_uses_tin(self): + """Out-of-bounds LOD 0 tiles should fall through to TIN.""" + terrain = self._make_terrain(128, 128) + rtx = _FakeRTX() + mgr = TerrainLODManager(terrain, tile_size=128, + pixel_spacing_x=1.0, pixel_spacing_y=1.0) + mgr.enable_heightfield_lod0() + + # Set up a streaming tile data function + def tile_data_fn(x_min, y_min, x_max, y_max, target_samples): + return np.full((target_samples, target_samples), 50.0, + dtype=np.float32) + + mgr.set_tile_data_fn(tile_data_fn) + # Camera well outside bounds → out-of-bounds tiles should use TIN + mgr.update([300, 300, 50], rtx, ve=1.0, force=True) + # Out-of-bounds LOD 0 tiles should have individual GAS (not batched) + oob_lod0 = [(tr, tc) for (tr, tc), lod in mgr._tile_lods.items() + if lod == 0 and (tr < 0 or tr >= mgr._n_tile_rows + or tc < 0 or tc >= mgr._n_tile_cols)] + for tr, tc in oob_lod0: + gid = _tile_gid(tr, tc) + assert rtx.has_geometry(gid), \ + f"OOB LOD 0 tile {gid} should use TIN, not heightfield" + + +# --------------------------------------------------------------------------- +# Mesh chunk simplification +# --------------------------------------------------------------------------- + +class TestMeshChunkSimplification: + """Tests for placed geometry simplification at higher LOD levels.""" + + @staticmethod + def _make_grid_mesh(rows=10, cols=10): + """Create a regular grid triangle mesh for testing simplification.""" + verts = [] + for r in range(rows): + for c in range(cols): + verts.extend([float(c), float(r), float(r + c) * 0.1]) + indices = [] + for r in range(rows - 1): + for c in range(cols - 1): + i0 = r * cols + c + i1 = i0 + 1 + i2 = i0 + cols + i3 = i2 + 1 + indices.extend([i0, i1, i2, i1, i3, i2]) + return (np.array(verts, dtype=np.float32), + np.array(indices, dtype=np.int32)) + + def test_simplify_mesh_reduces_triangles(self): + """simplify_mesh with ratio < 1 should reduce triangle count + (or return original if trimesh decimation is unavailable).""" + verts, indices = self._make_grid_mesh(20, 20) + orig_n_tris = len(indices) // 3 + sv, si = simplify_mesh(verts, indices, 0.5) + new_n_tris = len(si) // 3 + # If fast_simplification is available, should have fewer triangles. + # Otherwise simplify_mesh gracefully returns original. + assert new_n_tris <= orig_n_tris, \ + f"Should not increase triangles: got {new_n_tris} vs original {orig_n_tris}" + + def test_simplify_mesh_ratio_1_returns_original(self): + """simplify_mesh with ratio >= 1.0 should return original mesh.""" + verts, indices = self._make_grid_mesh(5, 5) + sv, si = simplify_mesh(verts, indices, 1.0) + np.testing.assert_array_equal(sv, verts) + np.testing.assert_array_equal(si, indices) + + def test_simplify_lod0_returns_original(self): + """simplify_mesh at ratio 1.0 (LOD 0) returns original arrays.""" + verts, indices = self._make_grid_mesh(10, 10) + sv, si = simplify_mesh(verts, indices, 1.0) + np.testing.assert_array_equal(sv, verts) + np.testing.assert_array_equal(si, indices) + + def test_simplify_empty_mesh(self): + """simplify_mesh with 0 faces should return original unchanged.""" + verts = np.array([0, 0, 0, 1, 0, 0, 0, 1, 0], dtype=np.float32) + indices = np.array([], dtype=np.int32) + sv, si = simplify_mesh(verts, indices, 0.5) + # Should not crash; returns original or empty + assert len(sv) >= 0 + assert len(si) >= 0 + + def test_simplify_single_face(self): + """simplify_mesh with 1 face should not crash.""" + verts = np.array([0, 0, 0, 1, 0, 0, 0, 1, 0], dtype=np.float32) + indices = np.array([0, 1, 2], dtype=np.int32) + sv, si = simplify_mesh(verts, indices, 0.5) + # 1 face can't simplify further — should return original + assert len(si) // 3 >= 1 + + def test_simplify_high_lod_clamps_to_last_ratio(self): + """LOD index beyond ratio table length should clamp to last entry.""" + verts, indices = self._make_grid_mesh(20, 20) + ratios = (1.0, 0.5, 0.25, 0.1) + # LOD 5 should clamp to index 3 (ratio 0.1) + idx = min(5, len(ratios) - 1) + sv, si = simplify_mesh(verts, indices, ratios[idx]) + assert len(si) <= len(indices) + + def test_build_lod_chain_progressive(self): + """build_lod_chain should produce progressively simpler meshes.""" + verts, indices = self._make_grid_mesh(20, 20) + chain = build_lod_chain(verts, indices, ratios=(1.0, 0.5, 0.25)) + assert len(chain) == 3 + # Each level should have equal or fewer triangles + prev_n = len(chain[0][1]) // 3 + for level, (v, i) in enumerate(chain[1:], 1): + n = len(i) // 3 + assert n <= prev_n, \ + f"LOD {level} has {n} tris, more than LOD {level-1} ({prev_n})" + prev_n = n + + +# --------------------------------------------------------------------------- +# Tile lifecycle callbacks +# --------------------------------------------------------------------------- + +class TestTileCallbacks: + """Tests for set_tile_callbacks tile lifecycle notifications.""" + + @staticmethod + def _make_terrain(rows, cols): + np.random.seed(42) + return np.random.rand(rows, cols).astype(np.float32) * 100 + + def test_on_added_called_for_each_tile(self): + """on_added should fire for every tile built on first update.""" + terrain = self._make_terrain(256, 256) + rtx = _FakeRTX() + added = [] + mgr = TerrainLODManager( + terrain, tile_size=128, + pixel_spacing_x=1.0, pixel_spacing_y=1.0, + max_lod=2, lod_distance_factor=3.0, + ) + mgr.set_tile_callbacks( + on_added=lambda tr, tc, elev: added.append((tr, tc, elev)), + ) + mgr.per_tick_build_limit = 100 + mgr.update(np.array([128, 128, 100]), rtx, force=True) + # Should have at least one tile added + assert len(added) > 0 + # Each callback should receive (tr, tc, elevation_tile) + for tr, tc, elev in added: + assert isinstance(tr, (int, np.integer)) + assert isinstance(tc, (int, np.integer)) + + def test_on_removed_called_on_eviction(self): + """on_removed should fire when tiles leave the distance range.""" + terrain = self._make_terrain(512, 512) + rtx = _FakeRTX() + removed = [] + mgr = TerrainLODManager( + terrain, tile_size=128, + pixel_spacing_x=1.0, pixel_spacing_y=1.0, + max_lod=3, lod_distance_factor=1.0, + ) + mgr.set_tile_callbacks( + on_removed=lambda tr, tc: removed.append((tr, tc)), + ) + mgr.per_tick_build_limit = 100 + # Build near corner + mgr.update(np.array([0, 0, 0]), rtx, force=True) + tiles_before = set(mgr._tile_lods.keys()) + assert len(tiles_before) > 0 + # Move far away so original tiles leave range + mgr.update(np.array([10000, 10000, 0]), rtx, force=True) + assert len(removed) > 0 + # Every removed tile should have been in tiles_before + for tr, tc in removed: + assert (tr, tc) in tiles_before + + def test_callbacks_not_called_when_none(self): + """No error when callbacks are not set.""" + terrain = self._make_terrain(256, 256) + rtx = _FakeRTX() + mgr = TerrainLODManager( + terrain, tile_size=128, + pixel_spacing_x=1.0, pixel_spacing_y=1.0, + ) + mgr.per_tick_build_limit = 100 + # Should not raise + mgr.update(np.array([128, 128, 100]), rtx, force=True) + mgr.update(np.array([10000, 10000, 0]), rtx, force=True) diff --git a/rtxpy/tests/test_overlay_tiles.py b/rtxpy/tests/test_overlay_tiles.py new file mode 100644 index 0000000..a29d71a --- /dev/null +++ b/rtxpy/tests/test_overlay_tiles.py @@ -0,0 +1,164 @@ +"""Tests for per-tile overlay compositing.""" + +import numpy as np +import pytest + +from rtxpy.viewer.overlay_tiles import OverlayTileManager + + +class TestOverlayTileManager: + """Tests for OverlayTileManager composite generation.""" + + def test_empty_manager_returns_none(self): + mgr = OverlayTileManager(tile_size=64) + d, r, c = mgr.get_composite({(0, 0), (0, 1)}) + assert d is None + assert r == 0 + assert c == 0 + + def test_single_tile_composite(self): + mgr = OverlayTileManager(tile_size=64) + data = np.ones((64, 64), dtype=np.float32) * 5.0 + mgr.set_tile(0, 0, data) + # get_composite returns None for GPU (no cupy in test), but + # the internal _composite should be built + mgr.get_composite({(0, 0)}) + assert mgr._composite is not None + assert mgr._composite.shape == (64, 64) + np.testing.assert_allclose(mgr._composite, 5.0) + assert mgr._origin_row == 0 + assert mgr._origin_col == 0 + + def test_multi_tile_composite_offsets(self): + mgr = OverlayTileManager(tile_size=64) + mgr.set_tile(1, 2, np.ones((64, 64), dtype=np.float32) * 1.0) + mgr.set_tile(2, 3, np.ones((64, 64), dtype=np.float32) * 2.0) + mgr.get_composite({(1, 2), (2, 3)}) + # Bounding box: rows 1-2, cols 2-3 → composite 128×128 + assert mgr._composite.shape == (128, 128) + assert mgr._origin_row == 1 * 64 + assert mgr._origin_col == 2 * 64 + # Tile (1,2) at local (0,0), tile (2,3) at local (64,64) + np.testing.assert_allclose(mgr._composite[0:64, 0:64], 1.0) + np.testing.assert_allclose(mgr._composite[64:128, 64:128], 2.0) + # Gaps should be NaN + assert np.all(np.isnan(mgr._composite[0:64, 64:128])) + + def test_populate_from_array(self): + mgr = OverlayTileManager(tile_size=64) + overlay = np.arange(128 * 128, dtype=np.float32).reshape(128, 128) + mgr.populate_from_array(overlay, 64, 2, 2) + assert mgr.has_tile(0, 0) + assert mgr.has_tile(0, 1) + assert mgr.has_tile(1, 0) + assert mgr.has_tile(1, 1) + + def test_populate_skips_all_nan_tiles(self): + mgr = OverlayTileManager(tile_size=64) + overlay = np.full((128, 128), np.nan, dtype=np.float32) + # Only top-left has data + overlay[0:32, 0:32] = 1.0 + mgr.populate_from_array(overlay, 64, 2, 2) + assert mgr.has_tile(0, 0) + assert not mgr.has_tile(0, 1) + assert not mgr.has_tile(1, 0) + assert not mgr.has_tile(1, 1) + + def test_remove_tile(self): + mgr = OverlayTileManager(tile_size=64) + mgr.set_tile(0, 0, np.ones((64, 64), dtype=np.float32)) + assert mgr.has_tile(0, 0) + mgr.remove_tile(0, 0) + assert not mgr.has_tile(0, 0) + + def test_invalidate_forces_recomposite(self): + mgr = OverlayTileManager(tile_size=64) + mgr.set_tile(0, 0, np.ones((64, 64), dtype=np.float32)) + mgr.get_composite({(0, 0)}) + comp1 = mgr._composite + # Without invalidate, same tile set → cached + mgr.get_composite({(0, 0)}) + # With invalidate, should rebuild + mgr.invalidate() + mgr.set_tile(0, 0, np.ones((64, 64), dtype=np.float32) * 99) + # Reset throttle so rebuild happens immediately in tests + mgr._last_rebuild = 0.0 + mgr.get_composite({(0, 0)}) + np.testing.assert_allclose(mgr._composite, 99.0) + + def test_visible_subset_only(self): + """Only visible tiles with data are composited.""" + mgr = OverlayTileManager(tile_size=64) + mgr.set_tile(0, 0, np.ones((64, 64), dtype=np.float32)) + mgr.set_tile(5, 5, np.ones((64, 64), dtype=np.float32) * 2) + # Only request tile (0,0) as visible + mgr.get_composite({(0, 0)}) + assert mgr._composite.shape == (64, 64) + assert mgr._origin_row == 0 + + def test_color_lut(self): + mgr = OverlayTileManager(tile_size=64) + lut = np.zeros((256, 3), dtype=np.float32) + mgr.set_color_lut(lut) + assert mgr.color_lut is lut + + +class TestTextureTileManager: + """Tests for TextureTileManager RGB composite generation.""" + + def test_empty_returns_none(self): + from rtxpy.viewer.overlay_tiles import TextureTileManager + mgr = TextureTileManager(tile_size=64) + d, r, c = mgr.get_composite({(0, 0)}) + assert d is None + + def test_single_tile_composite(self): + from rtxpy.viewer.overlay_tiles import TextureTileManager + mgr = TextureTileManager(tile_size=64) + data = np.ones((64, 64, 3), dtype=np.float32) * 0.5 + mgr.set_tile(0, 0, data) + mgr.get_composite({(0, 0)}) + assert mgr._composite is not None + assert mgr._composite.shape == (64, 64, 3) + np.testing.assert_allclose(mgr._composite, 0.5) + + def test_multi_tile_offsets(self): + from rtxpy.viewer.overlay_tiles import TextureTileManager + mgr = TextureTileManager(tile_size=64) + mgr.set_tile(1, 2, np.ones((64, 64, 3), dtype=np.float32) * 0.3) + mgr.set_tile(2, 3, np.ones((64, 64, 3), dtype=np.float32) * 0.7) + mgr.get_composite({(1, 2), (2, 3)}) + assert mgr._composite.shape == (128, 128, 3) + assert mgr._origin_row == 64 + assert mgr._origin_col == 128 + np.testing.assert_allclose(mgr._composite[0:64, 0:64], 0.3) + np.testing.assert_allclose(mgr._composite[64:128, 64:128], 0.7) + # Gaps should be zero + np.testing.assert_allclose(mgr._composite[0:64, 64:128], 0.0) + + def test_populate_from_array(self): + from rtxpy.viewer.overlay_tiles import TextureTileManager + mgr = TextureTileManager(tile_size=64) + tex = np.ones((128, 128, 3), dtype=np.float32) * 0.5 + mgr.populate_from_array(tex, 64, 2, 2) + assert mgr.has_tile(0, 0) + assert mgr.has_tile(1, 1) + + def test_populate_skips_zero_tiles(self): + from rtxpy.viewer.overlay_tiles import TextureTileManager + mgr = TextureTileManager(tile_size=64) + tex = np.zeros((128, 128, 3), dtype=np.float32) + tex[0:32, 0:32] = 0.5 # Only top-left has data + mgr.populate_from_array(tex, 64, 2, 2) + assert mgr.has_tile(0, 0) + assert not mgr.has_tile(0, 1) + assert not mgr.has_tile(1, 0) + assert not mgr.has_tile(1, 1) + + def test_remove_tile(self): + from rtxpy.viewer.overlay_tiles import TextureTileManager + mgr = TextureTileManager(tile_size=64) + mgr.set_tile(0, 0, np.ones((64, 64, 3), dtype=np.float32)) + assert mgr.has_tile(0, 0) + mgr.remove_tile(0, 0) + assert not mgr.has_tile(0, 0) diff --git a/rtxpy/tiles.py b/rtxpy/tiles.py index 58caf1a..a0d541f 100644 --- a/rtxpy/tiles.py +++ b/rtxpy/tiles.py @@ -169,6 +169,31 @@ def _build_latlon_grids(raster): return yy, xx +def _build_crs_transformer(raster): + """Build a pyproj Transformer from the raster's CRS to WGS84. + + Returns the transformer, or None if CRS info is unavailable. + """ + try: + import pyproj + crs = raster.rio.crs + except (ImportError, AttributeError): + return None + if crs is None: + return None + if crs.is_geographic: + return 'geographic' # sentinel: x=lon, y=lat, no transform needed + try: + return pyproj.Transformer.from_crs( + crs, "+proj=longlat +datum=WGS84 +no_defs", always_xy=True) + except Exception: + try: + return pyproj.Transformer.from_crs( + crs, "EPSG:4326", always_xy=True) + except Exception: + return None + + def _compute_pixel_spacing_meters(raster): """Estimate the ground-truth pixel spacing in metres. @@ -238,6 +263,9 @@ def __init__(self, url_template, raster, zoom=None): # Per-pixel WGS84 coordinate grids — shape (H, W) each self._lats, self._lons = _build_latlon_grids(raster) + # CRS → WGS84 transformer (for streaming tile basemap fetching) + self._crs_transformer = _build_crs_transformer(raster) + # Terrain WGS84 bounds (from the grids) self._lat_min = float(np.nanmin(self._lats)) self._lat_max = float(np.nanmax(self._lats)) @@ -304,6 +332,134 @@ def get_gpu_texture(self): self._texture_dirty = False return self._gpu_texture + def fetch_rgb_for_bounds(self, crs_x_min, crs_y_min, crs_x_max, crs_y_max, + out_h, out_w): + """Fetch basemap RGB for an arbitrary CRS bounding box. + + Builds a small lat/lon grid, fetches/caches needed XYZ tiles, + and composites them. Runs synchronously (call from background + thread). + + Parameters + ---------- + crs_x_min, crs_y_min, crs_x_max, crs_y_max : float + Bounding box in the raster's native CRS (e.g. UTM metres). + out_h, out_w : int + Output pixel dimensions. + + Returns + ------- + np.ndarray, shape (out_h, out_w, 3), float32 [0-1], or None. + """ + if self._crs_transformer is None: + return None + + # Build CRS coordinate grid for the output tile. + # Row 0 = max y (north) to match raster row convention. + xs = np.linspace(crs_x_min, crs_x_max, out_w, dtype=np.float64) + ys = np.linspace(crs_y_max, crs_y_min, out_h, dtype=np.float64) + xx, yy = np.meshgrid(xs, ys) + + # Transform to WGS84 + if self._crs_transformer == 'geographic': + lons, lats = xx, yy + else: + lons, lats = self._crs_transformer.transform(xx, yy) + + lat_min = float(np.nanmin(lats)) + lat_max = float(np.nanmax(lats)) + lon_min = float(np.nanmin(lons)) + lon_max = float(np.nanmax(lons)) + + # Find which XYZ tiles cover this area + tile_list = tiles_for_bounds(lat_min, lon_min, lat_max, lon_max, + self._zoom) + + # Fetch tiles (synchronous, using cache + disk + network) + rgb = np.zeros((out_h, out_w, 3), dtype=np.float32) + for coord in tile_list: + tile_array = self._get_or_fetch_tile(coord) + if tile_array is None: + continue + # Composite this tile onto output + tx, ty, tz = coord + tile_h, tile_w = tile_array.shape[:2] + nw_lat, nw_lon = tile_to_lat_lon(tx, ty, tz) + se_lat, se_lon = tile_to_lat_lon(tx + 1, ty + 1, tz) + lat_span = nw_lat - se_lat + lon_span = se_lon - nw_lon + if lat_span == 0 or lon_span == 0: + continue + mask = ((lats >= se_lat) & (lats <= nw_lat) & + (lons >= nw_lon) & (lons <= se_lon)) + rows, cols = np.where(mask) + if len(rows) == 0: + continue + pixel_lats = lats[rows, cols] + pixel_lons = lons[rows, cols] + row_fracs = (nw_lat - pixel_lats) / lat_span + col_fracs = (pixel_lons - nw_lon) / lon_span + tile_rows = np.clip((row_fracs * tile_h).astype(int), + 0, tile_h - 1) + tile_cols = np.clip((col_fracs * tile_w).astype(int), + 0, tile_w - 1) + rgb[rows, cols] = tile_array[tile_rows, tile_cols] + + return rgb + + def _get_or_fetch_tile(self, coord): + """Get a tile from cache, disk, or network (synchronous). + + Returns (256, 256, 3) float32 [0-1], or None. + """ + # In-memory cache + if coord in self._tile_cache: + return self._tile_cache[coord] + + tx, ty, tz = coord + tile_array = None + disk_path = self._disk_path(coord) + + # Disk cache + if disk_path.exists(): + try: + from PIL import Image + img = Image.open(disk_path).convert('RGB') + tile_array = np.asarray(img, dtype=np.float32) / 255.0 + except Exception: + pass + + # Download + if tile_array is None: + url = self._url_template.format(x=tx, y=ty, z=tz) + try: + req = Request(url, headers={'User-Agent': 'rtxpy/0.1'}) + with urlopen(req, timeout=15) as resp: + data = resp.read() + except Exception: + return None + try: + from PIL import Image + import io + img = Image.open(io.BytesIO(data)).convert('RGB') + tile_array = np.asarray(img, dtype=np.float32) / 255.0 + except Exception: + return None + # Save to disk cache + try: + disk_path.parent.mkdir(parents=True, exist_ok=True) + with open(disk_path, 'wb') as f: + f.write(data) + except Exception: + pass + + # Store in memory LRU cache + self._tile_cache[coord] = tile_array + if len(self._tile_cache) > self._cache_limit: + self._tile_cache.popitem(last=False) + + return tile_array + def reinit_for_raster(self, raster, pixel_spacing_x=None, pixel_spacing_y=None): """Re-initialize texture for a new raster (e.g. after resolution change). diff --git a/rtxpy/viewer/_hydro_kernels.py b/rtxpy/viewer/_hydro_kernels.py new file mode 100644 index 0000000..330e0d8 --- /dev/null +++ b/rtxpy/viewer/_hydro_kernels.py @@ -0,0 +1,461 @@ +"""CUDA kernels for hydrological flow particle simulation. + +Three Numba CUDA kernels: +- ``hydro_splat_kernel``: project particles to screen space and splat +- ``hydro_advect_kernel``: bilinear flow lookup, advection, respawn detection +- ``hydro_respawn_kernel``: reset respawned particles + +Separated from engine.py for cleaner organisation. The kernels are +compiled once on first use by Numba's JIT. +""" + +import math + +from numba import cuda + + +@cuda.jit +def hydro_splat_kernel( + trails, # (N*T, 2) float32 — (row, col) per trail point + ages, # (N,) int32 — per-particle age + lifetimes, # (N,) int32 — per-particle lifetime + colors, # (N, 3) float32 — per-particle (r, g, b) + radii, # (N,) int32 — per-particle splat radius + trail_len, # int32 scalar — trail points per particle + base_alpha, # float32 scalar — base alpha intensity + min_vis_age, # int32 scalar — minimum visible age + ref_depth, # float32 scalar — depth-scaling reference distance + terrain, # (tH, tW) float32 — terrain elevation + depth_t, # (sh, sw) float32 — ray-trace t-values for occlusion + output, # (sh, sw, 3) float32 — frame buffer (atomic add) + # Camera basis — scalar args to avoid tiny GPU allocations + cam_x, cam_y, cam_z, + fwd_x, fwd_y, fwd_z, + rgt_x, rgt_y, rgt_z, + up_x, up_y, up_z, + # Projection params + fov_scale, aspect_ratio, + # Terrain/world params + psx, psy, ve, subsample_f, min_depth, max_depth, +): + idx = cuda.grid(1) + if idx >= trails.shape[0]: + return + + # Compute alpha on-GPU from per-particle ages/lifetimes + pidx = idx // trail_len + tidx = idx % trail_len + age = ages[pidx] + lifetime = lifetimes[pidx] + + # Trail point not yet laid down + if age <= tidx: + return + + # Fade in / fade out / trail decay + fade_in = (age - min_vis_age) * 0.1 + if fade_in < 0.0: + fade_in = 0.0 + elif fade_in > 1.0: + fade_in = 1.0 + fade_out = (lifetime - age) * 0.05 + if fade_out < 0.0: + fade_out = 0.0 + elif fade_out > 1.0: + fade_out = 1.0 + # Quadratic trail decay — comet-tail effect + t = float(tidx) / float(trail_len) + trail_fade = (1.0 - t) * (1.0 - t) + a = base_alpha * fade_in * fade_out * trail_fade + + # Head glow: bright spark at particle position + if tidx == 0: + a = a * 1.5 + + if a < 1e-6: + return + + row = trails[idx, 0] + col = trails[idx, 1] + + # Terrain Z lookup (nearest-neighbor, clamped) + tH = terrain.shape[0] + tW = terrain.shape[1] + sr = int(row / subsample_f) + sc = int(col / subsample_f) + if sr < 0: + sr = 0 + elif sr >= tH: + sr = tH - 1 + if sc < 0: + sc = 0 + elif sc >= tW: + sc = tW - 1 + z_raw = terrain[sr, sc] + if z_raw != z_raw: # NaN check + z_raw = 0.0 + z_val = z_raw * ve + 3.0 + + # World position + wx = col * psx + wy = row * psy + + # Camera-relative + dx = wx - cam_x + dy = wy - cam_y + dz = z_val - cam_z + + # Depth along forward axis + depth = dx * fwd_x + dy * fwd_y + dz * fwd_z + if depth <= min_depth: + return + if max_depth > 0.0 and depth > max_depth: + return + + # Depth-scaled alpha: closer = brighter, farther = fainter. + # Prevents zoomed-out over-saturation from dense overlapping particles. + depth_scale = ref_depth / (depth + ref_depth) + a = a * depth_scale + + if a < 1e-6: + return + + inv_depth = 1.0 / (depth + 1e-10) + u_cam = dx * rgt_x + dy * rgt_y + dz * rgt_z + v_cam = dx * up_x + dy * up_y + dz * up_z + u_ndc = u_cam * inv_depth / (fov_scale * aspect_ratio) + v_ndc = v_cam * inv_depth / fov_scale + + sh = output.shape[0] + sw = output.shape[1] + sx = int((u_ndc + 1.0) * 0.5 * sw) + sy = int((1.0 - v_ndc) * 0.5 * sh) + + if sx < 0 or sx >= sw or sy < 0 or sy >= sh: + return + + # Depth test: cull particles occluded by terrain. + # Convert ray t-value at this pixel to forward depth, then compare + # to the particle's forward depth (already computed as `depth`). + if depth_t.shape[0] > 0: + t_val = depth_t[sy, sx] + if t_val > 0.0 and t_val < 1.0e20: + # Forward depth = t / sqrt(1 + u_cam^2 + v_cam^2) + u_px = (2.0 * float(sx) / float(sw) - 1.0) * fov_scale * aspect_ratio + v_px = (1.0 - 2.0 * float(sy) / float(sh)) * fov_scale + inv_cos = math.sqrt(1.0 + u_px * u_px + v_px * v_px) + terrain_fwd = t_val / inv_cos + if depth > terrain_fwd: + return + + # Per-particle color and radius + color_r = colors[pidx, 0] + color_g = colors[pidx, 1] + color_b = colors[pidx, 2] + r = radii[pidx] + if r < 1: + r = 1 + # Head glow: +1px radius halo at particle position + if tidx == 0: + r = r + 1 + + # Circular stamp splat + for offy in range(-r, r + 1): + for offx in range(-r, r + 1): + dist_sq = offx * offx + offy * offy + if dist_sq > r * r: + continue + falloff = 1.0 - math.sqrt(dist_sq) / r + px = sx + offx + py = sy + offy + if px < 0 or px >= sw or py < 0 or py >= sh: + continue + contrib = a * falloff + cuda.atomic.add(output, (py, px, 0), contrib * color_r) + cuda.atomic.add(output, (py, px, 1), contrib * color_g) + cuda.atomic.add(output, (py, px, 2), contrib * color_b) + + +@cuda.jit +def hydro_advect_kernel( + # Particle state (GPU-resident, modified in-place) + particles, # (N, 2) float32 — (row, col) positions + ages, # (N,) int32 + lifetimes, # (N,) int32 + trails, # (N, T, 2) float32 — trail history + particle_accum, # (N,) float32 — max-tracked stream weight + particle_raw_order, # (N,) int32 — max-tracked raw Strahler order + colors, # (N, 3) float32 — per-particle RGB + radii, # (N,) int32 — per-particle splat radius + # Grid textures (GPU-resident, read-only) + flow_u, # (H, W) float32 — MFD flow col-component + flow_v, # (H, W) float32 — MFD flow row-component + slope_mag, # (H, W) float32 — normalized slope + stream_order, # (H, W) float32 — normalized stream order (or empty) + stream_order_raw, # (H, W) int32 — raw Strahler order (or empty) + accum_norm, # (H, W) float32 — normalized flow accumulation + # Palette for color lookup (9, 3) float32 + palette, # (9, 3) float32 — stream order color palette + # Output: respawn flags + respawn_flags, # (N,) int32 — 1 if particle needs respawn + # Scalar params + speed, dt_scale, trail_len, + has_so, # int32: 1 if stream_order is valid + has_slope, # int32: 1 if slope_mag is valid + has_raw_order, # int32: 1 if stream_order_raw / particle_raw_order valid + # Window offset for streaming: particle coords are global, + # flow field covers a window starting at (win_r0, win_c0) + win_r0, win_c0, + # RNG seed + rng_base, # int64 — base seed for per-particle RNG +): + """Advect one hydro particle: bilinear flow lookup, trail shift, respawn detection.""" + i = cuda.grid(1) + N = particles.shape[0] + if i >= N: + return + + H = flow_u.shape[0] + W = flow_u.shape[1] + + row = particles[i, 0] + col = particles[i, 1] + + # Shift trail buffer: slot 0 = current pos (before advection) + t = trail_len - 1 + while t > 0: + trails[i, t, 0] = trails[i, t - 1, 0] + trails[i, t, 1] = trails[i, t - 1, 1] + t -= 1 + trails[i, 0, 0] = row + trails[i, 0, 1] = col + + # Map global particle position to local flow field coordinates + local_r = row - win_r0 + local_c = col - win_c0 + + # Bilinear interpolation of MFD flow vectors + r_clean = local_r + c_clean = local_c + if r_clean != r_clean: + r_clean = 0.0 + if c_clean != c_clean: + c_clean = 0.0 + if r_clean < 0.0: + r_clean = 0.0 + elif r_clean > H - 1.0: + r_clean = H - 1.0 + if c_clean < 0.0: + c_clean = 0.0 + elif c_clean > W - 1.0: + c_clean = W - 1.0 + + r0 = int(r_clean) + c0 = int(c_clean) + if r0 > H - 2: + r0 = H - 2 + if c0 > W - 2: + c0 = W - 2 + if r0 < 0: + r0 = 0 + if c0 < 0: + c0 = 0 + r1 = r0 + 1 + c1 = c0 + 1 + + dr = r_clean - float(r0) + dc = c_clean - float(c0) + w00 = (1.0 - dr) * (1.0 - dc) + w01 = (1.0 - dr) * dc + w10 = dr * (1.0 - dc) + w11 = dr * dc + + u_val = (flow_u[r0, c0] * w00 + flow_u[r0, c1] * w01 + + flow_u[r1, c0] * w10 + flow_u[r1, c1] * w11) + v_val = (flow_v[r0, c0] * w00 + flow_v[r0, c1] * w01 + + flow_v[r1, c0] * w10 + flow_v[r1, c1] * w11) + + # Integer indices for grid lookups + ri = int(r_clean) + ci = int(c_clean) + if ri > H - 1: + ri = H - 1 + if ci > W - 1: + ci = W - 1 + if ri < 0: + ri = 0 + if ci < 0: + ci = 0 + + # Max-track stream order / accumulation + if has_so: + cur_val = stream_order[ri, ci] + else: + cur_val = accum_norm[ri, ci] + old_val = particle_accum[i] + if cur_val > old_val: + particle_accum[i] = cur_val + # Update raw order + if has_raw_order: + cur_raw = stream_order_raw[ri, ci] + if cur_raw > particle_raw_order[i]: + particle_raw_order[i] = cur_raw + # Recompute color + radius from new weight + if has_raw_order: + raw_o = particle_raw_order[i] + idx = raw_o + if idx < 1: + idx = 1 + if idx > 8: + idx = 8 + colors[i, 0] = palette[idx, 0] + colors[i, 1] = palette[idx, 1] + colors[i, 2] = palette[idx, 2] + rad = raw_o + 1 + if rad < 2: + rad = 2 + if rad > 5: + rad = 5 + radii[i] = rad + else: + a_val = particle_accum[i] + colors[i, 0] = 0.02 + a_val * 0.43 + colors[i, 1] = 0.10 + a_val * 0.65 + colors[i, 2] = 0.55 + a_val * 0.40 + rad = 2 + int(a_val * 3.0) + if rad < 2: + rad = 2 + if rad > 5: + rad = 5 + radii[i] = rad + + # Simple xorshift64 RNG seeded per-particle per-frame + s = rng_base * 2654435761 + i * 1442695040888963407 + s = s ^ (s >> 17) + s = s * 6364136223846793005 + s = s ^ (s >> 31) + # Two uniform floats in [-0.1, 0.1] for jitter + jitter_r = ((s & 0xFFFF) / 65535.0 - 0.5) * 0.2 + s = s * 6364136223846793005 + 1 + s = s ^ (s >> 31) + jitter_c = ((s & 0xFFFF) / 65535.0 - 0.5) * 0.2 + + # Slope-based speed + slope_f = 1.0 + if has_slope: + slope_f = 0.3 + 0.7 * slope_mag[ri, ci] + + # Advect (in global coordinates — flow vectors are direction-only) + particles[i, 0] = row + (v_val + jitter_r) * speed * dt_scale * slope_f + particles[i, 1] = col + (u_val + jitter_c) * speed * dt_scale * slope_f + + # Age + ages[i] = ages[i] + 1 + + # Respawn detection: OOB (relative to flow window), aged-out, stuck + new_r = particles[i, 0] + new_c = particles[i, 1] + new_lr = new_r - win_r0 + new_lc = new_c - win_c0 + is_nan = (new_r != new_r) or (new_c != new_c) + is_oob = is_nan or new_lr < 0.0 or new_lr >= H or new_lc < 0.0 or new_lc >= W + is_old = ages[i] >= lifetimes[i] + is_stuck = (u_val * u_val + v_val * v_val) < 1e-6 + if is_oob or is_old or is_stuck: + respawn_flags[i] = 1 + else: + respawn_flags[i] = 0 + + +@cuda.jit +def hydro_respawn_kernel( + # Particle state (GPU-resident, modified in-place) + particles, # (N, 2) float32 + ages, # (N,) int32 + lifetimes, # (N,) int32 + trails, # (N, T, 2) float32 + particle_accum, # (N,) float32 + particle_raw_order, # (N,) int32 + colors, # (N, 3) float32 + radii, # (N,) int32 + # Respawn data (uploaded from CPU) + respawn_indices, # (M,) int32 — which particles to respawn + spawn_rows, # (M,) float32 — new row positions (global coords) + spawn_cols, # (M,) float32 — new col positions (global coords) + new_lifetimes, # (M,) int32 + # Grid lookups + stream_order, # (H, W) float32 + stream_order_raw, # (H, W) int32 + accum_norm, # (H, W) float32 + palette, # (9, 3) float32 + # Scalars + trail_len, has_so, has_raw_order, + # Window offset + win_r0, win_c0, +): + """Apply respawn: reset position, age, trails, color/radius for respawned particles.""" + m = cuda.grid(1) + if m >= respawn_indices.shape[0]: + return + + i = respawn_indices[m] + new_r = spawn_rows[m] + new_c = spawn_cols[m] + H = stream_order.shape[0] if has_so else accum_norm.shape[0] + W = stream_order.shape[1] if has_so else accum_norm.shape[1] + + particles[i, 0] = new_r + particles[i, 1] = new_c + ages[i] = 0 + lifetimes[i] = new_lifetimes[m] + + # Reset trails to new position + for t in range(trail_len): + trails[i, t, 0] = new_r + trails[i, t, 1] = new_c + + # Look up stream weight at spawn point (local coords) + ri = int(new_r - win_r0) + ci = int(new_c - win_c0) + if ri < 0: + ri = 0 + if ri >= H: + ri = H - 1 + if ci < 0: + ci = 0 + if ci >= W: + ci = W - 1 + + if has_so: + val = stream_order[ri, ci] + else: + val = accum_norm[ri, ci] + particle_accum[i] = val + + if has_raw_order: + raw_o = stream_order_raw[ri, ci] + particle_raw_order[i] = raw_o + idx = raw_o + if idx < 1: + idx = 1 + if idx > 8: + idx = 8 + colors[i, 0] = palette[idx, 0] + colors[i, 1] = palette[idx, 1] + colors[i, 2] = palette[idx, 2] + rad = raw_o + 1 + if rad < 2: + rad = 2 + if rad > 5: + rad = 5 + radii[i] = rad + else: + colors[i, 0] = 0.02 + val * 0.43 + colors[i, 1] = 0.10 + val * 0.65 + colors[i, 2] = 0.55 + val * 0.40 + rad = 2 + int(val * 3.0) + if rad < 2: + rad = 2 + if rad > 5: + rad = 5 + radii[i] = rad diff --git a/rtxpy/viewer/hud.py b/rtxpy/viewer/hud.py index cb2d778..f57a9be 100644 --- a/rtxpy/viewer/hud.py +++ b/rtxpy/viewer/hud.py @@ -15,7 +15,9 @@ class HUDState: 'last_title', 'last_subtitle', 'minimap_background', 'minimap_scale_x', 'minimap_scale_y', 'minimap_has_tiles', 'minimap_rect', + 'minimap_world_extent', 'minimap_style', 'minimap_layer', 'minimap_colors', + 'minimap_bg_extent', 'minimap_last_stream_time', ) def __init__(self, title='rtxpy', subtitle=None, legend=None): @@ -37,6 +39,9 @@ def __init__(self, title='rtxpy', subtitle=None, legend=None): self.minimap_scale_y = 1.0 self.minimap_has_tiles = False self.minimap_rect = None + self.minimap_world_extent = None # (wx_min, wy_min, wx_max, wy_max) self.minimap_style = None self.minimap_layer = None self.minimap_colors = None + self.minimap_bg_extent = None # (wx_min, wy_min, wx_max, wy_max) + self.minimap_last_stream_time = 0.0 diff --git a/rtxpy/viewer/hydro.py b/rtxpy/viewer/hydro.py index 99bf6b9..0d19a15 100644 --- a/rtxpy/viewer/hydro.py +++ b/rtxpy/viewer/hydro.py @@ -9,7 +9,7 @@ class HydroState: """ __slots__ = ( - 'hydro_data', 'hydro_enabled', + 'hydro_data', 'hydro_enabled', 'hydro_lazy', 'hydro_flow_u_px', 'hydro_flow_v_px', 'hydro_flow_accum_norm', 'hydro_stream_order', @@ -18,7 +18,7 @@ class HydroState: 'hydro_particles', 'hydro_ages', 'hydro_lifetimes', 'hydro_max_age', 'hydro_n_particles', 'hydro_trail_len', 'hydro_trails', - 'hydro_speed', 'hydro_min_depth', 'hydro_ref_depth', + 'hydro_speed', 'hydro_min_depth', 'hydro_max_depth', 'hydro_ref_depth', 'hydro_dot_radius', 'hydro_alpha', 'hydro_min_visible_age', 'hydro_accum_threshold', @@ -32,16 +32,27 @@ class HydroState: 'hydro_particle_colors', 'hydro_particle_radii', 'hydro_particle_raw_order', - # GPU buffers + # GPU buffers (rendering) 'd_hydro_trails', 'd_hydro_ages', 'd_hydro_lifetimes', 'd_hydro_colors', 'd_hydro_radii', 'hydro_done_event', + # GPU-resident advection state + 'd_hydro_particles', + 'd_hydro_particle_accum', + 'd_hydro_particle_raw_order', + 'd_hydro_flow_u', 'd_hydro_flow_v', + 'd_hydro_slope_mag', + 'd_hydro_stream_order', 'd_hydro_stream_order_raw', + 'd_hydro_accum_norm', + 'd_hydro_palette', + 'd_hydro_respawn_flags', ) def __init__(self): self.hydro_data = None self.hydro_enabled = False + self.hydro_lazy = False self.hydro_flow_u_px = None self.hydro_flow_v_px = None self.hydro_flow_accum_norm = None @@ -57,6 +68,7 @@ def __init__(self): self.hydro_trails = None self.hydro_speed = 0.75 self.hydro_min_depth = 0.0 + self.hydro_max_depth = 0.0 # 0 = unlimited self.hydro_ref_depth = 1.0 self.hydro_dot_radius = 2 self.hydro_alpha = 0.5 @@ -75,10 +87,22 @@ def __init__(self): self.hydro_particle_radii = None self.hydro_particle_raw_order = None - # GPU buffers + # GPU buffers (rendering) self.d_hydro_trails = None self.d_hydro_ages = None self.d_hydro_lifetimes = None self.d_hydro_colors = None self.d_hydro_radii = None self.hydro_done_event = None + # GPU-resident advection state + self.d_hydro_particles = None + self.d_hydro_particle_accum = None + self.d_hydro_particle_raw_order = None + self.d_hydro_flow_u = None + self.d_hydro_flow_v = None + self.d_hydro_slope_mag = None + self.d_hydro_stream_order = None + self.d_hydro_stream_order_raw = None + self.d_hydro_accum_norm = None + self.d_hydro_palette = None + self.d_hydro_respawn_flags = None diff --git a/rtxpy/viewer/hydro_manager.py b/rtxpy/viewer/hydro_manager.py new file mode 100644 index 0000000..7798108 --- /dev/null +++ b/rtxpy/viewer/hydro_manager.py @@ -0,0 +1,1229 @@ +"""HydroManager: manages hydrological flow particle simulation. + +Encapsulates flow field computation, particle lifecycle (spawn, advect, +respawn), GPU splatting, and streaming tile integration. Replaces the +~850 lines of hydro code previously embedded in engine.py. + +Streaming support +----------------- +When a ``TerrainLODManager`` is attached (via ``set_lod_manager``), the +flow field covers a *window* of tiles centred on the camera. Particles +are stored in **global pixel coordinates** so their positions stay valid +across window shifts. The advection kernel receives ``(win_r0, win_c0)`` +offsets to translate global coords → local flow field indices. + +The window is recomputed asynchronously (``ThreadPoolExecutor(1)``) when +the camera moves more than ~2 tiles from the window centre. +""" + +from __future__ import annotations + +import math +import time +from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + from .terrain_lod import TerrainLODManager + +try: + import cupy as cp + has_cupy = True +except ImportError: + cp = None + has_cupy = False + + +# Stream-order colour palette (shared with engine.py for overlay LUT) +STREAM_ORDER_PALETTE = np.array([ + [0.0, 0.0, 0.0 ], # 0: unused + [0.50, 0.80, 1.00], # 1: pale sky blue (headwaters) + [0.38, 0.68, 0.98], # 2: light blue + [0.28, 0.55, 0.95], # 3: sky blue + [0.18, 0.42, 0.90], # 4: medium blue + [0.10, 0.30, 0.85], # 5: royal blue + [0.06, 0.20, 0.78], # 6: deep blue + [0.03, 0.12, 0.70], # 7: dark blue + [0.01, 0.06, 0.60], # 8: navy (major rivers) +], dtype=np.float32) + + +def build_stream_palette_lut(max_order): + """Build a 256-entry color LUT for stream order overlay rendering.""" + lut = np.zeros((256, 3), dtype=np.float32) + denom = max(max_order - 1, 1) + for i in range(256): + order = int(round(1 + (i / 255.0) * denom)) + order = max(1, min(8, order)) + lut[i] = STREAM_ORDER_PALETTE[order] + return lut + + +def color_from_order(order_norm, raw_order=None): + """Map stream order → (R, G, B) per particle.""" + if raw_order is not None: + idx = np.clip(raw_order, 1, 8).astype(int) + colors = STREAM_ORDER_PALETTE[idx].copy() + else: + colors = np.empty((len(order_norm), 3), dtype=np.float32) + colors[:, 0] = 0.02 + order_norm * 0.43 + colors[:, 1] = 0.10 + order_norm * 0.65 + colors[:, 2] = 0.55 + order_norm * 0.40 + return np.clip(colors, 0.0, 1.0) + + +def radius_from_order(order_norm, raw_order=None): + """Map stream order → radius (2–5) per particle.""" + if raw_order is not None: + return np.clip(raw_order + 1, 2, 5).astype(np.int32) + return np.clip(2 + (order_norm * 3).astype(np.int32), + 2, 5).astype(np.int32) + + +class HydroManager: + """Manages hydrological flow particle simulation and rendering. + + Parameters + ---------- + hydro_state : HydroState + The viewer's HydroState object (owns CPU-side particle arrays + and GPU buffer references). + """ + + __slots__ = ( + '_state', + # LOD integration + '_lod_manager', + '_tile_data_fn', + '_crs_transform', + # Flow field window (streaming) + '_win_r0', '_win_c0', + '_win_h', '_win_w', + '_window_center_tr', '_window_center_tc', + '_window_radius', + '_window_future', + '_window_executor', + '_last_window_time', + # Streaming stream_link overlay (pending for engine pickup) + '_pending_stream_overlay', + '_pending_overlay_bounds', + # Terrain reference + '_terrain_np', + '_psx', '_psy', + ) + + def __init__(self, hydro_state): + self._state = hydro_state + self._lod_manager = None + self._tile_data_fn = None + self._crs_transform = None + + # Streaming window state + self._win_r0 = 0.0 + self._win_c0 = 0.0 + self._win_h = 0 + self._win_w = 0 + self._window_center_tr = None + self._window_center_tc = None + self._window_radius = 5 # tiles in each direction + self._window_future = None + self._window_executor = None + self._last_window_time = 0.0 + + # Streaming stream_link overlay (set by _compute_windowed_flow) + self._pending_stream_overlay = None + self._pending_overlay_bounds = None # (win_r0, win_c0, win_h, win_w) + + # Terrain ref (set during init) + self._terrain_np = None + self._psx = 1.0 + self._psy = 1.0 + + # ------------------------------------------------------------------ + # Configuration + # ------------------------------------------------------------------ + + def set_lod_manager(self, mgr: 'TerrainLODManager'): + """Attach LOD manager for streaming tile integration.""" + self._lod_manager = mgr + + def set_tile_data_fn(self, fn): + """Set the tile data callback for streaming elevation fetches.""" + self._tile_data_fn = fn + + def set_crs_transform(self, x0, y0, dx, dy): + """Store CRS origin and pixel spacing for coord conversion.""" + self._crs_transform = (float(x0), float(y0), float(dx), float(dy)) + + def set_terrain_ref(self, terrain_np, psx, psy): + """Store terrain array and pixel spacing for Z lookups.""" + self._terrain_np = terrain_np + self._psx = psx + self._psy = psy + + @property + def streaming(self): + """True if streaming tile mode is active.""" + return self._lod_manager is not None and self._tile_data_fn is not None + + # ------------------------------------------------------------------ + # Initialisation + # ------------------------------------------------------------------ + + def init_from_flow(self, flow_accum, terrain_data, psx, psy, **kwargs): + """Initialize hydro particles from flow accumulation grid. + + This is the main entry point — equivalent to the old + ``_init_hydro`` on InteractiveViewer. + + Parameters + ---------- + flow_accum : array-like, shape (H, W) + Flow accumulation grid. + terrain_data : array-like, shape (H, W) + Terrain elevation for slope computation. + psx, psy : float + Pixel spacing (world units per pixel). + **kwargs + Optional overrides: n_particles, max_age, trail_len, speed, + accum_threshold, color, alpha, dot_radius, min_visible_age, + stream_order, flow_dir_mfd, elevation. + """ + st = self._state + self._psx = psx + self._psy = psy + + if hasattr(flow_accum, 'get'): + flow_accum = flow_accum.get() + flow_accum = np.asarray(flow_accum, dtype=np.float64) + H, W = flow_accum.shape + + # Stream order grid (optional but strongly recommended) + stream_order = kwargs.pop('stream_order', None) + if stream_order is not None: + if hasattr(stream_order, 'get'): + stream_order = stream_order.get() + stream_order = np.asarray(stream_order, dtype=np.float64) + stream_order = np.nan_to_num(stream_order, nan=0.0) + has_stream_order = stream_order is not None and (stream_order > 0).any() + + st.hydro_data = True + + # Apply optional overrides + for key, attr, conv in [ + ('n_particles', 'hydro_n_particles', int), + ('max_age', 'hydro_max_age', int), + ('trail_len', 'hydro_trail_len', int), + ('speed', 'hydro_speed', float), + ('accum_threshold', 'hydro_accum_threshold', int), + ('color', 'hydro_color', tuple), + ('alpha', 'hydro_alpha', float), + ('dot_radius', 'hydro_dot_radius', int), + ('min_visible_age', 'hydro_min_visible_age', int), + ]: + if key in kwargs: + val = conv(kwargs[key]) + setattr(st, attr, val) + + # MFD flow vectors + flow_dir_mfd = kwargs.pop('flow_dir_mfd', None) + + sqrt2_inv = 1.0 / np.sqrt(2.0) + _dir_dr = np.array([0.0, sqrt2_inv, 1.0, sqrt2_inv, + 0.0, -sqrt2_inv, -1.0, -sqrt2_inv]) + _dir_dc = np.array([1.0, sqrt2_inv, 0.0, -sqrt2_inv, + -1.0, -sqrt2_inv, 0.0, sqrt2_inv]) + + if flow_dir_mfd is not None: + if hasattr(flow_dir_mfd, 'get'): + flow_dir_mfd = flow_dir_mfd.get() + frac = np.asarray(flow_dir_mfd, dtype=np.float64) + frac = np.nan_to_num(frac, nan=0.0) + flow_v = np.tensordot(_dir_dr, frac, axes=([0], [0])) + flow_u = np.tensordot(_dir_dc, frac, axes=([0], [0])) + valid_flow = np.any(frac > 0, axis=0) + del frac + else: + elevation = kwargs.pop('elevation', None) + if elevation is not None: + if hasattr(elevation, 'get'): + elevation = elevation.get() + elevation = np.asarray(elevation, dtype=np.float64) + else: + if hasattr(terrain_data, 'get'): + elevation = terrain_data.get() + else: + elevation = np.asarray(terrain_data, dtype=np.float64) + elevation = elevation.astype(np.float64) + + nan_mask_elev = np.isnan(elevation) + elev_clean = np.where(nan_mask_elev, 1e10, elevation) + + sqrt2 = np.sqrt(2.0) + mfd_p = 1.1 + flow_u = np.zeros((H, W), dtype=np.float64) + flow_v = np.zeros((H, W), dtype=np.float64) + + _nb_offsets = [ + (-1, -1, sqrt2), (-1, 0, 1.0), (-1, 1, sqrt2), + ( 0, -1, 1.0), ( 0, 1, 1.0), + ( 1, -1, sqrt2), ( 1, 0, 1.0), ( 1, 1, sqrt2), + ] + _nb_dr = np.array([-sqrt2_inv, -1.0, -sqrt2_inv, + 0.0, 0.0, + sqrt2_inv, 1.0, sqrt2_inv]) + _nb_dc = np.array([-sqrt2_inv, 0.0, sqrt2_inv, + -1.0, 1.0, + -sqrt2_inv, 0.0, sqrt2_inv]) + + for k, (dr, dc, dist) in enumerate(_nb_offsets): + cr = slice(max(0, -dr), H - max(0, dr)) + cc = slice(max(0, -dc), W - max(0, dc)) + nr = slice(max(0, -dr) + dr, H - max(0, dr) + dr) + nc = slice(max(0, -dc) + dc, W - max(0, dc) + dc) + drop = elev_clean[cr, cc] - elev_clean[nr, nc] + slope = np.maximum(drop / dist, 0.0) + weight = slope ** mfd_p + flow_v[cr, cc] += weight * _nb_dr[k] + flow_u[cr, cc] += weight * _nb_dc[k] + + flow_u[nan_mask_elev] = 0.0 + flow_v[nan_mask_elev] = 0.0 + + # Normalize to unit vectors + mag = np.sqrt(flow_u**2 + flow_v**2) + valid_flow = mag > 0 + flow_u[valid_flow] /= mag[valid_flow] + flow_v[valid_flow] /= mag[valid_flow] + + st.hydro_flow_u_px = flow_u.astype(np.float32) + st.hydro_flow_v_px = flow_v.astype(np.float32) + + # Normalize accumulation + fa_clipped = np.clip(flow_accum, 1, None) + log_fa = np.log10(fa_clipped) + threshold = np.log10(max(st.hydro_accum_threshold, 1)) + log_max = log_fa.max() + if log_max > threshold: + accum_norm = np.clip( + (log_fa - threshold) / (log_max - threshold), 0, 1) + else: + accum_norm = np.zeros_like(log_fa) + st.hydro_flow_accum_norm = accum_norm.astype(np.float32) + + # Store stream order grids + if has_stream_order: + max_order = stream_order.max() + so_norm = (stream_order / max(max_order, 1)).astype(np.float32) + st.hydro_stream_order = so_norm + st.hydro_stream_order_raw = stream_order.astype(np.int32) + print(f" Stream order: max {int(max_order)}, " + f"{int((stream_order > 0).sum())} stream cells") + else: + st.hydro_stream_order = None + st.hydro_stream_order_raw = None + + # Stream link grid + stream_link_grid = kwargs.pop('stream_link', None) + if stream_link_grid is not None: + if hasattr(stream_link_grid, 'get'): + stream_link_grid = stream_link_grid.get() + st.hydro_stream_link = np.nan_to_num( + np.asarray(stream_link_grid, dtype=np.float64), nan=0.0 + ).astype(np.int32) + else: + st.hydro_stream_link = None + + # Build spawn probabilities + if has_stream_order: + spawn_weights = np.where(stream_order > 0, + np.sqrt(stream_order), 0.0) + spawn_weights[~valid_flow] = 0.0 + else: + spawn_weights = accum_norm.copy() + spawn_weights[~valid_flow] = 0.0 + + flat_weights = spawn_weights.ravel() + valid_mask = flat_weights > 0 + valid_indices = np.nonzero(valid_mask)[0] + if len(valid_indices) > 0: + valid_probs = flat_weights[valid_indices].astype(np.float64) + valid_probs /= valid_probs.sum() + else: + valid_flow_flat = valid_flow.ravel() + valid_indices = np.nonzero(valid_flow_flat)[0] + if len(valid_indices) > 0: + valid_probs = np.ones(len(valid_indices), dtype=np.float64) + valid_probs /= valid_probs.sum() + else: + valid_indices = np.arange(H * W) + valid_probs = np.ones(H * W, dtype=np.float64) / (H * W) + st.hydro_spawn_indices = valid_indices + st.hydro_spawn_valid_probs = valid_probs + + # Spawn initial particles + N = st.hydro_n_particles + chosen = np.random.choice(len(valid_indices), N, p=valid_probs) + indices = valid_indices[chosen] + rows = (indices // W).astype(np.float32) + \ + np.random.uniform(-0.5, 0.5, N).astype(np.float32) + cols = (indices % W).astype(np.float32) + \ + np.random.uniform(-0.5, 0.5, N).astype(np.float32) + rows = np.clip(rows, 0, H - 1) + cols = np.clip(cols, 0, W - 1) + + st.hydro_particles = np.column_stack([rows, cols]).astype(np.float32) + st.hydro_ages = np.random.randint(0, st.hydro_max_age, N).astype(np.int32) + st.hydro_lifetimes = np.random.randint( + st.hydro_max_age // 2, st.hydro_max_age, N).astype(np.int32) + st.hydro_trails = np.zeros( + (N, st.hydro_trail_len, 2), dtype=np.float32) + for t in range(st.hydro_trail_len): + st.hydro_trails[:, t, :] = st.hydro_particles + + # Compute terrain slope magnitude + if hasattr(terrain_data, 'get'): + elev = terrain_data.get().astype(np.float64) + else: + elev = np.asarray(terrain_data, dtype=np.float64) + grad_row, grad_col = np.gradient(np.nan_to_num(elev, nan=0.0)) + slope_mag = np.sqrt(grad_row**2 + grad_col**2).astype(np.float32) + p95 = np.percentile(slope_mag[slope_mag > 0], 95) \ + if (slope_mag > 0).any() else 1.0 + slope_norm = np.clip( + slope_mag / max(p95, 1e-6), 0, 1).astype(np.float32) + st.hydro_slope_mag = slope_norm + + # Per-particle visual properties + r_idx = np.clip(np.floor(rows).astype(int), 0, H - 1) + c_idx = np.clip(np.floor(cols).astype(int), 0, W - 1) + if has_stream_order: + order_val = so_norm[r_idx, c_idx].astype(np.float32) + raw_order = st.hydro_stream_order_raw[r_idx, c_idx] + else: + order_val = accum_norm[r_idx, c_idx].astype(np.float32) + raw_order = None + st.hydro_particle_accum = order_val + st.hydro_particle_raw_order = raw_order + st.hydro_particle_colors = color_from_order( + order_val, raw_order=raw_order) + st.hydro_particle_radii = radius_from_order( + order_val, raw_order=raw_order) + + # Min render distance and depth-scaled alpha reference + world_diag = np.sqrt((W * psx)**2 + (H * psy)**2) + st.hydro_min_depth = 1.0 + st.hydro_max_depth = world_diag * 0.35 + st.hydro_ref_depth = world_diag * 0.15 + + # Set streaming window to cover full initial terrain + self._win_r0 = 0.0 + self._win_c0 = 0.0 + self._win_h = H + self._win_w = W + + # Upload to GPU + self._upload_to_gpu(N) + + print(f" Hydro flow initialized on {H}x{W} grid " + f"({N} particles, threshold={st.hydro_accum_threshold})") + + def compute_from_terrain(self, raster): + """Compute hydrological flow from terrain elevation on GPU. + + Uses xrspatial MFD functions. Called lazily on first hydro + enable or after terrain reload. + + Parameters + ---------- + raster : xarray.DataArray + Terrain elevation (may be CuPy-backed). + + Returns + ------- + dict or None + ``{'stream_order_raw': ..., 'stream_link': ...}`` on success, + or None on failure. Caller uses these to register overlays. + """ + try: + from xrspatial import fill as _fill + from xrspatial import flow_direction_mfd as _fd_mfd + from xrspatial import flow_accumulation_mfd as _fa_mfd + from xrspatial import stream_order_mfd as _so_mfd + from xrspatial import stream_link_mfd as _sl_mfd + except ImportError: + print("Hydro requires xrspatial: pip install xrspatial") + return None + try: + from scipy.ndimage import uniform_filter + except ImportError: + print("Hydro requires scipy: pip install scipy") + return None + + print("Computing hydrological flow on GPU...") + data = raster.data + is_cupy = hasattr(data, 'get') + + # Condition DEM + elev_np = data.get() if is_cupy else np.array(data) + elev_np = elev_np.astype(np.float32) + ocean = (elev_np == 0.0) | np.isnan(elev_np) + elev_np[ocean] = -100.0 + + smoothed = uniform_filter(elev_np, size=15, mode='nearest') + smoothed[ocean] = -100.0 + del elev_np + + if is_cupy: + sm = cp.asarray(smoothed) + else: + sm = smoothed + del smoothed + + filled = _fill(raster.copy(data=sm)) + fill_depth = filled.data - sm + resolved = filled.data + fill_depth * 0.01 + del filled, fill_depth, sm + + if is_cupy: + cp.random.seed(0) + resolved += cp.random.uniform( + 0, 0.0001, resolved.shape, dtype=cp.float32) + resolved[cp.asarray(ocean)] = -100.0 + else: + np.random.seed(0) + resolved += np.random.uniform( + 0, 0.0001, resolved.shape).astype(np.float32) + resolved[ocean] = -100.0 + + resolved_da = raster.copy(data=resolved) + fd_mfd = _fd_mfd(resolved_da, boundary='nearest') + fa_mfd = _fa_mfd(fd_mfd) + del resolved_da, resolved + + so = _so_mfd(fd_mfd, fa_mfd, threshold=50) + sl = _sl_mfd(fd_mfd, fa_mfd, threshold=50) + + fa_out = fa_mfd.data + fd_out = fd_mfd.data + so_out = so.data + sl_out = sl.data + xp = cp if is_cupy else np + if is_cupy: + ocean_gpu = cp.asarray(ocean) + fa_out[ocean_gpu] = cp.nan + fd_out[:, ocean_gpu] = cp.nan + so_out[ocean_gpu] = cp.nan + sl_out[ocean_gpu] = cp.nan + else: + fa_out[ocean] = np.nan + fd_out[:, ocean] = np.nan + so_out[ocean] = np.nan + sl_out[ocean] = np.nan + + sl_clean = xp.nan_to_num(sl_out, nan=0.0).astype(xp.float32) + + terrain_data = raster.data + self.init_from_flow( + fa_out, + terrain_data, + self._psx, self._psy, + flow_dir_mfd=fd_out, + stream_order=so_out, + stream_link=sl_clean, + ) + + H, W = raster.shape + print(f" Hydro flow computed on GPU ({H}x{W} grid, MFD)") + + # Return overlay data for the engine to register + result = {} + st = self._state + if st.hydro_stream_order_raw is not None: + result['stream_order_raw'] = st.hydro_stream_order_raw + max_order = int(st.hydro_stream_order_raw.max()) + result['palette_lut'] = build_stream_palette_lut(max_order) + + sl_np = sl_clean.get() if is_cupy else np.asarray(sl_clean) + so_raw = st.hydro_stream_order_raw.astype(np.float32) + sl_color = np.where( + (sl_np <= 0) | (so_raw <= 0), + np.float32(np.nan), so_raw) + result['stream_link_overlay'] = sl_color + result['palette_lut'] = build_stream_palette_lut(max_order) + + return result + + def _upload_to_gpu(self, N): + """Upload all particle + grid arrays to GPU.""" + if not has_cupy: + return + st = self._state + st.d_hydro_particles = cp.asarray(st.hydro_particles) + st.d_hydro_ages = cp.asarray(st.hydro_ages) + st.d_hydro_lifetimes = cp.asarray(st.hydro_lifetimes) + st.d_hydro_trails = cp.asarray(st.hydro_trails) + st.d_hydro_colors = cp.asarray(st.hydro_particle_colors) + st.d_hydro_radii = cp.asarray(st.hydro_particle_radii) + st.d_hydro_particle_accum = cp.asarray(st.hydro_particle_accum) + if st.hydro_particle_raw_order is not None: + st.d_hydro_particle_raw_order = cp.asarray( + st.hydro_particle_raw_order) + else: + st.d_hydro_particle_raw_order = cp.zeros(N, dtype=cp.int32) + st.d_hydro_flow_u = cp.asarray(st.hydro_flow_u_px) + st.d_hydro_flow_v = cp.asarray(st.hydro_flow_v_px) + if st.hydro_slope_mag is not None: + st.d_hydro_slope_mag = cp.asarray(st.hydro_slope_mag) + else: + st.d_hydro_slope_mag = cp.empty((0, 0), dtype=cp.float32) + if st.hydro_stream_order is not None: + st.d_hydro_stream_order = cp.asarray(st.hydro_stream_order) + else: + st.d_hydro_stream_order = cp.empty((0, 0), dtype=cp.float32) + if st.hydro_stream_order_raw is not None: + st.d_hydro_stream_order_raw = cp.asarray( + st.hydro_stream_order_raw) + else: + st.d_hydro_stream_order_raw = cp.empty((0, 0), dtype=cp.int32) + st.d_hydro_accum_norm = cp.asarray(st.hydro_flow_accum_norm) + st.d_hydro_palette = cp.asarray(STREAM_ORDER_PALETTE) + st.d_hydro_respawn_flags = cp.zeros(N, dtype=cp.int32) + + # ------------------------------------------------------------------ + # Per-tick update + # ------------------------------------------------------------------ + + def update_particles(self, dt_scale=1.0): + """Advect hydro particles one tick on GPU. + + Two-pass: GPU advection kernel, then CPU respawn batch. + """ + if not has_cupy: + return + st = self._state + if st.d_hydro_flow_u is None or st.d_hydro_particles is None: + return + + from ._hydro_kernels import hydro_advect_kernel, hydro_respawn_kernel + + N = st.d_hydro_particles.shape[0] + + has_so = 1 if (st.hydro_stream_order is not None) else 0 + has_slope = 1 if (st.hydro_slope_mag is not None) else 0 + has_raw = 1 if (st.hydro_stream_order_raw is not None) else 0 + + speed = float(st.hydro_speed) + trail_len = int(st.hydro_trail_len) + rng_base = np.random.randint(0, 2**62) + + threadsperblock = 256 + blockspergrid = (N + threadsperblock - 1) // threadsperblock + + hydro_advect_kernel[blockspergrid, threadsperblock]( + st.d_hydro_particles, + st.d_hydro_ages, + st.d_hydro_lifetimes, + st.d_hydro_trails, + st.d_hydro_particle_accum, + st.d_hydro_particle_raw_order, + st.d_hydro_colors, + st.d_hydro_radii, + st.d_hydro_flow_u, + st.d_hydro_flow_v, + st.d_hydro_slope_mag, + st.d_hydro_stream_order, + st.d_hydro_stream_order_raw, + st.d_hydro_accum_norm, + st.d_hydro_palette, + st.d_hydro_respawn_flags, + speed, float(dt_scale), trail_len, + has_so, has_slope, has_raw, + float(self._win_r0), float(self._win_c0), + rng_base, + ) + + # Read back respawn flags and handle respawns on CPU + respawn_flags = st.d_hydro_respawn_flags.get() + respawn_idx = np.nonzero(respawn_flags)[0] + n_respawn = len(respawn_idx) + + if n_respawn > 0: + H, W = st.d_hydro_flow_u.shape + chosen = np.random.choice( + len(st.hydro_spawn_indices), n_respawn, + p=st.hydro_spawn_valid_probs) + flat_indices = st.hydro_spawn_indices[chosen] + # Spawn positions are in flow-field-local coords; + # convert to global by adding window offset + spawn_rows = (flat_indices // W).astype(np.float32) + \ + np.random.uniform(-0.5, 0.5, n_respawn).astype(np.float32) + \ + float(self._win_r0) + spawn_cols = (flat_indices % W).astype(np.float32) + \ + np.random.uniform(-0.5, 0.5, n_respawn).astype(np.float32) + \ + float(self._win_c0) + spawn_rows = np.clip(spawn_rows, + self._win_r0, self._win_r0 + H - 1) + spawn_cols = np.clip(spawn_cols, + self._win_c0, self._win_c0 + W - 1) + new_lifetimes = np.random.randint( + st.hydro_max_age // 2, st.hydro_max_age, + n_respawn).astype(np.int32) + + d_respawn_idx = cp.asarray(respawn_idx.astype(np.int32)) + d_spawn_rows = cp.asarray(spawn_rows) + d_spawn_cols = cp.asarray(spawn_cols) + d_new_lifetimes = cp.asarray(new_lifetimes) + + blocks_r = (n_respawn + threadsperblock - 1) // threadsperblock + hydro_respawn_kernel[blocks_r, threadsperblock]( + st.d_hydro_particles, + st.d_hydro_ages, + st.d_hydro_lifetimes, + st.d_hydro_trails, + st.d_hydro_particle_accum, + st.d_hydro_particle_raw_order, + st.d_hydro_colors, + st.d_hydro_radii, + d_respawn_idx, + d_spawn_rows, + d_spawn_cols, + d_new_lifetimes, + st.d_hydro_stream_order, + st.d_hydro_stream_order_raw, + st.d_hydro_accum_norm, + st.d_hydro_palette, + trail_len, has_so, has_raw, + float(self._win_r0), float(self._win_c0), + ) + + # ------------------------------------------------------------------ + # GPU splatting + # ------------------------------------------------------------------ + + def splat_gpu(self, d_frame, camera_pos, look_at, fov, ve, + subsample_factor, terrain_gpu, depth_t=None): + """Project and splat hydro particles on GPU. + + Parameters + ---------- + d_frame : cupy.ndarray, shape (H, W, 3) + GPU frame buffer (float32 0-1). Modified in-place. + camera_pos : tuple of 3 floats + look_at : tuple of 3 floats + fov : float + Vertical field of view in degrees. + ve : float + Vertical exaggeration. + subsample_factor : float + terrain_gpu : cupy.ndarray, shape (tH, tW) + GPU terrain for Z lookup. + depth_t : cupy.ndarray or None + Depth buffer for occlusion culling. + """ + st = self._state + if st.d_hydro_particles is None or st.d_hydro_trails is None: + return + if not has_cupy: + return + + from ._hydro_kernels import hydro_splat_kernel + from ..analysis.render import _compute_camera_basis + + N = st.d_hydro_particles.shape[0] + trail_len = st.hydro_trail_len + total = N * trail_len + + forward, right, cam_up = _compute_camera_basis( + tuple(camera_pos), tuple(look_at), (0, 0, 1), + ) + fov_scale = math.tan(math.radians(fov) / 2.0) + aspect_ratio = d_frame.shape[1] / d_frame.shape[0] + + d_trails_flat = st.d_hydro_trails.reshape(-1, 2) + + if depth_t is None: + depth_t = cp.empty((0, 0), dtype=cp.float32) + + threadsperblock = 256 + blockspergrid = (total + threadsperblock - 1) // threadsperblock + + hydro_splat_kernel[blockspergrid, threadsperblock]( + d_trails_flat, + st.d_hydro_ages, + st.d_hydro_lifetimes, + st.d_hydro_colors, + st.d_hydro_radii, + trail_len, + float(st.hydro_alpha), + int(st.hydro_min_visible_age), + float(st.hydro_ref_depth), + terrain_gpu, + depth_t, + d_frame, + float(camera_pos[0]), float(camera_pos[1]), + float(camera_pos[2]), + float(forward[0]), float(forward[1]), float(forward[2]), + float(right[0]), float(right[1]), float(right[2]), + float(cam_up[0]), float(cam_up[1]), float(cam_up[2]), + float(fov_scale), float(aspect_ratio), + float(self._psx), + float(self._psy), + float(ve), + float(subsample_factor), + float(st.hydro_min_depth), + float(st.hydro_max_depth), + ) + + cp.clip(d_frame, 0, 1, out=d_frame) + + # ------------------------------------------------------------------ + # Streaming window management + # ------------------------------------------------------------------ + + def update_streaming_window(self, camera_row, camera_col): + """Check if the flow field window needs to shift for streaming. + + Called each tick when streaming is active. If the camera has + moved more than 2 tiles from the window centre and enough time + has passed, triggers an async recompute of the flow field over + a new tile window. + + Parameters + ---------- + camera_row, camera_col : float + Camera position in pixel (row, col) coordinates. + """ + if not self.streaming: + return + + mgr = self._lod_manager + ts = mgr._tile_size + + cam_tr = int(camera_row / ts) + cam_tc = int(camera_col / ts) + + # Check if we need to shift + if self._window_center_tr is not None: + dr = abs(cam_tr - self._window_center_tr) + dc = abs(cam_tc - self._window_center_tc) + if dr <= 2 and dc <= 2: + return # still within comfort zone + + # Don't start a new compute if one is in flight + if self._window_future is not None and not self._window_future.done(): + return + + # Throttle: at most every 3 seconds + now = time.monotonic() + if now - self._last_window_time < 3.0: + return + + # Collect the result of any previous future + if self._window_future is not None and self._window_future.done(): + self._apply_window_result(self._window_future.result()) + self._window_future = None + + # Submit async recompute + self._last_window_time = now + radius = self._window_radius + r0_tile = cam_tr - radius + c0_tile = cam_tc - radius + r1_tile = cam_tr + radius + 1 + c1_tile = cam_tc + radius + 1 + + # Pixel bounds of the window + win_r0 = r0_tile * ts + win_c0 = c0_tile * ts + win_r1 = r1_tile * ts + win_c1 = c1_tile * ts + + if self._window_executor is None: + self._window_executor = ThreadPoolExecutor(max_workers=1) + + self._window_future = self._window_executor.submit( + self._compute_windowed_flow, + win_r0, win_c0, win_r1, win_c1, + cam_tr, cam_tc, + ) + + def check_streaming_result(self): + """Poll for completed async window recompute. Call each tick.""" + if self._window_future is not None and self._window_future.done(): + try: + result = self._window_future.result() + if result is not None: + self._apply_window_result(result) + except Exception as e: + print(f"Hydro streaming window error: {e}") + self._window_future = None + + def pop_streaming_overlay(self): + """Return and clear pending streaming stream overlay. + + Returns ``(overlay, win_r0, win_c0)`` or ``(None, 0, 0)``. + """ + ov = self._pending_stream_overlay + bounds = self._pending_overlay_bounds + self._pending_stream_overlay = None + self._pending_overlay_bounds = None + if ov is None or bounds is None: + return None, 0, 0 + return ov, bounds[0], bounds[1] + + @staticmethod + def _compute_stream_overlay(elevation, threshold=50): + """Compute stream overlay via D8 flow accumulation (CPU). + + Subsamples large grids for performance, then upsamples result. + Returns ``(H, W)`` float32 array: stream order values (1-8) + where streams exist, NaN elsewhere. + """ + H, W = elevation.shape + + # Subsample to keep computation under ~1M cells + max_cells = 1024 + subsample = max(1, max(H, W) // max_cells) + if subsample > 1: + elev_s = elevation[::subsample, ::subsample].copy() + else: + elev_s = elevation + sH, sW = elev_s.shape + N = sH * sW + + elev_clean = np.nan_to_num(elev_s, nan=1e10).astype(np.float64) + flat_elev = elev_clean.ravel() + valid_cell = flat_elev < 1e9 + + # --- D8 flow direction (vectorized) --- + idx_grid = np.arange(N, dtype=np.int64).reshape(sH, sW) + target = np.full(N, -1, dtype=np.int64) + max_slope = np.full((sH, sW), 0.0, dtype=np.float64) + + sqrt2 = np.sqrt(2.0) + nb_offsets = [ + (-1, -1, sqrt2), (-1, 0, 1.0), (-1, 1, sqrt2), + (0, -1, 1.0), (0, 1, 1.0), + (1, -1, sqrt2), (1, 0, 1.0), (1, 1, sqrt2), + ] + for dr, dc, dist in nb_offsets: + rs = slice(max(0, -dr), sH - max(0, dr)) + cs = slice(max(0, -dc), sW - max(0, dc)) + rn = slice(max(0, -dr) + dr, sH - max(0, dr) + dr) + cn = slice(max(0, -dc) + dc, sW - max(0, dc) + dc) + drop = elev_clean[rs, cs] - elev_clean[rn, cn] + slope = drop / dist + better = slope > max_slope[rs, cs] + if better.any(): + src = idx_grid[rs, cs][better].ravel() + dst = idx_grid[rn, cn][better].ravel() + target[src] = dst + max_slope[rs, cs] = np.maximum(max_slope[rs, cs], slope) + + # --- Flow accumulation: propagate high→low --- + order = np.argsort(-flat_elev) + accum = np.ones(N, dtype=np.float64) + for idx in order: + if not valid_cell[idx]: + continue + tgt = target[idx] + if tgt >= 0: + accum[tgt] += accum[idx] + + accum_2d = accum.reshape(sH, sW) + + # --- Detect streams and map to stream order --- + is_stream = accum_2d > threshold + if not is_stream.any(): + return np.full((H, W), np.nan, dtype=np.float32) + + stream_vals = accum_2d[is_stream] + log_vals = np.log10(np.maximum(stream_vals, 1.0)) + log_min = np.log10(max(threshold, 1.0)) + log_max = log_vals.max() + log_range = max(log_max - log_min, 1.0) + order_vals = 1.0 + (log_vals - log_min) / log_range * 7.0 + order_vals = np.clip(order_vals, 1.0, 8.0).astype(np.float32) + + overlay_s = np.full((sH, sW), np.nan, dtype=np.float32) + overlay_s[is_stream] = order_vals + + # Upsample back to full resolution + if subsample > 1: + from PIL import Image + # Nearest-neighbor for categorical data + img = Image.fromarray(overlay_s) + overlay = np.array( + img.resize((W, H), Image.NEAREST), dtype=np.float32) + else: + overlay = overlay_s + + return overlay + + def _compute_windowed_flow(self, win_r0, win_c0, win_r1, win_c1, + center_tr, center_tc): + """Build flow field for the given pixel window (runs off main thread). + + Assembles elevation from: + 1. Initial terrain (for in-bounds pixels) + 2. LOD tile cache (for cached streaming tiles) + 3. tile_data_fn (for uncached areas) + + Then computes MFD flow vectors and spawn probabilities. + + Returns a dict with all the arrays needed to update GPU state, + or None on failure. + """ + mgr = self._lod_manager + if mgr is None: + return None + + win_h = win_r1 - win_r0 + win_c = win_c1 - win_c0 + if win_h <= 0 or win_c <= 0: + return None + + # Assemble elevation array for the window + elevation = np.full((win_h, win_c), np.nan, dtype=np.float32) + + # 1. Fill from initial terrain where available + terrain_np = mgr._terrain_np + t_H, t_W = terrain_np.shape + # Overlap region between window and initial terrain + src_r0 = max(0, -win_r0) + src_c0 = max(0, -win_c0) + src_r1 = min(win_h, t_H - win_r0) + src_c1 = min(win_c, t_W - win_c0) + ter_r0 = max(0, win_r0) + ter_c0 = max(0, win_c0) + if src_r1 > src_r0 and src_c1 > src_c0: + ter_r1 = ter_r0 + (src_r1 - src_r0) + ter_c1 = ter_c0 + (src_c1 - src_c0) + elevation[src_r0:src_r1, src_c0:src_c1] = \ + terrain_np[ter_r0:ter_r1, ter_c0:ter_c1] + + # 2. Fill remaining NaN areas from tile_data_fn + if self._tile_data_fn is not None and np.isnan(elevation).any(): + ts = mgr._tile_size + crs_tf = self._crs_transform + if crs_tf is not None: + crs_x0, crs_y0, crs_dx, crs_dy = crs_tf + psx = abs(mgr._psx) + psy = abs(mgr._psy) + + # Iterate tile-sized blocks in the window + for br in range(0, win_h, ts): + for bc in range(0, win_c, ts): + br1 = min(br + ts, win_h) + bc1 = min(bc + ts, win_c) + block = elevation[br:br1, bc:bc1] + if not np.isnan(block).any(): + continue + + # Global pixel coords of this block + gr0 = win_r0 + br + gc0 = win_c0 + bc + gr1 = win_r0 + br1 + gc1 = win_c0 + bc1 + + # Convert to CRS + x_min = crs_x0 + gc0 * crs_dx + x_max = crs_x0 + gc1 * crs_dx + y_min = crs_y0 + gr0 * crs_dy + y_max = crs_y0 + gr1 * crs_dy + + if x_min > x_max: + x_min, x_max = x_max, x_min + if y_min > y_max: + y_min, y_max = y_max, y_min + + try: + tile_data = self._tile_data_fn( + x_min, y_min, x_max, y_max, + max(br1 - br, bc1 - bc)) + except Exception: + continue + + if tile_data is not None: + td = np.asarray(tile_data, dtype=np.float32) + # Resize to match block dimensions + if td.shape != (br1 - br, bc1 - bc): + from PIL import Image + img = Image.fromarray(td) + img = img.resize( + (bc1 - bc, br1 - br), + Image.BILINEAR) + td = np.array(img, dtype=np.float32) + # Only fill NaN cells + nan_mask = np.isnan(elevation[br:br1, bc:bc1]) + elevation[br:br1, bc:bc1] = np.where( + nan_mask, td, elevation[br:br1, bc:bc1]) + + # If still mostly NaN, bail + valid_frac = np.isfinite(elevation).mean() + if valid_frac < 0.1: + return None + + # Fill remaining NaN with nearest-neighbor + for _ in range(min(50, max(win_h, win_c))): + still_nan = np.isnan(elevation) + if not still_nan.any(): + break + padded = np.pad(elevation, 1, mode='edge') + neighbors = np.stack([ + padded[:-2, 1:-1], padded[2:, 1:-1], + padded[1:-1, :-2], padded[1:-1, 2:], + ], axis=0) + with np.errstate(all='ignore'): + fill_vals = np.nanmean(neighbors, axis=0) + elevation = np.where( + still_nan & np.isfinite(fill_vals), fill_vals, elevation) + + # Compute stream overlay for streaming tiles + try: + stream_overlay = self._compute_stream_overlay(elevation) + except Exception: + stream_overlay = None + + # Compute MFD flow from elevation + nan_mask = np.isnan(elevation) + elev_clean = np.where(nan_mask, 1e10, elevation).astype(np.float64) + + sqrt2 = np.sqrt(2.0) + sqrt2_inv = 1.0 / sqrt2 + mfd_p = 1.1 + flow_u = np.zeros((win_h, win_c), dtype=np.float64) + flow_v = np.zeros((win_h, win_c), dtype=np.float64) + + _nb_offsets = [ + (-1, -1, sqrt2), (-1, 0, 1.0), (-1, 1, sqrt2), + ( 0, -1, 1.0), ( 0, 1, 1.0), + ( 1, -1, sqrt2), ( 1, 0, 1.0), ( 1, 1, sqrt2), + ] + _nb_dr = np.array([-sqrt2_inv, -1.0, -sqrt2_inv, + 0.0, 0.0, + sqrt2_inv, 1.0, sqrt2_inv]) + _nb_dc = np.array([-sqrt2_inv, 0.0, sqrt2_inv, + -1.0, 1.0, + -sqrt2_inv, 0.0, sqrt2_inv]) + + for k, (dr, dc, dist) in enumerate(_nb_offsets): + cr = slice(max(0, -dr), win_h - max(0, dr)) + cc = slice(max(0, -dc), win_c - max(0, dc)) + nr = slice(max(0, -dr) + dr, win_h - max(0, dr) + dr) + nc = slice(max(0, -dc) + dc, win_c - max(0, dc) + dc) + drop = elev_clean[cr, cc] - elev_clean[nr, nc] + slope = np.maximum(drop / dist, 0.0) + weight = slope ** mfd_p + flow_v[cr, cc] += weight * _nb_dr[k] + flow_u[cr, cc] += weight * _nb_dc[k] + + flow_u[nan_mask] = 0.0 + flow_v[nan_mask] = 0.0 + + # Normalize + mag = np.sqrt(flow_u**2 + flow_v**2) + valid_flow = mag > 0 + flow_u[valid_flow] /= mag[valid_flow] + flow_v[valid_flow] /= mag[valid_flow] + + flow_u = flow_u.astype(np.float32) + flow_v = flow_v.astype(np.float32) + + # Slope magnitude + grad_row, grad_col = np.gradient( + np.nan_to_num(elevation, nan=0.0).astype(np.float64)) + slope_mag = np.sqrt(grad_row**2 + grad_col**2).astype(np.float32) + p95 = np.percentile(slope_mag[slope_mag > 0], 95) \ + if (slope_mag > 0).any() else 1.0 + slope_norm = np.clip( + slope_mag / max(p95, 1e-6), 0, 1).astype(np.float32) + + # Accumulation-based spawn weights (simple: use flow magnitude) + spawn_weights = np.where(valid_flow, mag.astype(np.float32), 0.0) + flat_weights = spawn_weights.ravel() + valid_mask = flat_weights > 0 + valid_indices = np.nonzero(valid_mask)[0] + if len(valid_indices) > 0: + valid_probs = flat_weights[valid_indices].astype(np.float64) + valid_probs /= valid_probs.sum() + else: + valid_indices = np.arange(win_h * win_c) + valid_probs = np.ones(win_h * win_c, dtype=np.float64) / \ + (win_h * win_c) + + # Accumulation norm for particle colouring + accum_norm = np.clip(mag.astype(np.float32) / + max(mag.max(), 1e-6), 0, 1) + + return { + 'win_r0': win_r0, 'win_c0': win_c0, + 'win_h': win_h, 'win_w': win_c, + 'center_tr': center_tr, 'center_tc': center_tc, + 'flow_u': flow_u, 'flow_v': flow_v, + 'slope_mag': slope_norm, + 'accum_norm': accum_norm.astype(np.float32), + 'spawn_indices': valid_indices, + 'spawn_probs': valid_probs, + 'stream_overlay': stream_overlay, + } + + def _apply_window_result(self, result): + """Apply a completed windowed flow computation to GPU state.""" + if result is None: + return + + st = self._state + self._win_r0 = float(result['win_r0']) + self._win_c0 = float(result['win_c0']) + self._win_h = result['win_h'] + self._win_w = result['win_w'] + self._window_center_tr = result['center_tr'] + self._window_center_tc = result['center_tc'] + + # Update CPU-side grids + st.hydro_flow_u_px = result['flow_u'] + st.hydro_flow_v_px = result['flow_v'] + st.hydro_slope_mag = result['slope_mag'] + st.hydro_flow_accum_norm = result['accum_norm'] + st.hydro_spawn_indices = result['spawn_indices'] + st.hydro_spawn_valid_probs = result['spawn_probs'] + + # Stream order not available in windowed mode (would need + # full xrspatial which is too slow for async recompute) + st.hydro_stream_order = None + st.hydro_stream_order_raw = None + + # Upload new grids to GPU + if has_cupy: + st.d_hydro_flow_u = cp.asarray(result['flow_u']) + st.d_hydro_flow_v = cp.asarray(result['flow_v']) + st.d_hydro_slope_mag = cp.asarray(result['slope_mag']) + st.d_hydro_accum_norm = cp.asarray(result['accum_norm']) + st.d_hydro_stream_order = cp.empty((0, 0), dtype=cp.float32) + st.d_hydro_stream_order_raw = cp.empty((0, 0), dtype=cp.int32) + + # Store stream overlay for engine to pick up + stream_ov = result.get('stream_overlay') + if stream_ov is not None: + self._pending_stream_overlay = stream_ov + self._pending_overlay_bounds = ( + result['win_r0'], result['win_c0'], + result['win_h'], result['win_w'], + ) + + print(f" Hydro window shifted to " + f"({result['win_r0']}, {result['win_c0']}) " + f"size {result['win_h']}x{result['win_w']}") + + # ------------------------------------------------------------------ + # Cleanup + # ------------------------------------------------------------------ + + def shutdown(self): + """Clean up thread pool.""" + if self._window_executor is not None: + self._window_executor.shutdown(wait=False) + self._window_executor = None diff --git a/rtxpy/viewer/keybindings.py b/rtxpy/viewer/keybindings.py index 82d1c93..a0035bc 100644 --- a/rtxpy/viewer/keybindings.py +++ b/rtxpy/viewer/keybindings.py @@ -40,7 +40,6 @@ 'T': '_action_cycle_time', # Time-of-day 'Y': '_action_toggle_hydro', # Hydro flow particles 'N': '_action_toggle_clouds', # Cloud layer - 'A': '_action_toggle_terrain_lod', # Distance-based terrain LOD } # Lowercase key bindings — checked after shift bindings @@ -58,7 +57,6 @@ ']': '_action_observer_elev_up', 'f': '_save_screenshot', 'y': '_action_cycle_color_stretch', - 'b': '_action_cycle_mesh_type', 'u': '_action_cycle_basemap_fwd', ',': '_action_overlay_alpha_down', '.': '_action_overlay_alpha_up', diff --git a/rtxpy/viewer/overlay_tiles.py b/rtxpy/viewer/overlay_tiles.py new file mode 100644 index 0000000..3ce2aa2 --- /dev/null +++ b/rtxpy/viewer/overlay_tiles.py @@ -0,0 +1,409 @@ +"""Per-tile overlay compositing for LOD terrain. + +Manages overlay data (hydro stream_link, etc.) on a per-tile basis and +composites visible tiles into a single contiguous GPU array with pixel +offsets. The render kernel indexes the composite using:: + + ov_y = elev_y - overlay_offset_y + ov_x = elev_x - overlay_offset_x + +This keeps the Numba CUDA kernel simple (single 2D array + two int +offsets) while supporting unbounded tiled terrain. +""" + +import threading +import time + +import numpy as np + + +class OverlayTileManager: + """Composites per-tile overlay arrays into a GPU-ready buffer. + + Parameters + ---------- + tile_size : int + Tile size in pixels (must match TerrainLODManager). + """ + + __slots__ = ( + '_tile_size', + '_tile_overlays', # {(tr, tc): np.ndarray} + '_color_lut', # np.ndarray (256, 3) float32 or None + '_composite', # np.ndarray (H, W) float32 or None + '_d_composite', # cupy ndarray or None + '_origin_row', # int: pixel row of composite[0,0] + '_origin_col', # int: pixel col of composite[0,0] + '_dirty', # bool + '_composited_tiles', # frozenset of (tr, tc) + '_last_rebuild', # float: monotonic time of last rebuild + '_lock', + ) + + _REBUILD_INTERVAL = 0.25 + + def __init__(self, tile_size): + self._tile_size = tile_size + self._tile_overlays = {} + self._color_lut = None + self._composite = None + self._d_composite = None + self._origin_row = 0 + self._origin_col = 0 + self._dirty = True + self._composited_tiles = frozenset() + self._last_rebuild = 0.0 + self._lock = threading.Lock() + + # ------------------------------------------------------------------ + # Tile data management + # ------------------------------------------------------------------ + + def set_color_lut(self, lut): + """Set the palette LUT for categorical overlay coloring.""" + self._color_lut = lut + + @property + def color_lut(self): + return self._color_lut + + def set_tile(self, tr, tc, data): + """Store overlay data for tile (tr, tc). + + Parameters + ---------- + data : np.ndarray, shape (th, tw) + Overlay values (float32, NaN = transparent). + """ + with self._lock: + self._tile_overlays[(tr, tc)] = np.asarray(data, dtype=np.float32) + self._dirty = True + + def remove_tile(self, tr, tc): + """Remove overlay for tile (tr, tc).""" + with self._lock: + if (tr, tc) in self._tile_overlays: + del self._tile_overlays[(tr, tc)] + self._dirty = True + + def has_tile(self, tr, tc): + return (tr, tc) in self._tile_overlays + + def clear(self): + """Remove all tile overlays.""" + with self._lock: + self._tile_overlays.clear() + self._composite = None + self._d_composite = None + self._dirty = True + + def populate_from_array(self, overlay, tile_size, n_tile_rows, n_tile_cols): + """Slice a monolithic overlay array into per-tile chunks. + + Parameters + ---------- + overlay : np.ndarray, shape (H, W) + Full-terrain overlay (e.g. from compute_from_terrain). + tile_size : int + Tile size in pixels. + n_tile_rows, n_tile_cols : int + Number of tile rows/columns in the initial terrain grid. + """ + H, W = overlay.shape + with self._lock: + for tr in range(n_tile_rows): + for tc in range(n_tile_cols): + r0 = tr * tile_size + c0 = tc * tile_size + r1 = min(r0 + tile_size, H) + c1 = min(c0 + tile_size, W) + tile_data = overlay[r0:r1, c0:c1] + # Skip all-NaN tiles to save memory + if np.all(np.isnan(tile_data)): + continue + self._tile_overlays[(tr, tc)] = tile_data.copy() + self._dirty = True + + # ------------------------------------------------------------------ + # Composite for rendering + # ------------------------------------------------------------------ + + def get_composite(self, visible_tiles): + """Return (gpu_array, offset_row, offset_col) for the render kernel. + + Parameters + ---------- + visible_tiles : set of (tr, tc) + Currently visible tile coordinates. + + Returns + ------- + (d_composite, offset_row, offset_col) or (None, 0, 0) + GPU overlay array and pixel offsets into elev_y/elev_x space. + """ + # Filter to tiles that actually have overlay data + with self._lock: + tiles_with_data = visible_tiles & set(self._tile_overlays.keys()) + + if not tiles_with_data: + return None, 0, 0 + + tiles_frozen = frozenset(tiles_with_data) + + # Rebuild composite only when tile set changed or data is dirty + if not self._dirty and tiles_frozen == self._composited_tiles: + if self._d_composite is not None: + return self._d_composite, self._origin_row, self._origin_col + + # Superset: visible set contracted but data unchanged + if (not self._dirty + and tiles_frozen <= self._composited_tiles + and self._d_composite is not None): + return self._d_composite, self._origin_row, self._origin_col + + # Throttle rebuilds to avoid per-frame GPU uploads + now = time.monotonic() + if (self._dirty + and self._d_composite is not None + and (now - self._last_rebuild) < self._REBUILD_INTERVAL): + return self._d_composite, self._origin_row, self._origin_col + + self._dirty = False + self._composited_tiles = tiles_frozen + self._last_rebuild = now + + ts = self._tile_size + + # Compute bounding box in tile coordinates + min_tr = min(tr for tr, tc in tiles_with_data) + max_tr = max(tr for tr, tc in tiles_with_data) + min_tc = min(tc for tr, tc in tiles_with_data) + max_tc = max(tc for tr, tc in tiles_with_data) + + # Pixel origin of the composite + origin_row = min_tr * ts + origin_col = min_tc * ts + + # Composite dimensions + comp_h = (max_tr - min_tr + 1) * ts + comp_w = (max_tc - min_tc + 1) * ts + + # Build composite (NaN = no data / transparent) + composite = np.full((comp_h, comp_w), np.nan, dtype=np.float32) + + with self._lock: + for tr, tc in tiles_with_data: + tile_data = self._tile_overlays.get((tr, tc)) + if tile_data is None: + continue + r0 = (tr - min_tr) * ts + c0 = (tc - min_tc) * ts + th, tw = tile_data.shape + composite[r0:r0 + th, c0:c0 + tw] = tile_data + + self._composite = composite + self._origin_row = origin_row + self._origin_col = origin_col + + # Upload to GPU + try: + import cupy + self._d_composite = cupy.asarray(composite) + except ImportError: + self._d_composite = None + return None, 0, 0 + + return self._d_composite, origin_row, origin_col + + def invalidate(self): + """Force recomposite on next get_composite call.""" + self._dirty = True + + +class TextureTileManager: + """Composites per-tile RGB textures into a GPU-ready buffer. + + Same composite-with-offset pattern as :class:`OverlayTileManager` + but for ``(H, W, 3)`` float32 RGB data (basemap imagery). + + Parameters + ---------- + tile_size : int + Tile size in pixels (must match TerrainLODManager). + """ + + __slots__ = ( + '_tile_size', + '_tile_textures', # {(tr, tc): np.ndarray (th, tw, 3)} + '_composite', # np.ndarray (H, W, 3) float32 or None + '_d_composite', # cupy ndarray or None + '_origin_row', # int + '_origin_col', # int + '_dirty', + '_composited_tiles', + '_last_rebuild', # float: monotonic time of last rebuild + '_lock', + ) + + # Minimum seconds between composite rebuilds. Background fetch + # threads calling set_tile() each mark dirty; without throttling + # we'd rebuild+upload every frame while fetches stream in. + _REBUILD_INTERVAL = 0.25 + + def __init__(self, tile_size): + self._tile_size = tile_size + self._tile_textures = {} + self._composite = None + self._d_composite = None + self._origin_row = 0 + self._origin_col = 0 + self._dirty = True + self._composited_tiles = frozenset() + self._last_rebuild = 0.0 + self._lock = threading.Lock() + + def set_tile(self, tr, tc, data): + """Store RGB texture data for tile (tr, tc). + + Parameters + ---------- + data : np.ndarray, shape (th, tw, 3) + RGB values as float32 [0-1]. + """ + with self._lock: + self._tile_textures[(tr, tc)] = np.asarray(data, dtype=np.float32) + self._dirty = True + + def remove_tile(self, tr, tc): + """Remove texture for tile (tr, tc).""" + with self._lock: + if (tr, tc) in self._tile_textures: + del self._tile_textures[(tr, tc)] + self._dirty = True + + def has_tile(self, tr, tc): + return (tr, tc) in self._tile_textures + + def clear(self): + """Remove all tile textures.""" + with self._lock: + self._tile_textures.clear() + self._composite = None + self._d_composite = None + self._dirty = True + + def populate_from_array(self, rgb_texture, tile_size, n_tile_rows, n_tile_cols): + """Slice a monolithic RGB texture into per-tile chunks. + + Parameters + ---------- + rgb_texture : np.ndarray, shape (H, W, 3) + Full-terrain RGB texture (float32 [0-1]). + tile_size : int + Tile size in pixels. + n_tile_rows, n_tile_cols : int + Number of tile rows/columns. + """ + H, W = rgb_texture.shape[:2] + with self._lock: + for tr in range(n_tile_rows): + for tc in range(n_tile_cols): + r0 = tr * tile_size + c0 = tc * tile_size + r1 = min(r0 + tile_size, H) + c1 = min(c0 + tile_size, W) + tile_data = rgb_texture[r0:r1, c0:c1] + # Skip all-zero tiles (no texture data) + if np.all(tile_data == 0): + continue + self._tile_textures[(tr, tc)] = tile_data.copy() + self._dirty = True + + def get_composite(self, visible_tiles): + """Return (gpu_array, offset_row, offset_col) for the render kernel. + + Parameters + ---------- + visible_tiles : set of (tr, tc) + Currently visible tile coordinates. + + Returns + ------- + (d_composite, offset_row, offset_col) or (None, 0, 0) + GPU RGB array ``(H, W, 3)`` and pixel offsets. + """ + with self._lock: + tiles_with_data = visible_tiles & set(self._tile_textures.keys()) + + if not tiles_with_data: + return None, 0, 0 + + tiles_frozen = frozenset(tiles_with_data) + + # Fast path: nothing changed at all. + if not self._dirty and tiles_frozen == self._composited_tiles: + if self._d_composite is not None: + return self._d_composite, self._origin_row, self._origin_col + + # Superset path: visible set contracted but data unchanged — + # existing composite already covers all needed tiles. + if (not self._dirty + and tiles_frozen <= self._composited_tiles + and self._d_composite is not None): + return self._d_composite, self._origin_row, self._origin_col + + # Throttle: don't rebuild more than once per _REBUILD_INTERVAL. + # Background fetch threads mark dirty in bursts; without this + # we'd rebuild+upload every frame while fetches stream in. + now = time.monotonic() + if (self._dirty + and self._d_composite is not None + and (now - self._last_rebuild) < self._REBUILD_INTERVAL): + return self._d_composite, self._origin_row, self._origin_col + + self._dirty = False + self._composited_tiles = tiles_frozen + self._last_rebuild = now + + ts = self._tile_size + + min_tr = min(tr for tr, tc in tiles_with_data) + max_tr = max(tr for tr, tc in tiles_with_data) + min_tc = min(tc for tr, tc in tiles_with_data) + max_tc = max(tc for tr, tc in tiles_with_data) + + origin_row = min_tr * ts + origin_col = min_tc * ts + + comp_h = (max_tr - min_tr + 1) * ts + comp_w = (max_tc - min_tc + 1) * ts + + # Build composite (zeros = no texture) + composite = np.zeros((comp_h, comp_w, 3), dtype=np.float32) + + with self._lock: + for tr, tc in tiles_with_data: + tile_data = self._tile_textures.get((tr, tc)) + if tile_data is None: + continue + r0 = (tr - min_tr) * ts + c0 = (tc - min_tc) * ts + th, tw = tile_data.shape[:2] + composite[r0:r0 + th, c0:c0 + tw] = tile_data + + self._composite = composite + self._origin_row = origin_row + self._origin_col = origin_col + + try: + import cupy + self._d_composite = cupy.asarray(composite) + except ImportError: + self._d_composite = None + return None, 0, 0 + + return self._d_composite, origin_row, origin_col + + def invalidate(self): + """Force recomposite on next get_composite call.""" + self._dirty = True diff --git a/rtxpy/viewer/terrain.py b/rtxpy/viewer/terrain.py index 9e8064d..e99f9bf 100644 --- a/rtxpy/viewer/terrain.py +++ b/rtxpy/viewer/terrain.py @@ -14,9 +14,9 @@ class TerrainState: 'pixel_spacing_x', 'pixel_spacing_y', '_base_pixel_spacing_x', '_base_pixel_spacing_y', 'subsample_factor', - '_terrain_mesh_cache', '_baked_mesh_cache', + '_baked_mesh_cache', '_gpu_terrain', '_gpu_base_terrain', - 'mesh_type', '_water_mask', + '_water_mask', 'vertical_exaggeration', '_land_color_range', 'terrain_skirt', '_terrain_loader', @@ -24,19 +24,20 @@ class TerrainState: '_coord_step_x', '_coord_step_y', '_reload_cooldown', '_last_reload_time', '_terrain_reload_future', '_terrain_reload_pool', + # World-space offset: keeps camera stable across terrain reloads + '_world_offset_x', '_world_offset_y', # LOD state 'lod_enabled', '_terrain_lod_manager', ) def __init__(self, raster, pixel_spacing_x=1.0, pixel_spacing_y=1.0, - mesh_type='heightfield', subsample=1, skirt=True): + subsample=1, skirt=True): self.raster = raster self._base_raster = raster self.pixel_spacing_x = pixel_spacing_x self.pixel_spacing_y = pixel_spacing_y self._base_pixel_spacing_x = pixel_spacing_x self._base_pixel_spacing_y = pixel_spacing_y - self.mesh_type = mesh_type self.subsample_factor = max(1, int(subsample)) self.vertical_exaggeration = 1.0 self.terrain_skirt = skirt @@ -50,7 +51,6 @@ def __init__(self, raster, pixel_spacing_x=1.0, pixel_spacing_y=1.0, self._water_mask = None # Mesh caches - self._terrain_mesh_cache = {} self._baked_mesh_cache = {} self._gpu_terrain = None self._gpu_base_terrain = None @@ -66,6 +66,22 @@ def __init__(self, raster, pixel_spacing_x=1.0, pixel_spacing_y=1.0, self._terrain_reload_future = None self._terrain_reload_pool = None + # World-space offset for stable camera across terrain reloads + self._world_offset_x = 0.0 + self._world_offset_y = 0.0 + # LOD self.lod_enabled = False self._terrain_lod_manager = None + + def clear_all_caches(self): + """Clear all terrain mesh caches (baked meshes and LOD tiles). + + Call this when terrain data or resolution changes to ensure no + stale geometry survives across the different caching layers. + """ + self._baked_mesh_cache.clear() + if self._terrain_lod_manager is not None: + self._terrain_lod_manager._tile_cache.clear() + self._terrain_lod_manager._tile_lods.clear() + self._terrain_lod_manager._pyramid_cache.clear() diff --git a/rtxpy/viewer/terrain_lod.py b/rtxpy/viewer/terrain_lod.py index 739199b..f2f13f8 100644 --- a/rtxpy/viewer/terrain_lod.py +++ b/rtxpy/viewer/terrain_lod.py @@ -6,13 +6,27 @@ are built as individual GAS entries in the OptiX IAS so the raytracer traverses only the detail actually needed. -Tile edges get a short vertical skirt to hide T-junction cracks where -adjacent tiles have different LOD levels. +Adjacent tiles with different LOD levels are stitched at boundaries +(boundary vertex Z values interpolated from the coarser neighbor's +pyramid data) to eliminate T-junction cracks. + +Mesh construction (triangulation, normals, skirt, vertex transforms) +runs in a background ``ThreadPoolExecutor`` so the viewer tick loop +stays responsive while tiles build. Completed meshes are collected +next tick and uploaded to the GPU. """ +import logging +import math +from concurrent.futures import ThreadPoolExecutor + import numpy as np -from ..lod import compute_lod_level, compute_lod_distances +logger = logging.getLogger(__name__) + +from ..lod import (compute_lod_distances, + compute_lod_level_with_hysteresis, + compute_tile_roughness) class TerrainLODManager: @@ -41,8 +55,22 @@ class TerrainLODManager: '_max_lod', '_lod_distances', '_lod_distance_factor', '_H', '_W', '_n_tile_rows', '_n_tile_cols', '_tile_centers', '_tile_lods', '_tile_cache', - '_active_tiles', '_last_update_pos', '_update_threshold', - '_base_subsample', + '_active_tiles', '_last_update_pos', '_last_update_fwd', + '_update_threshold_sq', + '_base_subsample', '_pyramid_cache', '_has_in_flight_work', + '_tile_half_diag_sq', '_tile_half_diag', + '_max_dist_sq', '_max_dist', '_streaming_max_dist', + '_tile_world_x', '_tile_world_y', + '_offset_x', '_offset_y', + 'max_tiles', 'per_tick_build_limit', + '_tile_roughness', + '_tile_data_fn', '_streaming', '_crs_origin', '_crs_spacing', + '_tile_data_cache', '_io_futures', + '_executor', '_pending_futures', '_threaded', + '_batched', '_lod_tile_meshes', '_dirty_lods', '_batch_gids', + '_hf_enabled', '_hf_gid', '_hf_dirty', '_hf_tile_size', + '_build_retries', + '_on_tile_added', '_on_tile_removed', ) def __init__(self, terrain_np, tile_size=128, @@ -63,39 +91,182 @@ def __init__(self, terrain_np, tile_size=128, self._n_tile_rows = (H + tile_size - 1) // tile_size self._n_tile_cols = (W + tile_size - 1) // tile_size + # World-space offset applied to tile vertices (for stable camera + # across terrain reloads). Must be set before _recompute_tile_centers. + self._offset_x = 0.0 + self._offset_y = 0.0 + + # Streaming: when set, tiles outside terrain_np bounds are fetched + # from this callback instead of from the pyramid cache. + self._tile_data_fn = None + self._streaming = False + self._tile_data_cache = {} # cache_key -> np.ndarray (prefetched I/O) + self._io_futures = {} # cache_key -> Future (in-flight prefetches) + # CRS coordinate origin and signed spacing — used to convert pixel + # indices to CRS coordinates for tile_data_fn calls. Without this, + # the callback receives viewer world-space coords (pixel * abs(spacing)) + # which are NOT valid CRS coordinates (UTM easting/northing). + self._crs_origin = None # (crs_x0, crs_y0) or None + self._crs_spacing = None # (crs_dx, crs_dy) signed, or None + # Pre-compute tile centres in world coordinates self._tile_centers = {} - for tr in range(self._n_tile_rows): - for tc in range(self._n_tile_cols): - r0 = tr * tile_size - c0 = tc * tile_size - r1 = min(r0 + tile_size, H) - c1 = min(c0 + tile_size, W) - cx = (c0 + c1) * 0.5 * pixel_spacing_x - cy = (r0 + r1) * 0.5 * pixel_spacing_y - self._tile_centers[(tr, tc)] = (cx, cy) + self._recompute_tile_centers() - # LOD distance thresholds - tile_diag = np.sqrt( - (tile_size * pixel_spacing_x) ** 2 - + (tile_size * pixel_spacing_y) ** 2 - ) + # LOD distance thresholds (store squared for fast comparison) + tile_diag_sq = ((tile_size * pixel_spacing_x) ** 2 + + (tile_size * pixel_spacing_y) ** 2) + self._tile_half_diag_sq = tile_diag_sq * 0.25 + self._tile_half_diag = np.sqrt(tile_diag_sq * 0.25) self._lod_distances = compute_lod_distances( - tile_diag, factor=lod_distance_factor, max_lod=max_lod) + np.sqrt(tile_diag_sq), factor=lod_distance_factor, + max_lod=max_lod) + max_dist = (self._lod_distances[-1] * 3.0 + if self._lod_distances else float('inf')) + self._max_dist_sq = max_dist * max_dist + self._max_dist = max_dist + self._streaming_max_dist = (self._lod_distances[-1] * 1.5 + if self._lod_distances else max_dist) + # Tile world-space dimensions (for streaming radius computation) + self._tile_world_x = tile_size * abs(pixel_spacing_x) + self._tile_world_y = tile_size * abs(pixel_spacing_y) + + # Lazy pyramid cache: level -> downsampled array (built on demand) + self._pyramid_cache = {} + + # Per-tile roughness scale factors for terrain-adaptive LOD. + # Rough tiles keep finer detail at greater distance; smooth tiles + # drop to coarser LOD sooner. Computed from full-res terrain. + self._tile_roughness = {} # (tr, tc) -> scale float + self._compute_all_roughness() + + # Progressive loading limits + self.max_tiles = float('inf') # no cap; frustum + distance cull suffice + self.per_tick_build_limit = 8 # max new tile meshes built per update # Per-tile state self._tile_lods = {} # (tr, tc) -> current LOD level - self._tile_cache = {} # (tr, tc, lod, base_sub) -> (verts, indices) + self._tile_cache = {} # (tr, tc, lod, base_sub) -> (verts, indices, normals) self._active_tiles = set() # GAS IDs currently in the scene + self._build_retries = {} # cache_key -> int (failed build count) - # Movement threshold before re-evaluating LOD + # Movement/rotation threshold before re-evaluating LOD self._last_update_pos = None - self._update_threshold = tile_diag * 0.25 + self._last_update_fwd = None + update_threshold = np.sqrt(tile_diag_sq) * 0.25 + self._update_threshold_sq = update_threshold * update_threshold + self._has_in_flight_work = False + + # Threaded mesh building: tiles are built in background threads + # and collected next tick. Disabled by default for test compat; + # engine.py enables it after construction. + self._executor = None + self._pending_futures = {} # cache_key -> Future + self._threaded = False + + # Batched upload: tiles at the same LOD are concatenated into a + # single GAS to reduce IAS instance count. Disabled by default; + # engine.py enables it after construction. + self._batched = False + self._lod_tile_meshes = {} # {lod: {(tr,tc): (verts, indices, normals)}} + self._dirty_lods = set() # LOD levels needing batch rebuild + self._batch_gids = set() # batch GAS IDs currently in the scene + + # Heightfield for LOD 0: in-bounds LOD 0 tiles use heightfield + # ray marching instead of explicit triangles. Disabled by default; + # engine.py enables it after construction. + self._hf_enabled = False + self._hf_gid = 'terrain_lod_hf' + self._hf_dirty = False + self._hf_tile_size = 32 + + # Tile lifecycle callbacks — used by OverlayTileManager to + # keep per-tile overlay data in sync with the tile set. + self._on_tile_added = None # fn(tr, tc, elevation_tile) + self._on_tile_removed = None # fn(tr, tc) # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ + def enable_threaded_building(self, max_workers=4): + """Enable background mesh building with a thread pool. + + Call once after construction to move tile triangulation off the + main thread. Completed meshes are collected in :meth:`update`. + """ + if self._executor is None: + self._executor = ThreadPoolExecutor(max_workers=max_workers) + self._threaded = True + + def enable_batched_upload(self): + """Enable batched GAS upload for LOD tiles. + + Instead of creating individual GAS entries per tile, tiles at + the same LOD level are concatenated into a single GAS. This + reduces IAS instance count from ~50-100 to ~4 (one per active + LOD level), cutting per-ray IAS traversal overhead. + + Batch GAS entries are rebuilt when the set of tiles at any LOD + level changes (tile added, removed, or LOD transition). + """ + self._batched = True + + def set_tile_callbacks(self, on_added=None, on_removed=None): + """Register callbacks for tile lifecycle events. + + Parameters + ---------- + on_added : callable(tr, tc, elevation_tile) or None + Called when a tile is successfully added to the scene. + *elevation_tile* is the tile's elevation data from the + pyramid cache or streaming fetch (may be None for HF tiles). + on_removed : callable(tr, tc) or None + Called when a tile is evicted from the scene. + """ + self._on_tile_added = on_added + self._on_tile_removed = on_removed + + def enable_heightfield_lod0(self, hf_tile_size=32): + """Use heightfield ray marching for LOD 0 in-bounds tiles. + + LOD 0 tiles are full-resolution regular grids — the ideal case + for heightfield ray marching. Heightfield uses bilinear + interpolation for smooth normals and ~4 bytes/pixel vs ~16 + bytes/pixel for explicit triangles. + + LOD 1+ tiles and out-of-bounds streaming tiles continue to use + TIN (triangle) meshes. + + Parameters + ---------- + hf_tile_size : int + AABB tile dimension within the heightfield GAS (default 32). + """ + self._hf_enabled = True + self._hf_tile_size = hf_tile_size + self._hf_dirty = True + + def shutdown(self): + """Shut down the background thread pool and clear batch state.""" + if self._executor is not None: + for fut in self._pending_futures.values(): + fut.cancel() + self._pending_futures.clear() + for fut in self._io_futures.values(): + fut.cancel() + self._io_futures.clear() + self._tile_data_cache.clear() + self._executor.shutdown(wait=False) + self._executor = None + self._threaded = False + self._lod_tile_meshes.clear() + self._dirty_lods.clear() + self._batch_gids.clear() + self._batched = False + self._hf_dirty = False + self._hf_enabled = False + @property def n_tiles(self): """Total number of terrain tiles.""" @@ -110,21 +281,166 @@ def set_base_subsample(self, factor): """Update the global base subsample and invalidate cache.""" if factor != self._base_subsample: self._base_subsample = factor + self._cancel_pending() + self._pyramid_cache.clear() self._tile_cache.clear() + self._tile_data_cache.clear() + self._lod_tile_meshes.clear() + self._dirty_lods.clear() self._tile_lods.clear() + if self._hf_enabled: + self._hf_dirty = True - def set_terrain(self, terrain_np): + def set_offset(self, offset_x, offset_y): + """Set world-space offset for tile vertices. + + Used after terrain reload to keep the camera position stable: + tile vertices are shifted so the same geographic point maps to + the same world-space position as before the reload. + """ + if offset_x != self._offset_x or offset_y != self._offset_y: + self._offset_x = offset_x + self._offset_y = offset_y + self._cancel_pending() + self._tile_cache.clear() + self._tile_data_cache.clear() + self._lod_tile_meshes.clear() + self._dirty_lods.clear() + self._tile_lods.clear() + if self._hf_enabled: + self._hf_dirty = True + self._recompute_tile_centers() + + def set_tile_data_fn(self, fn): + """Set a callback for loading tile elevation data on demand. + + When set, the tile grid becomes unbounded — tiles beyond the + initial terrain array are fetched via this callback as the camera + moves, enabling seamless terrain streaming. + + Parameters + ---------- + fn : callable or None + ``fn(x_min, y_min, x_max, y_max, target_samples)`` + returns ``np.ndarray`` of shape ``(H, W)`` (float32 + elevations, row 0 = northernmost) or ``None``. + Called synchronously; consider wrapping in a thread pool + for slow I/O backends (zarr over network). + """ + self._tile_data_fn = fn + self._streaming = fn is not None + + if self._streaming: + # Use floor division so partial edge tiles become out-of-bounds + # and are handled by streaming (which fetches from zarr without + # NaN reprojection artifacts). Ceiling division leaves a gap + # between the last in-bounds partial tile and the first streaming + # tile when terrain dims aren't divisible by tile_size. + ts = self._tile_size + new_rows = self._H // ts + new_cols = self._W // ts + if new_rows != self._n_tile_rows or new_cols != self._n_tile_cols: + self._n_tile_rows = new_rows + self._n_tile_cols = new_cols + self._recompute_tile_centers() + self._compute_all_roughness() + + def set_crs_transform(self, origin_x, origin_y, dx, dy): + """Set CRS coordinate origin and signed pixel spacing. + + Required for streaming tiles: converts pixel indices to actual + CRS coordinates (e.g. UTM easting/northing) before calling + ``tile_data_fn``. Without this, the callback receives viewer + world-space coordinates which are not valid CRS values. + + Parameters + ---------- + origin_x, origin_y : float + CRS coordinate of the first pixel (row 0, col 0). + dx, dy : float + Signed CRS spacing per pixel. For UTM rasters, ``dx > 0`` + (easting increases with column) and ``dy < 0`` (northing + decreases with row, since row 0 = north). + """ + self._crs_origin = (origin_x, origin_y) + self._crs_spacing = (dx, dy) + + def _tile_center(self, tr, tc): + """Return the world-space center of tile (tr, tc). + + Uses precomputed values for in-bounds tiles, otherwise + computes on-the-fly for streaming tiles. + """ + cached = self._tile_centers.get((tr, tc)) + if cached is not None: + return cached + ts = self._tile_size + cx = (tc * ts + ts * 0.5) * self._psx + self._offset_x + cy = (tr * ts + ts * 0.5) * self._psy + self._offset_y + return (cx, cy) + + def _recompute_tile_centers(self): + """Recompute tile centre positions from current grid dimensions.""" + self._tile_centers.clear() + ts = self._tile_size + H, W = self._H, self._W + ox, oy = self._offset_x, self._offset_y + for tr in range(self._n_tile_rows): + for tc in range(self._n_tile_cols): + r0 = tr * ts + c0 = tc * ts + r1 = min(r0 + ts, H) + c1 = min(c0 + ts, W) + cx = (c0 + c1) * 0.5 * self._psx + ox + cy = (r0 + r1) * 0.5 * self._psy + oy + self._tile_centers[(tr, tc)] = (cx, cy) + + def set_terrain(self, terrain_np, offset_x=None, offset_y=None): """Replace the terrain data (e.g. after dynamic reload).""" self._terrain_np = terrain_np H, W = terrain_np.shape self._H = H self._W = W + ts = self._tile_size + if self._streaming: + self._n_tile_rows = H // ts + self._n_tile_cols = W // ts + else: + self._n_tile_rows = (H + ts - 1) // ts + self._n_tile_cols = (W + ts - 1) // ts + if offset_x is not None: + self._offset_x = offset_x + if offset_y is not None: + self._offset_y = offset_y + self._recompute_tile_centers() + self._compute_all_roughness() + self._cancel_pending() + self._pyramid_cache.clear() self._tile_cache.clear() + self._tile_data_cache.clear() + self._lod_tile_meshes.clear() + self._dirty_lods.clear() self._tile_lods.clear() + if self._hf_enabled: + self._hf_dirty = True - def update(self, camera_pos, rtx, ve=1.0, force=False): + def update(self, camera_pos, rtx, ve=1.0, force=False, + camera_front=None, fov=60.0): """Re-evaluate LOD per tile and rebuild changed tiles. + Only tiles within LOD distance range are considered. Tiles are + sorted by distance and capped at ``max_tiles``. At most + ``per_tick_build_limit`` new meshes are built per call so the + viewer stays responsive while tiles load progressively. + + Frustum culling prevents *building* tiles outside the view cone + but never *removes* already-built tiles — they stay in the scene + until they leave the distance range. This avoids flicker when + the camera rotates. + + LOD transitions use hysteresis (20% dead zone) to prevent + popping when the camera oscillates near a threshold boundary. + Parameters ---------- camera_pos : array-like @@ -135,6 +451,11 @@ def update(self, camera_pos, rtx, ve=1.0, force=False): Current vertical exaggeration. force : bool If True, rebuild all tiles regardless of movement threshold. + camera_front : array-like or None + Camera forward direction ``[x, y, z]``. When provided, + tiles outside the view frustum are skipped for building. + fov : float + Horizontal field of view in degrees (used for frustum cull). Returns ------- @@ -143,56 +464,462 @@ def update(self, camera_pos, rtx, ve=1.0, force=False): """ cam_x, cam_y = float(camera_pos[0]), float(camera_pos[1]) - # Skip if camera hasn't moved enough - if not force and self._last_update_pos is not None: + # Skip if camera hasn't moved or rotated enough + if not force and not self._has_in_flight_work and self._last_update_pos is not None: dx = cam_x - self._last_update_pos[0] dy = cam_y - self._last_update_pos[1] - if dx * dx + dy * dy < self._update_threshold ** 2: + moved = dx * dx + dy * dy >= self._update_threshold_sq + # Check rotation change via forward-direction dot product + rotated = False + if not moved and camera_front is not None and self._last_update_fwd is not None: + dot = (float(camera_front[0]) * self._last_update_fwd[0] + + float(camera_front[1]) * self._last_update_fwd[1]) + rotated = dot < 0.95 # ~18° change + if not moved and not rotated: return False self._last_update_pos = (cam_x, cam_y) + if camera_front is not None: + self._last_update_fwd = (float(camera_front[0]), + float(camera_front[1])) + + # Frustum culling setup: project forward to XY plane. + # When looking steeply down (small XY component), the 2D + # projection is unreliable — disable culling in that case. + frustum_cos = None + fwd_x = fwd_y = 0.0 + if camera_front is not None: + fwd_x = float(camera_front[0]) + fwd_y = float(camera_front[1]) + fwd_len_sq = fwd_x * fwd_x + fwd_y * fwd_y + # Only cull when the camera has a meaningful horizontal + # component (fwd_len > 0.3 ≈ pitch shallower than ~73°). + if fwd_len_sq > 0.09: # 0.3^2 + fwd_len = math.sqrt(fwd_len_sq) + fwd_x /= fwd_len + fwd_y /= fwd_len + # Widen frustum for steeper pitch: as fwd_len shrinks + # toward 0.3 the ground footprint of the view cone grows. + pitch_margin = 20.0 * (1.0 - fwd_len) / 0.7 # 0..~20° + half_angle = math.radians(fov * 0.5 + 20.0 + pitch_margin) + frustum_cos = math.cos(half_angle) + + tile_half_diag_sq = self._tile_half_diag_sq + # Use a tighter radius for streaming to keep tile count + # manageable. Beyond the last LOD threshold all tiles are at + # max LOD — adding more doesn't improve quality but each one + # grows the batch GAS that must be rebuilt on LOD changes. + if self._streaming: + max_dist = self._streaming_max_dist + max_dist_sq = max_dist * max_dist + else: + max_dist_sq = self._max_dist_sq + max_dist = self._max_dist + + # Pass 1: find all tiles within distance range, and flag which + # are also inside the view frustum (candidates for building). + # Uses squared distances to avoid sqrt per tile. + in_frustum = [] # tiles that passed distance + frustum culling + in_range_ids = set() # all distance-valid tiles (frustum-independent) + in_range_rc = set() # (tr, tc) pairs for all in-range tiles + + lod_distances = self._lod_distances + max_lod = self._max_lod + tile_lods = self._tile_lods + + # Determine tile iteration range. For streaming mode, compute + # tiles near the camera (unbounded grid). Otherwise use the + # fixed grid from the initial terrain array. + if self._streaming: + ts = self._tile_size + tc_f = (cam_x - self._offset_x) / (ts * self._psx) - 0.5 + tr_f = (cam_y - self._offset_y) / (ts * self._psy) - 0.5 + tc_cam = int(round(tc_f)) + tr_cam = int(round(tr_f)) + r_tc = int(max_dist / self._tile_world_x) + 1 + r_tr = int(max_dist / self._tile_world_y) + 1 + tile_iter = ((tr, tc) + for tr in range(tr_cam - r_tr, tr_cam + r_tr + 1) + for tc in range(tc_cam - r_tc, tc_cam + r_tc + 1)) + else: + tile_iter = ((tr, tc) + for tr in range(self._n_tile_rows) + for tc in range(self._n_tile_cols)) + + for tr, tc in tile_iter: + cx, cy = self._tile_center(tr, tc) + dx_t = cx - cam_x + dy_t = cy - cam_y + dist_sq = dx_t * dx_t + dy_t * dy_t + if dist_sq > max_dist_sq: + continue + + # Track all in-range tiles (prevents eviction on rotate). + gid = _tile_gid(tr, tc) + in_range_ids.add(gid) + in_range_rc.add((tr, tc)) + + # Frustum check — skip LOD/roughness for tiles behind the + # camera. Tiles close enough to overlap the view origin + # (dist_sq <= tile_half_diag_sq) are always kept. + dist = math.sqrt(dist_sq) + if frustum_cos is not None and dist_sq > tile_half_diag_sq: + inv_dist = 1.0 / dist + cos_angle = (dx_t * fwd_x + dy_t * fwd_y) * inv_dist + tile_margin = self._tile_half_diag * inv_dist + if cos_angle + tile_margin < frustum_cos: + continue + + prev_lod = tile_lods.get((tr, tc), -1) + # Terrain-adaptive LOD: scale distance by roughness so + # rough tiles appear closer (keep detail) and smooth tiles + # appear farther (drop detail sooner). + roughness_scale = self._tile_roughness.get((tr, tc), 1.0) + effective_dist = dist / roughness_scale if roughness_scale > 0 else dist + lod = compute_lod_level_with_hysteresis( + effective_dist, lod_distances, prev_lod) + lod = min(lod, max_lod) + entry = (dist_sq, tr, tc, lod, gid) + in_frustum.append(entry) + + # Build candidates are the in-frustum tiles, sorted by distance + in_frustum.sort() + if len(in_frustum) > self.max_tiles: + in_frustum = in_frustum[:self.max_tiles] + + # Two-pass build: fill holes (never-built tiles) before + # spending budget on LOD transitions. A tile at the wrong LOD + # is much better than a missing tile. + unbuilt = [] + lod_updates = [] + for entry in in_frustum: + _dist_sq, tr, tc, lod, gid = entry + prev_lod = tile_lods.get((tr, tc), -1) + if lod == prev_lod and not force: + continue + if gid not in self._active_tiles: + unbuilt.append(entry) + else: + lod_updates.append(entry) changed = False - new_tile_ids = set() - for tr in range(self._n_tile_rows): - for tc in range(self._n_tile_cols): - cx, cy = self._tile_centers[(tr, tc)] - dist = np.sqrt((cam_x - cx) ** 2 + (cam_y - cy) ** 2) - - lod = compute_lod_level(dist, self._lod_distances) - lod = min(lod, self._max_lod) - - tile_id = _tile_gid(tr, tc) - new_tile_ids.add(tile_id) - - prev_lod = self._tile_lods.get((tr, tc), -1) - if lod != prev_lod or force: - verts, indices = self._get_tile_mesh(tr, tc, lod) - if verts is not None: - # Apply VE - if ve != 1.0: - verts = verts.copy() - verts[2::3] *= ve - rtx.add_geometry(tile_id, verts, indices) - self._tile_lods[(tr, tc)] = lod - changed = True - - # Remove stale tiles (shouldn't happen, but be safe) - for old_id in self._active_tiles - new_tile_ids: - rtx.remove_geometry(old_id) - changed = True - - self._active_tiles = new_tile_ids + # --- Phase A: collect completed async builds --- + changed |= self._collect_completed_builds(rtx, ve) + + # --- Phase B: process tile queue (build/upload/stage) --- + b_changed, pending = self._process_tile_queue( + unbuilt + lod_updates, rtx, ve) + changed |= b_changed + + # --- Phase C: prefetch I/O for streaming tiles --- + if (self._streaming and self._threaded + and self._executor is not None): + self._prefetch_streaming_tiles(unbuilt + lod_updates) + + # --- Phase D: remove stale tiles --- + changed |= self._remove_stale_tiles(rtx, in_range_ids, in_range_rc) + + # --- Phase E: rebuild heightfield and batch GAS --- + if self._hf_enabled and self._hf_dirty: + changed |= self._rebuild_heightfield(rtx, ve) + if self._batched: + changed |= self._rebuild_batches(rtx) + + self._has_in_flight_work = (pending or bool(self._pending_futures) + or bool(self._io_futures) + or bool(self._dirty_lods)) + return changed + + # ------------------------------------------------------------------ + # update() helper phases + # ------------------------------------------------------------------ + + _MAX_BUILD_RETRIES = 3 + + def _collect_completed_builds(self, rtx, ve): + """Phase A: harvest completed async mesh builds and I/O prefetches. + + Freshly completed meshes are uploaded/staged immediately so they + don't depend on re-passing the frustum check (the build was + already approved on a prior tick). + """ + changed = False + if self._pending_futures: + batched = self._batched + for cache_key, fut in list(self._pending_futures.items()): + if not fut.done(): + continue + del self._pending_futures[cache_key] + exc = fut.exception() + if exc is not None: + retries = self._build_retries.get(cache_key, 0) + 1 + self._build_retries[cache_key] = retries + if retries >= self._MAX_BUILD_RETRIES: + logger.warning( + "tile %s build failed %d times, giving up: %s", + cache_key[:2], retries, exc) + else: + logger.debug( + "tile %s build failed (attempt %d): %s", + cache_key[:2], retries, exc) + continue + self._build_retries.pop(cache_key, None) + result = fut.result() + if result is not None and result[0] is not None: + self._tile_cache[cache_key] = result + # Upload/stage immediately — don't wait for frustum re-check + tr, tc, lod = cache_key[0], cache_key[1], cache_key[2] + gid = _tile_gid(tr, tc) + if batched: + self._stage_tile(gid, tr, tc, lod, result, ve) + else: + changed |= self._upload_tile( + gid, tr, tc, lod, result, ve, rtx) + else: + # Cache empty result so this tile is not rebuilt + self._tile_cache[cache_key] = (None, None, None) + + if self._io_futures: + for cache_key, fut in list(self._io_futures.items()): + if not fut.done(): + continue + del self._io_futures[cache_key] + exc = fut.exception() + if exc is not None: + logger.debug("tile %s I/O prefetch failed: %s", + cache_key[:2], exc) + continue + result = fut.result() + if result is not None: + self._tile_data_cache[cache_key] = result + return changed + + def _process_tile_queue(self, queue, rtx, ve): + """Phase B: build, upload, or stage tiles from the priority queue. + + Returns (changed, pending) where *changed* is True if any tile + GAS was added and *pending* is True if work was deferred. + """ + # Pre-populate pyramid levels so background threads don't race + if self._threaded: + for lvl in range(self._max_lod + 1): + self._get_pyramid_level(lvl) + + changed = False + builds = 0 + pending = False + batched = self._batched + hf_enabled = self._hf_enabled + tile_lods = self._tile_lods + + for _dist_sq, tr, tc, lod, gid in queue: + # Heightfield handles LOD 0 in-bounds tiles — no mesh + # building needed; the heightfield GAS covers them. + if hf_enabled and lod == 0: + in_bounds = (0 <= tr < self._n_tile_rows + and 0 <= tc < self._n_tile_cols) + if in_bounds: + prev_lod = tile_lods.get((tr, tc), -1) + if prev_lod > 0 and batched: + self._unstage_tile(tr, tc, prev_lod) + was_new = (tr, tc) not in self._tile_lods + self._tile_lods[(tr, tc)] = 0 + self._active_tiles.add(gid) + self._hf_dirty = True + if was_new and self._on_tile_added is not None: + self._fire_tile_added(tr, tc) + continue + + cache_key = (tr, tc, lod, self._base_subsample) + + # Already in the cache — upload/stage immediately (free) + cached = self._tile_cache.get(cache_key) + if cached is not None: + if cached[0] is None: + continue # empty tile (no data available) + if batched: + self._stage_tile(gid, tr, tc, lod, cached, ve) + else: + changed |= self._upload_tile( + gid, tr, tc, lod, cached, ve, rtx) + continue + + # Already building in a background thread — skip + if cache_key in self._pending_futures: + pending = True + continue + + # Give up on tiles that have failed too many times + if self._build_retries.get(cache_key, 0) >= self._MAX_BUILD_RETRIES: + continue + + # Respect per-tick build limit + if builds >= self.per_tick_build_limit: + pending = True + continue + builds += 1 + + if self._threaded and self._executor is not None: + fut = self._executor.submit( + self._get_tile_mesh, tr, tc, lod) + self._pending_futures[cache_key] = fut + pending = True + else: + result = self._get_tile_mesh(tr, tc, lod) + if result[0] is not None: + if batched: + self._stage_tile(gid, tr, tc, lod, result, ve) + else: + changed |= self._upload_tile( + gid, tr, tc, lod, result, ve, rtx) + + return changed, pending + + def _prefetch_streaming_tiles(self, queue, max_prefetches=8): + """Phase C: submit I/O prefetches for out-of-bounds streaming tiles. + + Overlaps zarr reads with the current batch of mesh builds so + tile data is ready when the build budget frees up next tick. + At most *max_prefetches* new I/O futures are submitted per tick + to avoid overwhelming the thread pool or network. + """ + submitted = 0 + for _dist_sq, tr, tc, lod, _gid in queue: + if submitted >= max_prefetches: + break + cache_key = (tr, tc, lod, self._base_subsample) + if (0 <= tr < self._n_tile_rows + and 0 <= tc < self._n_tile_cols): + continue # in-bounds tiles use pyramid, not I/O + if (cache_key in self._tile_cache + or cache_key in self._pending_futures + or cache_key in self._tile_data_cache + or cache_key in self._io_futures): + continue + self._io_futures[cache_key] = self._executor.submit( + self._fetch_tile_data, tr, tc, lod) + submitted += 1 + + def _remove_stale_tiles(self, rtx, in_range_ids, in_range_rc): + """Phase D: remove tiles that left the distance range. + + Never removes tiles merely outside the frustum (avoids flicker + when the camera rotates). Evicts mesh cache entries to bound + memory, including orphaned None-cached entries from streaming + tiles that were never activated. + + Returns True if any tile GAS was removed. + """ + stale_ids = self._active_tiles - in_range_ids + + changed = False + batched = self._batched + hf_enabled = self._hf_enabled + + if stale_ids: + # Remove individual GAS entries for stale tiles. In batched + # mode, tiles from the pre-batching initial build still have + # individual GAS entries that must be cleaned up. + for old_id in stale_ids: + if rtx.has_geometry(old_id): + rtx.remove_geometry(old_id) + changed = True + + stale_keys = {k for k in self._tile_lods + if _tile_gid(*k) in stale_ids} + stale_rc = {(k[0], k[1]) for k in stale_keys} + on_removed = self._on_tile_removed + for k in stale_keys: + prev_lod = self._tile_lods.get(k, -1) + if hf_enabled and prev_lod == 0: + self._hf_dirty = True + elif batched and prev_lod >= 0: + self._unstage_tile(k[0], k[1], prev_lod) + del self._tile_lods[k] + if on_removed is not None: + on_removed(k[0], k[1]) + # Evict all cached LOD variants for stale tiles (single pass) + for cache_k in [ck for ck in self._tile_cache + if (ck[0], ck[1]) in stale_rc]: + del self._tile_cache[cache_k] + # Cancel in-flight builds for stale tiles (single pass) + for fk in [fk for fk in self._pending_futures + if (fk[0], fk[1]) in stale_rc]: + self._pending_futures.pop(fk).cancel() + # Cancel in-flight I/O prefetches (single pass) + for fk in [fk for fk in self._io_futures + if (fk[0], fk[1]) in stale_rc]: + self._io_futures.pop(fk).cancel() + # Evict prefetched tile data (single pass) + for dk in [dk for dk in self._tile_data_cache + if (dk[0], dk[1]) in stale_rc]: + del self._tile_data_cache[dk] + # Evict stale build retry entries + for brk in [brk for brk in self._build_retries + if (brk[0], brk[1]) in stale_rc]: + del self._build_retries[brk] + self._active_tiles -= stale_ids + + # Evict orphaned cache entries: tiles cached (including + # None-cached streaming tiles outside DEM extent) but not in + # distance range and not pending build. Without this, None + # entries from tiles the camera flew past accumulate forever. + if self._streaming and self._tile_cache: + pending_rc = {(k[0], k[1]) for k in self._pending_futures} + orphan_keys = [ck for ck in self._tile_cache + if (ck[0], ck[1]) not in in_range_rc + and (ck[0], ck[1]) not in pending_rc] + for ck in orphan_keys: + del self._tile_cache[ck] + return changed def remove_all(self, rtx): """Remove all LOD tile geometries from the scene.""" + self._cancel_pending() + # Remove batch and heightfield GAS entries + for gid in list(self._batch_gids): + rtx.remove_geometry(gid) + self._batch_gids.clear() + self._lod_tile_meshes.clear() + self._dirty_lods.clear() + if self._hf_enabled: + self._hf_dirty = False + # Remove individual tile GAS entries (may be no-ops in batched mode) for tile_id in list(self._active_tiles): rtx.remove_geometry(tile_id) self._active_tiles.clear() self._tile_lods.clear() self._last_update_pos = None + self._last_update_fwd = None + + def get_metrics(self): + """Return structured LOD metrics for programmatic monitoring. + + Returns + ------- + dict + ``active_tiles``: number of tiles with an assigned LOD, + ``cached_variants``: total mesh variants stored, + ``in_flight_builds``: pending mesh build futures, + ``in_flight_io``: pending I/O prefetch futures, + ``pyramid_levels``: number of cached pyramid levels, + ``pyramid_bytes``: approximate memory used by pyramid cache, + ``lod_counts``: dict mapping LOD level to tile count. + """ + from collections import Counter + pyr_bytes = sum(arr.nbytes for arr in self._pyramid_cache.values() + if hasattr(arr, 'nbytes')) + return { + 'active_tiles': len(self._tile_lods), + 'cached_variants': len(self._tile_cache), + 'in_flight_builds': len(self._pending_futures), + 'in_flight_io': len(self._io_futures), + 'pyramid_levels': len(self._pyramid_cache), + 'pyramid_bytes': pyr_bytes, + 'lod_counts': dict(Counter(self._tile_lods.values())), + } def get_stats(self): """Return a summary string of LOD state.""" @@ -200,43 +927,704 @@ def get_stats(self): return "LOD: no tiles" from collections import Counter counts = Counter(self._tile_lods.values()) - parts = [f"L{lvl}:{cnt}" for lvl, cnt in sorted(counts.items())] + parts = [] + for lvl, cnt in sorted(counts.items()): + label = "HF" if self._hf_enabled and lvl == 0 else f"L{lvl}" + parts.append(f"{label}:{cnt}") + active = len(self._tile_lods) + if self._batched or self._hf_enabled: + n_gas = len(self._batch_gids) + staged = sum(len(t) for t in self._lod_tile_meshes.values()) + total_verts = sum(len(v) // 3 + for tiles in self._lod_tile_meshes.values() + for v, _, _ in tiles.values()) + suffix = f", {n_gas} GAS, {staged}stg/{total_verts//1000}Kv" + else: + suffix = "" + if self._streaming: + return f"LOD: {active} tiles ({', '.join(parts)}{suffix})" total = self._n_tile_rows * self._n_tile_cols - return f"LOD tiles: {total} ({', '.join(parts)})" + return f"LOD: {active}/{total} tiles ({', '.join(parts)}{suffix})" # ------------------------------------------------------------------ # Internal # ------------------------------------------------------------------ + def _compute_all_roughness(self): + """Compute roughness scale factors for all in-bounds tiles. + + Uses ``compute_tile_roughness`` on full-resolution tile slices, + then maps raw roughness to a scale factor in ``[0.5, 2.0]`` + via exponential interpolation:: + + scale = 0.5 * 4^t where t = normalized roughness in [0, 1] + + Smooth tiles get scale < 1 (LOD demoted earlier); rough tiles + get scale > 1 (LOD promoted — finer detail kept at greater + distance). When all tiles have equal roughness, every tile + gets scale 1.0 (neutral). + """ + ts = self._tile_size + raw = {} + for tr in range(self._n_tile_rows): + for tc in range(self._n_tile_cols): + r0 = tr * ts + c0 = tc * ts + r1 = min(r0 + ts, self._H) + c1 = min(c0 + ts, self._W) + raw[(tr, tc)] = compute_tile_roughness( + self._terrain_np[r0:r1, c0:c1]) + + if not raw: + self._tile_roughness = {} + return + + vals = np.array(list(raw.values())) + r_min = float(np.min(vals)) + r_max = float(np.max(vals)) + + # If the roughest tile is negligible compared to the terrain's + # elevation range, all tiles are effectively flat — skip + # adaptation to avoid amplifying float32 noise. + elev_range = float( + np.nanmax(self._terrain_np) - np.nanmin(self._terrain_np)) + roughness_floor = max(1e-4, elev_range * 1e-3) + if r_max < roughness_floor or r_max - r_min < 1e-10: + self._tile_roughness = {k: 1.0 for k in raw} + return + + # Log-normalize roughness so the distribution isn't compressed + # by extreme outliers (e.g. one sharp peak among gentle hills). + log_vals = np.log1p(vals) + log_min = float(np.min(log_vals)) + log_max = float(np.max(log_vals)) + log_range = log_max - log_min + + self._tile_roughness = {} + if log_range < 1e-10: + for k in raw: + self._tile_roughness[k] = 1.0 + return + for k, r in raw.items(): + t = (np.log1p(r) - log_min) / log_range # 0..1 + # Exponential: 0.5 * 4^t → t=0 → 0.5, t=0.5 → 1.0, t=1 → 2.0 + self._tile_roughness[k] = 0.5 * (4.0 ** t) + + def _cancel_pending(self): + """Cancel all in-flight background build and I/O futures.""" + for fut in self._pending_futures.values(): + fut.cancel() + self._pending_futures.clear() + for fut in self._io_futures.values(): + fut.cancel() + self._io_futures.clear() + self._build_retries.clear() + + def _upload_tile(self, gid, tr, tc, lod, mesh_data, ve, rtx): + """Stitch, apply VE, and upload a completed tile mesh. + + Returns True if the tile was successfully added. + """ + # own=False: rtx.add_geometry copies into GPU buffers internally + verts, indices, norms = self._prepare_tile( + mesh_data, tr, tc, lod, ve, own=False) + rc = rtx.add_geometry(gid, verts, indices, normals=norms) + if rc == 0 or rc is None: + self._active_tiles.add(gid) + self._tile_lods[(tr, tc)] = lod + if self._on_tile_added is not None: + self._fire_tile_added(tr, tc) + return True + return False + + def _fire_tile_added(self, tr, tc): + """Invoke the on_tile_added callback with this tile's elevation.""" + ts = self._tile_size + in_bounds = (0 <= tr < self._n_tile_rows + and 0 <= tc < self._n_tile_cols) + if in_bounds: + r0, c0 = tr * ts, tc * ts + r1 = min(r0 + ts, self._H) + c1 = min(c0 + ts, self._W) + elev = self._terrain_np[r0:r1, c0:c1] + else: + # Streaming tile — check prefetch cache + cache_key = (tr, tc) + elev = self._tile_data_cache.get(cache_key) + try: + self._on_tile_added(tr, tc, elev) + except Exception: + logger.debug("on_tile_added callback failed for (%d, %d)", + tr, tc, exc_info=True) + + def _needs_stitch(self, tr, tc, lod): + """Check whether any neighbor requires boundary stitching.""" + for nr, nc in ((tr - 1, tc), (tr + 1, tc), + (tr, tc - 1), (tr, tc + 1)): + nlod = self._tile_lods.get((nr, nc), -1) + if nlod < 0 or nlod == lod: + continue + if nlod > lod: + return True + if self._hf_enabled and nlod == 0 and lod > 0: + return True + return False + + def _prepare_tile(self, mesh_data, tr, tc, lod, ve, own=True): + """Prepare cached mesh for upload: stitch boundaries and apply VE. + + Parameters + ---------- + own : bool + If True (default), returned arrays are guaranteed to be + independent copies safe for long-term storage. If False, + cache references may be returned when no mutation is needed + (suitable for immediate GPU upload that copies internally). + """ + cached_verts, indices, cached_normals = mesh_data + needs_stitch = self._needs_stitch(tr, tc, lod) + needs_ve = ve != 1.0 + + if not needs_stitch and not needs_ve: + if own: + return (cached_verts.copy(), indices.copy(), + cached_normals.copy()) + return cached_verts, indices, cached_normals + + verts = cached_verts.copy() + if needs_stitch: + self._stitch_tile_boundary(verts, tr, tc, lod) + if needs_ve: + verts[2::3] *= ve + norms = cached_normals.copy() + norms[2::3] /= ve + length = np.sqrt(norms[0::3]**2 + norms[1::3]**2 + norms[2::3]**2) + length[length < 1e-10] = 1.0 + norms[0::3] /= length + norms[1::3] /= length + norms[2::3] /= length + else: + norms = cached_normals.copy() if own else cached_normals + return verts, indices.copy() if own else indices, norms + + def _stage_tile(self, gid, tr, tc, lod, mesh_data, ve): + """Stage a tile mesh for batch upload (deferred GAS build). + + The tile's stitched+VE-transformed mesh is stored per LOD level. + The batch GAS for that level is rebuilt by :meth:`_rebuild_batches` + at the end of the update tick. + """ + # own=True: staged data persists across ticks, needs owned copies + verts, indices, normals = self._prepare_tile(mesh_data, tr, tc, lod, ve) + # If tile was at a different LOD, unstage from old level + prev_lod = self._tile_lods.get((tr, tc), -1) + if prev_lod >= 0 and prev_lod != lod: + self._unstage_tile(tr, tc, prev_lod) + # Stage into the new LOD's batch + if lod not in self._lod_tile_meshes: + self._lod_tile_meshes[lod] = {} + self._lod_tile_meshes[lod][(tr, tc)] = (verts, indices, normals) + self._dirty_lods.add(lod) + self._active_tiles.add(gid) + self._tile_lods[(tr, tc)] = lod + if self._on_tile_added is not None: + self._fire_tile_added(tr, tc) + + def _tile_grid_dims(self, tr, tc, lod): + """Compute grid dimensions (th, tw) for a tile without building it.""" + bs = self._base_subsample + pyr_level = min(lod, self._max_lod) + pyr_sub = bs * (2 ** pyr_level) + pyr = self._get_pyramid_level(pyr_level) + pH, pW = pyr.shape + subsample = bs * (2 ** lod) + extra = subsample // pyr_sub + + ts = self._tile_size + r0_pyr = (tr * ts) // pyr_sub + c0_pyr = (tc * ts) // pyr_sub + r1_full = min((tr + 1) * ts, self._H) + c1_full = min((tc + 1) * ts, self._W) + r1_pyr = min((r1_full + pyr_sub - 1) // pyr_sub + 1, pH) + c1_pyr = min((c1_full + pyr_sub - 1) // pyr_sub + 1, pW) + + th = len(range(r0_pyr, r1_pyr, extra)) + tw = len(range(c0_pyr, c1_pyr, extra)) + return th, tw + + def _stitch_tile_boundary(self, verts, tr, tc, lod): + """Snap boundary vertex Z to match neighbors at different LODs. + + For each edge where the neighbor has a different LOD (or is a + heightfield tile), the boundary vertices are interpolated from + the reference pyramid level or the neighbor's cached mesh. + Works for both in-bounds and streaming tiles. + """ + in_bounds = (0 <= tr < self._n_tile_rows + and 0 <= tc < self._n_tile_cols) + n_verts = len(verts) // 3 + if n_verts < 4: + return + if in_bounds: + th, tw = self._tile_grid_dims(tr, tc, lod) + if n_verts != th * tw or th < 2 or tw < 2: + return + else: + # Streaming tile: infer grid dims from vertex count. + # Streaming meshes are regular grids, so th * tw == n_verts. + # Count vertices sharing the first row's Y coordinate. + y_coords = verts[1::3] + y0 = y_coords[0] + tol = abs(self._psy) * 0.1 + tw = int(np.sum(np.abs(y_coords - y0) < max(tol, 1e-6))) + if tw < 2: + return + th = n_verts // tw + if th < 2 or th * tw != n_verts: + return + + neighbors = { + 'top': (tr - 1, tc), + 'bottom': (tr + 1, tc), + 'left': (tr, tc - 1), + 'right': (tr, tc + 1), + } + + for edge, (nr, nc) in neighbors.items(): + neighbor_lod = self._tile_lods.get((nr, nc), -1) + if neighbor_lod < 0 or neighbor_lod == lod: + continue + + # Determine reference LOD for stitching + if neighbor_lod > lod: + # Neighbor is coarser — stitch to its grid + ref_lod = neighbor_lod + elif self._hf_enabled and neighbor_lod == 0 and lod > 0: + # Neighbor is heightfield LOD 0 — stitch to full-res + ref_lod = 0 + else: + continue + + ref_z = self._get_boundary_z_ref(tr, tc, edge, ref_lod) + if ref_z is None or len(ref_z) < 1: + continue + + n_ref = len(ref_z) + + if edge in ('top', 'bottom'): + n_self = tw + row = 0 if edge == 'top' else th - 1 + edge_indices = row * tw + np.arange(tw) + else: + n_self = th + col = 0 if edge == 'left' else tw - 1 + edge_indices = np.arange(th) * tw + col + + if n_self < 2: + continue + + # Interpolate reference Z at our boundary positions + positions = (np.arange(n_self, dtype=np.float64) + * (n_ref - 1) / (n_self - 1)) + z_interp = np.interp( + positions, + np.arange(n_ref, dtype=np.float64), + ref_z.astype(np.float64)) + verts[edge_indices * 3 + 2] = z_interp.astype(np.float32) + + def _get_boundary_z_ref(self, tr, tc, edge, ref_lod): + """Get Z values from a reference level at a shared tile boundary. + + Returns the 1-D array of elevation values that the reference + (coarser or heightfield) tile would have along the shared edge. + Falls back to extracting Z from the neighbor's cached mesh + when pyramid data is not available (streaming tiles). + """ + tile_in_bounds = (0 <= tr < self._n_tile_rows + and 0 <= tc < self._n_tile_cols) + + # Pyramid path works when the tile itself is in-bounds + if tile_in_bounds: + return self._get_boundary_z_ref_pyramid(tr, tc, edge, ref_lod) + + # Tile is streaming (out of bounds) — determine the neighbor + # and try to get boundary Z from IT + if edge == 'top': + nr, nc = tr - 1, tc + elif edge == 'bottom': + nr, nc = tr + 1, tc + elif edge == 'left': + nr, nc = tr, tc - 1 + else: + nr, nc = tr, tc + 1 + + neighbor_in_bounds = (0 <= nr < self._n_tile_rows + and 0 <= nc < self._n_tile_cols) + + if neighbor_in_bounds: + # Neighbor is in-bounds — use pyramid via neighbor's coords + opposite = {'top': 'bottom', 'bottom': 'top', + 'left': 'right', 'right': 'left'} + return self._get_boundary_z_ref_pyramid( + nr, nc, opposite[edge], ref_lod) + + # Both tiles are streaming — extract from neighbor's cached mesh + return self._get_boundary_z_from_cache(nr, nc, edge, ref_lod) + + def _get_boundary_z_from_cache(self, nr, nc, edge, ref_lod): + """Extract boundary Z values from a neighbor's cached mesh. + + Used for streaming tiles that have no pyramid data. Reads the + un-VE'd mesh from ``_tile_cache`` and returns Z at the shared + edge. + + Parameters + ---------- + nr, nc : int + Neighbor tile row/column. + edge : str + Edge from the perspective of the tile being stitched + ('top'/'bottom'/'left'/'right'). The neighbor's shared + edge is the opposite. + ref_lod : int + LOD of the neighbor tile. + """ + cache_key = (nr, nc, ref_lod, self._base_subsample) + cached = self._tile_cache.get(cache_key) + if cached is None or cached[0] is None: + return None + + n_verts_cached, _, _ = cached + nverts = len(n_verts_cached) // 3 + if nverts < 4: + return None + + # Infer grid dims from cached mesh + y_coords = n_verts_cached[1::3] + y0 = y_coords[0] + tol = abs(self._psy) * 0.1 + n_tw = int(np.sum(np.abs(y_coords - y0) < max(tol, 1e-6))) + if n_tw < 2: + return None + n_th = nverts // n_tw + if n_th < 2 or n_th * n_tw != nverts: + return None + + z_vals = n_verts_cached[2::3] + + # Map caller's edge to neighbor's opposite edge + opposite = {'top': 'bottom', 'bottom': 'top', + 'left': 'right', 'right': 'left'} + neighbor_edge = opposite[edge] + + if neighbor_edge == 'top': + return z_vals[:n_tw].copy() + elif neighbor_edge == 'bottom': + return z_vals[(n_th - 1) * n_tw:n_th * n_tw].copy() + elif neighbor_edge == 'left': + return z_vals[::n_tw].copy() + elif neighbor_edge == 'right': + return z_vals[n_tw - 1::n_tw].copy() + return None + + def _get_boundary_z_ref_pyramid(self, tr, tc, edge, ref_lod): + """Get Z values from a pyramid level at a shared boundary. + + Original pyramid-based implementation for in-bounds tiles. + """ + pyr_level = min(ref_lod, self._max_lod) + pyr = self._get_pyramid_level(pyr_level) + pH, pW = pyr.shape + pyr_sub = self._base_subsample * (2 ** pyr_level) + subsample_ref = self._base_subsample * (2 ** ref_lod) + extra = subsample_ref // pyr_sub + + ts = self._tile_size + r0_full = tr * ts + c0_full = tc * ts + r1_full = min(r0_full + ts, self._H) + c1_full = min(c0_full + ts, self._W) + c0_pyr = c0_full // pyr_sub + c1_pyr = min((c1_full + pyr_sub - 1) // pyr_sub + 1, pW) + r0_pyr = r0_full // pyr_sub + r1_pyr = min((r1_full + pyr_sub - 1) // pyr_sub + 1, pH) + + if edge == 'top': + r = r0_full // pyr_sub + if r >= pH: + return None + return pyr[r, c0_pyr:c1_pyr:extra].copy() + elif edge == 'bottom': + r = min(r1_full // pyr_sub, pH - 1) + return pyr[r, c0_pyr:c1_pyr:extra].copy() + elif edge == 'left': + c = c0_full // pyr_sub + if c >= pW: + return None + return pyr[r0_pyr:r1_pyr:extra, c].copy() + elif edge == 'right': + c = min(c1_full // pyr_sub, pW - 1) + return pyr[r0_pyr:r1_pyr:extra, c].copy() + return None + + def _unstage_tile(self, tr, tc, lod): + """Remove a tile from its LOD batch and mark the level dirty.""" + tiles = self._lod_tile_meshes.get(lod) + if tiles is not None: + tiles.pop((tr, tc), None) + self._dirty_lods.add(lod) + + def _rebuild_heightfield(self, rtx, ve): + """Rebuild the heightfield GAS for LOD 0 in-bounds tiles. + + Uploads the subsampled elevation array and builds AABB custom + primitives with an active mask derived from current LOD 0 tile + assignments. Only AABB tiles overlapping LOD 0 LOD-tiles get + valid bounds; all others are degenerate (zero-volume). + + Returns True if the heightfield GAS was added or updated. + """ + self._hf_dirty = False + + # Gather in-bounds LOD 0 tiles + lod0_tiles = [ + (tr, tc) for (tr, tc), lod in self._tile_lods.items() + if lod == 0 + and 0 <= tr < self._n_tile_rows + and 0 <= tc < self._n_tile_cols + ] + + gid = self._hf_gid + if not lod0_tiles: + # No LOD 0 in-bounds tiles — remove heightfield GAS + if gid in self._batch_gids: + rtx.remove_geometry(gid) + self._batch_gids.discard(gid) + return True + return False + + # Get subsampled elevation (pyramid level 0) + bs = self._base_subsample + pyr = self._get_pyramid_level(0) + pH, pW = pyr.shape + + hf_ts = self._hf_tile_size + num_tiles_x = math.ceil((pW - 1) / hf_ts) + num_tiles_y = math.ceil((pH - 1) / hf_ts) + + # Build active mask: AABB tiles covered by LOD 0 LOD-tiles. + # Use 2D view + slice assignment instead of per-element Python loop. + active_mask = np.zeros((num_tiles_y, num_tiles_x), dtype=bool) + ts = self._tile_size + for tr, tc in lod0_tiles: + r0_full = tr * ts + c0_full = tc * ts + r1_full = min(r0_full + ts, self._H) + c1_full = min(c0_full + ts, self._W) + # Map to heightfield grid pixel coords + r0_hf = r0_full // bs + c0_hf = c0_full // bs + r1_hf = min((r1_full + bs - 1) // bs, pH) + c1_hf = min((c1_full + bs - 1) // bs, pW) + # Map to AABB tile indices and mark active via slice + ty0 = r0_hf // hf_ts + tx0 = c0_hf // hf_ts + ty1 = min(max(ty0, (r1_hf - 1) // hf_ts) + 1, num_tiles_y) + tx1 = min(max(tx0, (c1_hf - 1) // hf_ts) + 1, num_tiles_x) + active_mask[ty0:ty1, tx0:tx1] = True + active_mask = active_mask.ravel() + + # IAS transform encodes world offset + transform = [ + 1.0, 0.0, 0.0, self._offset_x, + 0.0, 1.0, 0.0, self._offset_y, + 0.0, 0.0, 1.0, 0.0, + ] + + rc = rtx.add_heightfield_geometry( + gid, pyr, pH, pW, + spacing_x=self._psx * bs, + spacing_y=self._psy * bs, + ve=ve, + tile_size=hf_ts, + active_mask=active_mask, + transform=transform) + + if rc == 0 or rc is None: + self._batch_gids.add(gid) + return True + return False + + def _rebuild_batches(self, rtx): + """Rebuild batch GAS entries for dirty LOD levels. + + Concatenates all staged tile meshes at each dirty LOD level + into a single vertex/index/normal buffer and uploads as one + GAS. Empty levels have their batch GAS removed. + + To avoid frame spikes when multiple LODs are dirty (e.g. on + camera rotation), at most one non-empty LOD batch is rebuilt + per tick. Empty-LOD removals are always processed immediately + since they're cheap. + + Returns True if any GAS was added or updated. + """ + if not self._dirty_lods: + return False + changed = False + rebuilt_one = False + remaining = set() + for lod in sorted(self._dirty_lods): + batch_gid = _batch_gid(lod) + tiles = self._lod_tile_meshes.get(lod, {}) + if not tiles: + # All tiles left this LOD — remove the batch GAS (cheap) + if batch_gid in self._batch_gids: + rtx.remove_geometry(batch_gid) + self._batch_gids.discard(batch_gid) + changed = True + continue + + # Only rebuild one non-empty batch per tick to cap frame time + if rebuilt_one: + remaining.add(lod) + continue + rebuilt_one = True + + all_verts = [] + all_indices = [] + all_normals = [] + vert_offset = 0 + for (v, idx, n) in tiles.values(): + all_verts.append(v) + shifted = idx + vert_offset + all_indices.append(shifted) + all_normals.append(n) + vert_offset += len(v) // 3 + + batch_verts = np.concatenate(all_verts) + batch_indices = np.concatenate(all_indices) + batch_normals = np.concatenate(all_normals) + rc = rtx.add_geometry( + batch_gid, batch_verts, batch_indices, + normals=batch_normals) + if rc == 0 or rc is None: + self._batch_gids.add(batch_gid) + changed = True + self._dirty_lods = remaining + return changed + def _get_tile_mesh(self, tr, tc, lod): - """Build or retrieve cached tile mesh.""" + """Build or retrieve cached tile mesh. + + Returns references to cached arrays — caller must copy before + mutating (e.g. for VE scaling). + + If the requested LOD produces a tile too small to triangulate + (< 2×2 grid — common for edge tiles at high subsampling), + falls back to progressively lower LOD levels until one works. + """ cache_key = (tr, tc, lod, self._base_subsample) cached = self._tile_cache.get(cache_key) if cached is not None: - return cached[0].copy(), cached[1].copy() + return cached + + in_bounds = (0 <= tr < self._n_tile_rows + and 0 <= tc < self._n_tile_cols) + + if in_bounds: + # Try requested LOD, fall back to lower levels for small edge tiles + for try_lod in range(lod, -1, -1): + verts, indices, normals = self._build_tile_mesh(tr, tc, try_lod) + if verts is not None: + entry = (verts, indices, normals) + self._tile_cache[cache_key] = entry + return entry + elif self._tile_data_fn is not None: + verts, indices, normals = self._build_streaming_tile_mesh(tr, tc, lod) + if verts is not None: + entry = (verts, indices, normals) + self._tile_cache[cache_key] = entry + return entry + + return None, None, None + + def _get_pyramid_level(self, level): + """Return the downsampled terrain array for *level*, building lazily. + + Level 0 is the terrain at ``base_subsample`` resolution. Each + subsequent level halves via 2x2 NaN-aware box averaging. + Evicts cached levels above ``_max_lod`` to bound memory. + """ + cached = self._pyramid_cache.get(level) + if cached is not None: + return cached - verts, indices = self._build_tile_mesh(tr, tc, lod) - if verts is not None: - self._tile_cache[cache_key] = (verts.copy(), indices.copy()) - return verts, indices + if level == 0: + bs = self._base_subsample + if bs > 1: + arr = self._terrain_np[::bs, ::bs].copy() + else: + arr = self._terrain_np + else: + prev = self._get_pyramid_level(level - 1) + arr = _box_downsample_2x(prev) + + self._pyramid_cache[level] = arr + + # Evict levels beyond current max_lod to bound memory + for stale in [k for k in self._pyramid_cache + if k > self._max_lod]: + del self._pyramid_cache[stale] + + return arr def _build_tile_mesh(self, tr, tc, lod): - """Triangulate a single tile at the given LOD level.""" + """Triangulate a single tile at the given LOD level. + + Pyramid levels are built lazily on first access (box-filter + averaged). This avoids pre-computing the full pyramid for + large zarr-backed terrains. + """ from .. import mesh as mesh_mod subsample = self._base_subsample * (2 ** lod) - r0 = tr * self._tile_size - c0 = tc * self._tile_size - # Extend by one pixel so adjacent tiles share boundary vertices, - # eliminating the one-pixel gap that causes shading seams. - r1 = min(r0 + self._tile_size + 1, self._H) - c1 = min(c0 + self._tile_size + 1, self._W) - - # Extract tile data with subsampling - tile = self._terrain_np[r0:r1:subsample, c0:c1:subsample] + + # Pyramid level for this LOD (level 0 = base_subsample) + pyr_level = min(lod, self._max_lod) + pyr_arr = self._get_pyramid_level(pyr_level) + pyr_sub = self._base_subsample * (2 ** pyr_level) + pH, pW = pyr_arr.shape + + # Map full-res tile pixel coords to pyramid pixel coords. + # Use ceil division for the end boundary so adjacent tiles + # always share exactly one overlapping row/column, even when + # tile_size isn't evenly divisible by pyr_sub. + r0_full = tr * self._tile_size + c0_full = tc * self._tile_size + r1_full = min(r0_full + self._tile_size, self._H) + c1_full = min(c0_full + self._tile_size, self._W) + r0_pyr = r0_full // pyr_sub + c0_pyr = c0_full // pyr_sub + r1_pyr = min((r1_full + pyr_sub - 1) // pyr_sub + 1, pH) + c1_pyr = min((c1_full + pyr_sub - 1) // pyr_sub + 1, pW) + + # If this LOD needs further subsampling beyond the pyramid level + # (happens if lod > max precomputed level), stride the remainder. + extra_sub = subsample // pyr_sub + tile = pyr_arr[r0_pyr:r1_pyr:extra_sub, c0_pyr:c1_pyr:extra_sub] th, tw = tile.shape if th < 2 or tw < 2: - return None, None + return None, None, None + + # Replace NaN values (common at edges from UTM reprojection) + # to avoid degenerate triangles and NaN normals. + bad = ~np.isfinite(tile) + if bad.all(): + return None, None, None + if bad.any(): + tile = tile.copy() + tile[bad] = float(np.nanmean(tile)) # Triangulate using the fast numba/CUDA path num_verts = th * tw @@ -245,109 +1633,172 @@ def _build_tile_mesh(self, tr, tc, lod): indices = np.zeros(num_tris * 3, dtype=np.int32) mesh_mod.triangulate_terrain(verts, indices, tile, scale=1.0) + # Compute smooth per-vertex normals from the elevation grid + eff_sub = pyr_sub * extra_sub + normals = mesh_mod.compute_terrain_normals( + tile, th, tw, + psx=eff_sub * self._psx, + psy=eff_sub * self._psy) + # Transform from local grid coords to world coords. - # triangulate_terrain writes x=w, y=h in grid-local pixel indices. - # We need: x = (c0 + w*subsample) * psx - # y = (r0 + h*subsample) * psy - verts[0::3] = verts[0::3] * subsample * self._psx + c0 * self._psx - verts[1::3] = verts[1::3] * subsample * self._psy + r0 * self._psy - - # Only add skirt on exterior edges (terrain boundary). - # Interior edges shared with adjacent tiles via the +1 overlap - # don't need skirt — overlapping skirt walls cause artifacts. - edges = ( - tr == 0, # top - tc == self._n_tile_cols - 1, # right - tr == self._n_tile_rows - 1, # bottom - tc == 0, # left - ) - verts, indices = _add_tile_skirt(verts, indices, th, tw, edges=edges) - - return verts, indices + # Each pixel in the tile spans subsample full-res pixels. + # The world offset keeps positions stable across terrain reloads. + verts[0::3] = verts[0::3] * eff_sub * self._psx + c0_full * self._psx + self._offset_x + verts[1::3] = verts[1::3] * eff_sub * self._psy + r0_full * self._psy + self._offset_y + + return verts, indices, normals + + def _fetch_tile_data(self, tr, tc, lod): + """Fetch raw elevation data for a streaming tile (I/O only). + + Called from the thread pool to prefetch zarr reads ahead of + mesh building. Returns ``np.ndarray`` or ``None``. + """ + ts = self._tile_size + subsample = self._base_subsample * (2 ** lod) + target_samples = max(2, ts // subsample) + + c0 = tc * ts + r0 = tr * ts + + # Convert pixel indices to CRS coordinates (e.g. UTM) when + # available; otherwise fall back to viewer world-space coords. + if self._crs_origin is not None and self._crs_spacing is not None: + crs_x0, crs_y0 = self._crs_origin + crs_dx, crs_dy = self._crs_spacing + wx0 = crs_x0 + c0 * crs_dx + wy0 = crs_y0 + r0 * crs_dy + wx1 = crs_x0 + (c0 + ts) * crs_dx + wy1 = crs_y0 + (r0 + ts) * crs_dy + else: + wx0 = c0 * self._psx + self._offset_x + wy0 = r0 * self._psy + self._offset_y + wx1 = (c0 + ts) * self._psx + self._offset_x + wy1 = (r0 + ts) * self._psy + self._offset_y + + x_min, x_max = min(wx0, wx1), max(wx0, wx1) + y_min, y_max = min(wy0, wy1), max(wy0, wy1) + + try: + tile = self._tile_data_fn(x_min, y_min, x_max, y_max, + target_samples) + except Exception: + return None + if tile is None: + return None + return np.asarray(tile, dtype=np.float32) + + def _build_streaming_tile_mesh(self, tr, tc, lod): + """Build a tile mesh from streamed data (outside initial terrain). + + Uses prefetched tile data from ``_tile_data_cache`` when + available, otherwise fetches synchronously via ``_tile_data_fn``. + """ + from .. import mesh as mesh_mod + + # Check prefetch cache first (populated by I/O futures) + cache_key = (tr, tc, lod, self._base_subsample) + tile = self._tile_data_cache.pop(cache_key, None) + + if tile is None: + # No prefetched data — fetch synchronously + tile = self._fetch_tile_data(tr, tc, lod) + if tile is None: + return None, None, None + + ts = self._tile_size + + # Replace NaN / extreme values with 0 — NaN creates degenerate + # triangles and fill-value leaks produce vertices at -21M Z. + bad = ~np.isfinite(tile) + if bad.all(): + return None, None, None + if bad.any(): + tile = tile.copy() + tile[bad] = 0.0 + + # World-space bounds of this tile + c0 = tc * ts + r0 = tr * ts + wx0 = c0 * self._psx + self._offset_x + wy0 = r0 * self._psy + self._offset_y + wx1 = (c0 + ts) * self._psx + self._offset_x + wy1 = (r0 + ts) * self._psy + self._offset_y + th, tw = tile.shape + if th < 2 or tw < 2: + return None, None, None + + # Triangulate + num_verts = th * tw + num_tris = (th - 1) * (tw - 1) * 2 + verts = np.zeros(num_verts * 3, dtype=np.float32) + indices = np.zeros(num_tris * 3, dtype=np.int32) + mesh_mod.triangulate_terrain(verts, indices, tile, scale=1.0) + + # Compute smooth normals using effective pixel spacing + if tw > 1: + eff_psx = (wx1 - wx0) / (tw - 1) + else: + eff_psx = 1.0 + if th > 1: + eff_psy = (wy1 - wy0) / (th - 1) + else: + eff_psy = 1.0 + normals = mesh_mod.compute_terrain_normals( + tile, th, tw, psx=eff_psx, psy=eff_psy) + + # Transform to world coordinates. + # triangulate_terrain places vertex at (col, row, z) with scale=1. + # Map col ∈ [0, tw-1] → [wx0, wx1], row ∈ [0, th-1] → [wy0, wy1]. + if tw > 1: + verts[0::3] = verts[0::3] / (tw - 1) * (wx1 - wx0) + wx0 + else: + verts[0::3] = (wx0 + wx1) * 0.5 + if th > 1: + verts[1::3] = verts[1::3] / (th - 1) * (wy1 - wy0) + wy0 + else: + verts[1::3] = (wy0 + wy1) * 0.5 + + return verts, indices, normals # ------------------------------------------------------------------ # Module-level helpers # ------------------------------------------------------------------ + +def _box_downsample_2x(arr): + """Downsample a 2D array by 2× using NaN-aware box averaging.""" + import warnings + + H, W = arr.shape + # Trim to even dimensions + hh = H - H % 2 + ww = W - W % 2 + if hh < 2 or ww < 2: + return arr.copy() + block = arr[:hh, :ww].reshape(hh // 2, 2, ww // 2, 2) + # nanmean emits "Mean of empty slice" for all-NaN blocks (water); + # the resulting NaN is correct, so suppress the warning. + with warnings.catch_warnings(): + warnings.simplefilter('ignore', RuntimeWarning) + out = np.nanmean(block, axis=(1, 3)).astype(arr.dtype) + return out + + def _tile_gid(tr, tc): """Geometry ID for a terrain LOD tile.""" return f'terrain_lod_r{tr}_c{tc}' -def is_terrain_lod_gid(gid): - """Return True if *gid* belongs to a terrain LOD tile.""" - return gid.startswith('terrain_lod_r') +def _batch_gid(lod): + """Geometry ID for a batched terrain LOD GAS at *lod* level.""" + return f'terrain_lod_batch_L{lod}' + + +def is_terrain_lod_gid(gid): + """Return True if *gid* belongs to a terrain LOD tile or batch.""" + return gid.startswith('terrain_lod_') -def _add_tile_skirt(vertices, indices, H, W, skirt_depth=None, - edges=(True, True, True, True)): - """Add a thin skirt around specified tile edges. - Parameters - ---------- - edges : tuple of bool - ``(top, right, bottom, left)`` — which edges get skirt - geometry. Interior tile edges shared with adjacent tiles - should be False to avoid overlapping wall triangles. - """ - if not any(edges): - return vertices, indices - - z_vals = vertices[2::3] - z_min = float(np.nanmin(z_vals)) - z_max = float(np.nanmax(z_vals)) - - if skirt_depth is None: - z_range = z_max - z_min - skirt_depth = max(0.5, z_range * 0.02) - - skirt_z = z_min - skirt_depth - - # Build clockwise perimeter (same order as mesh.add_terrain_skirt) - top = np.arange(W, dtype=np.int32) - right = (np.arange(1, H, dtype=np.int32)) * W + (W - 1) - bottom = (H - 1) * W + np.arange(W - 2, -1, -1, dtype=np.int32) - left = np.arange(H - 2, 0, -1, dtype=np.int32) * W - perim = np.concatenate([top, right, bottom, left]) - n_perim = len(perim) - n_orig = len(vertices) // 3 - - skirt_verts = np.empty(n_perim * 3, dtype=np.float32) - skirt_verts[0::3] = vertices[perim * 3] - skirt_verts[1::3] = vertices[perim * 3 + 1] - skirt_verts[2::3] = skirt_z - - # Mask: only create wall triangles for active edges. - # Perimeter segments per edge: top W-1, right H-1, bottom W-1, left H-1. - edge_top, edge_right, edge_bottom, edge_left = edges - seg_mask = np.zeros(n_perim, dtype=bool) - off = 0 - for active, count in [(edge_top, W - 1), (edge_right, H - 1), - (edge_bottom, W - 1), (edge_left, H - 1)]: - if active: - seg_mask[off:off + count] = True - off += count - - active_segs = np.where(seg_mask)[0].astype(np.int32) - if len(active_segs) == 0: - return vertices, indices - - idx_next = (active_segs + 1) % n_perim - top_a = perim[active_segs] - top_b = perim[idx_next] - bot_a = (n_orig + active_segs).astype(np.int32) - bot_b = (n_orig + idx_next).astype(np.int32) - - n_active = len(active_segs) - wall_tris = np.empty(n_active * 6, dtype=np.int32) - wall_tris[0::6] = top_a - wall_tris[1::6] = bot_b - wall_tris[2::6] = top_b - wall_tris[3::6] = top_a - wall_tris[4::6] = bot_a - wall_tris[5::6] = bot_b - - new_verts = np.concatenate([vertices, skirt_verts]) - new_indices = np.concatenate([indices, wall_tris]) - return new_verts, new_indices