From 71af9eb6691b3149ee1db53b05a9c357746c845a Mon Sep 17 00:00:00 2001 From: User Date: Thu, 18 Dec 2025 21:59:50 +0000 Subject: [PATCH] feat: Add SPLAT and SOG export formats for 3D Gaussians --- src/sharp/cli/predict.py | 26 +++- src/sharp/utils/gaussians.py | 230 +++++++++++++++++++++++++++++++++++ 2 files changed, 255 insertions(+), 1 deletion(-) diff --git a/src/sharp/cli/predict.py b/src/sharp/cli/predict.py index 8914bb5..5b18bfa 100644 --- a/src/sharp/cli/predict.py +++ b/src/sharp/cli/predict.py @@ -26,6 +26,8 @@ Gaussians3D, SceneMetaData, save_ply, + save_splat, + save_sog, unproject_gaussians, ) @@ -66,6 +68,15 @@ default=False, help="Whether to render trajectory for checkpoint.", ) +@click.option( + "-f", + "--format", + "export_formats", + type=click.Choice(["ply", "splat", "sog"], case_sensitive=False), + multiple=True, + default=["ply"], + help="Output format(s). Can specify multiple: -f ply -f splat -f sog", +) @click.option( "--device", type=str, @@ -78,12 +89,16 @@ def predict_cli( output_path: Path, checkpoint_path: Path, with_rendering: bool, + export_formats: tuple[str, ...], device: str, verbose: bool, ): """Predict Gaussians from input images.""" logging_utils.configure(logging.DEBUG if verbose else logging.INFO) + # Normalize export formats to lowercase + export_formats = tuple(fmt.lower() for fmt in export_formats) + extensions = io.get_supported_image_extensions() image_paths = [] @@ -145,7 +160,16 @@ def predict_cli( gaussians = predict_image(gaussian_predictor, image, f_px, torch.device(device)) LOGGER.info("Saving 3DGS to %s", output_path) - save_ply(gaussians, f_px, (height, width), output_path / f"{image_path.stem}.ply") + for fmt in export_formats: + if fmt == "ply": + save_ply(gaussians, f_px, (height, width), output_path / f"{image_path.stem}.ply") + LOGGER.info("Saved PLY: %s", output_path / f"{image_path.stem}.ply") + elif fmt == "splat": + save_splat(gaussians, f_px, (height, width), output_path / f"{image_path.stem}.splat") + LOGGER.info("Saved SPLAT: %s", output_path / f"{image_path.stem}.splat") + elif fmt == "sog": + save_sog(gaussians, f_px, (height, width), output_path / f"{image_path.stem}.sog") + LOGGER.info("Saved SOG: %s", output_path / f"{image_path.stem}.sog") if with_rendering: output_video_path = (output_path / image_path.stem).with_suffix(".mp4") diff --git a/src/sharp/utils/gaussians.py b/src/sharp/utils/gaussians.py index 9c037eb..b97b45b 100644 --- a/src/sharp/utils/gaussians.py +++ b/src/sharp/utils/gaussians.py @@ -478,3 +478,233 @@ def _inverse_sigmoid(tensor: torch.Tensor) -> torch.Tensor: plydata.write(path) return plydata + + +@torch.no_grad() +def save_splat( + gaussians: Gaussians3D, f_px: float, image_shape: tuple[int, int], path: Path +) -> None: + """Save Gaussians to .splat format (compact binary format for web viewers). + + The .splat format is a simple binary format used by web-based 3DGS viewers. + Each Gaussian is stored as 32 bytes: + - 12 bytes: xyz position (3 x float32) + - 12 bytes: scales (3 x float32) + - 4 bytes: RGBA color (4 x uint8) + - 4 bytes: quaternion rotation (4 x uint8, encoded as (q * 128 + 128)) + + Gaussians are sorted by size * opacity (descending) for progressive rendering. + """ + xyz = gaussians.mean_vectors.flatten(0, 1).cpu().numpy() + scales = gaussians.singular_values.flatten(0, 1).cpu().numpy() + quats = gaussians.quaternions.flatten(0, 1).cpu().numpy() + colors_rgb = cs_utils.linearRGB2sRGB(gaussians.colors.flatten(0, 1)).cpu().numpy() + opacities = gaussians.opacities.flatten(0, 1).cpu().numpy() + + # Sort by size * opacity (descending) for progressive rendering + sort_idx = np.argsort(-(scales.prod(axis=1) * opacities)) + + # Normalize quaternions + quats = quats / np.linalg.norm(quats, axis=1, keepdims=True) + + with open(path, "wb") as f: + for i in sort_idx: + f.write(xyz[i].astype(np.float32).tobytes()) + f.write(scales[i].astype(np.float32).tobytes()) + rgba = np.concatenate([colors_rgb[i], [opacities[i]]]) + f.write((rgba * 255).clip(0, 255).astype(np.uint8).tobytes()) + f.write((quats[i] * 128 + 128).clip(0, 255).astype(np.uint8).tobytes()) + + +@torch.no_grad() +def save_sog( + gaussians: Gaussians3D, f_px: float, image_shape: tuple[int, int], path: Path +) -> None: + """Save Gaussians to SOG format (Spatially Ordered Gaussians). + + SOG is a highly compressed format using quantization and WebP images. + Typically 15-20x smaller than PLY. The format stores data in a ZIP archive + containing WebP images for positions, rotations, scales, and colors. + + Reference: https://github.com/aras-p/sog-format + """ + import io + import json + import math + import zipfile + + from PIL import Image + + xyz = gaussians.mean_vectors.flatten(0, 1).cpu().numpy() + scales = gaussians.singular_values.flatten(0, 1).cpu().numpy() + quats = gaussians.quaternions.flatten(0, 1).cpu().numpy() + colors_linear = gaussians.colors.flatten(0, 1).cpu().numpy() + opacities = gaussians.opacities.flatten(0, 1).cpu().numpy() + + num_gaussians = len(xyz) + + # Compute image dimensions (roughly square) + img_width = int(math.ceil(math.sqrt(num_gaussians))) + img_height = int(math.ceil(num_gaussians / img_width)) + total_pixels = img_width * img_height + + # Pad arrays to fill image + def pad_array(arr: np.ndarray, total: int) -> np.ndarray: + if len(arr) < total: + pad_shape = (total - len(arr),) + arr.shape[1:] + return np.concatenate([arr, np.zeros(pad_shape, dtype=arr.dtype)]) + return arr + + xyz = pad_array(xyz, total_pixels) + scales = pad_array(scales, total_pixels) + quats = pad_array(quats, total_pixels) + colors_linear = pad_array(colors_linear, total_pixels) + opacities = pad_array(opacities, total_pixels) + + # Normalize quaternions + quats = quats / (np.linalg.norm(quats, axis=1, keepdims=True) + 1e-8) + + # === 1. Encode positions (16-bit per axis with symmetric log transform) === + def symlog(x: np.ndarray) -> np.ndarray: + return np.sign(x) * np.log1p(np.abs(x)) + + xyz_log = symlog(xyz) + mins = xyz_log.min(axis=0) + maxs = xyz_log.max(axis=0) + + # Avoid division by zero + ranges = maxs - mins + ranges = np.where(ranges < 1e-8, 1.0, ranges) + + # Quantize to 16-bit + xyz_norm = (xyz_log - mins) / ranges + xyz_q16 = (xyz_norm * 65535).clip(0, 65535).astype(np.uint16) + + means_l = (xyz_q16 & 0xFF).astype(np.uint8) + means_u = (xyz_q16 >> 8).astype(np.uint8) + + # === 2. Encode quaternions (smallest-three, 26-bit) === + def encode_quaternion(q: np.ndarray) -> np.ndarray: + """Encode quaternion using smallest-three method.""" + # Find largest component + abs_q = np.abs(q) + mode = np.argmax(abs_q, axis=1) + + # Ensure the largest component is positive + signs = np.sign(q[np.arange(len(q)), mode]) + q = q * signs[:, None] + + # Extract the three smallest components + result = np.zeros((len(q), 4), dtype=np.uint8) + sqrt2_inv = 1.0 / math.sqrt(2) + + for i in range(len(q)): + m = mode[i] + # Get indices of the three kept components + kept = [j for j in range(4) if j != m] + vals = q[i, kept] + # Quantize from [-sqrt2/2, sqrt2/2] to [0, 255] + encoded = ((vals * sqrt2_inv + 0.5) * 255).clip(0, 255).astype(np.uint8) + result[i, :3] = encoded + result[i, 3] = 252 + m # Mode in alpha channel + + return result + + quats_encoded = encode_quaternion(quats) + + # === 3. Build scale codebook (256 entries) === + # SOG stores scales in LOG space - the renderer does exp(codebook[idx]) + scales_log = np.log(np.maximum(scales, 1e-10)) + scales_log_flat = scales_log.flatten() + + # Use percentiles for codebook (in log space) + percentiles = np.linspace(0, 100, 256) + scale_codebook = np.percentile(scales_log_flat, percentiles).astype(np.float32) + + # Quantize values to nearest codebook entry + def quantize_to_codebook(values: np.ndarray, codebook: np.ndarray) -> np.ndarray: + indices = np.searchsorted(codebook, values) + indices = np.clip(indices, 0, len(codebook) - 1) + # Check if previous index is closer + prev_indices = np.clip(indices - 1, 0, len(codebook) - 1) + dist_curr = np.abs(values - codebook[indices]) + dist_prev = np.abs(values - codebook[prev_indices]) + use_prev = (dist_prev < dist_curr) & (indices > 0) + indices = np.where(use_prev, prev_indices, indices) + return indices.astype(np.uint8) + + scales_q = np.stack( + [ + quantize_to_codebook(scales_log[:, 0], scale_codebook), + quantize_to_codebook(scales_log[:, 1], scale_codebook), + quantize_to_codebook(scales_log[:, 2], scale_codebook), + ], + axis=1, + ) + + # === 4. Build SH0 codebook and encode colors === + SH_C0 = 0.28209479177387814 + sh0_coeffs = (colors_linear - 0.5) / SH_C0 + sh0_flat = sh0_coeffs.flatten() + + sh0_percentiles = np.linspace(0, 100, 256) + sh0_codebook = np.percentile(sh0_flat, sh0_percentiles).astype(np.float32) + + sh0_r = quantize_to_codebook(sh0_coeffs[:, 0], sh0_codebook) + sh0_g = quantize_to_codebook(sh0_coeffs[:, 1], sh0_codebook) + sh0_b = quantize_to_codebook(sh0_coeffs[:, 2], sh0_codebook) + sh0_a = (opacities * 255).clip(0, 255).astype(np.uint8) + + # === 5. Create images === + def create_image(data: np.ndarray, width: int, height: int) -> Image.Image: + data = data.reshape(height, width, -1) + if data.shape[2] == 3: + return Image.fromarray(data, mode="RGB") + elif data.shape[2] == 4: + return Image.fromarray(data, mode="RGBA") + else: + raise ValueError(f"Unexpected channel count: {data.shape[2]}") + + means_l_img = create_image(means_l, img_width, img_height) + means_u_img = create_image(means_u, img_width, img_height) + quats_img = create_image(quats_encoded, img_width, img_height) + scales_img = create_image(scales_q, img_width, img_height) + + sh0_data = np.stack([sh0_r, sh0_g, sh0_b, sh0_a], axis=1) + sh0_img = create_image(sh0_data, img_width, img_height) + + # === 6. Create meta.json === + meta = { + "version": 2, + "count": num_gaussians, + "antialias": False, + "means": { + "mins": mins.tolist(), + "maxs": maxs.tolist(), + "files": ["means_l.webp", "means_u.webp"], + }, + "scales": {"codebook": scale_codebook.tolist(), "files": ["scales.webp"]}, + "quats": {"files": ["quats.webp"]}, + "sh0": {"codebook": sh0_codebook.tolist(), "files": ["sh0.webp"]}, + } + + # === 7. Save as ZIP archive === + path = Path(path) + if path.suffix.lower() != ".sog": + path = path.with_suffix(".sog") + + with zipfile.ZipFile(path, "w", zipfile.ZIP_DEFLATED) as zf: + # Save images as lossless WebP + for name, img in [ + ("means_l.webp", means_l_img), + ("means_u.webp", means_u_img), + ("quats.webp", quats_img), + ("scales.webp", scales_img), + ("sh0.webp", sh0_img), + ]: + buf = io.BytesIO() + img.save(buf, format="WEBP", lossless=True) + zf.writestr(name, buf.getvalue()) + + # Save meta.json + zf.writestr("meta.json", json.dumps(meta, indent=2))