-
Notifications
You must be signed in to change notification settings - Fork 317
feat: Add SPLAT and SOG export formats for 3D Gaussians #24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -478,3 +478,233 @@ def _inverse_sigmoid(tensor: torch.Tensor) -> torch.Tensor: | |||||
|
|
||||||
| plydata.write(path) | ||||||
| return plydata | ||||||
|
|
||||||
|
|
||||||
| @torch.no_grad() | ||||||
| def save_splat( | ||||||
| gaussians: Gaussians3D, f_px: float, image_shape: tuple[int, int], path: Path | ||||||
| ) -> None: | ||||||
| """Save Gaussians to .splat format (compact binary format for web viewers). | ||||||
|
|
||||||
| The .splat format is a simple binary format used by web-based 3DGS viewers. | ||||||
| Each Gaussian is stored as 32 bytes: | ||||||
| - 12 bytes: xyz position (3 x float32) | ||||||
| - 12 bytes: scales (3 x float32) | ||||||
| - 4 bytes: RGBA color (4 x uint8) | ||||||
| - 4 bytes: quaternion rotation (4 x uint8, encoded as (q * 128 + 128)) | ||||||
|
|
||||||
| Gaussians are sorted by size * opacity (descending) for progressive rendering. | ||||||
| """ | ||||||
| xyz = gaussians.mean_vectors.flatten(0, 1).cpu().numpy() | ||||||
| scales = gaussians.singular_values.flatten(0, 1).cpu().numpy() | ||||||
| quats = gaussians.quaternions.flatten(0, 1).cpu().numpy() | ||||||
| colors_rgb = cs_utils.linearRGB2sRGB(gaussians.colors.flatten(0, 1)).cpu().numpy() | ||||||
| opacities = gaussians.opacities.flatten(0, 1).cpu().numpy() | ||||||
|
|
||||||
| # Sort by size * opacity (descending) for progressive rendering | ||||||
| sort_idx = np.argsort(-(scales.prod(axis=1) * opacities)) | ||||||
|
|
||||||
| # Normalize quaternions | ||||||
| quats = quats / np.linalg.norm(quats, axis=1, keepdims=True) | ||||||
|
|
||||||
| with open(path, "wb") as f: | ||||||
| for i in sort_idx: | ||||||
| f.write(xyz[i].astype(np.float32).tobytes()) | ||||||
| f.write(scales[i].astype(np.float32).tobytes()) | ||||||
| rgba = np.concatenate([colors_rgb[i], [opacities[i]]]) | ||||||
| f.write((rgba * 255).clip(0, 255).astype(np.uint8).tobytes()) | ||||||
| f.write((quats[i] * 128 + 128).clip(0, 255).astype(np.uint8).tobytes()) | ||||||
|
|
||||||
|
|
||||||
| @torch.no_grad() | ||||||
| def save_sog( | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
| gaussians: Gaussians3D, f_px: float, image_shape: tuple[int, int], path: Path | ||||||
| ) -> None: | ||||||
| """Save Gaussians to SOG format (Spatially Ordered Gaussians). | ||||||
|
|
||||||
| SOG is a highly compressed format using quantization and WebP images. | ||||||
| Typically 15-20x smaller than PLY. The format stores data in a ZIP archive | ||||||
| containing WebP images for positions, rotations, scales, and colors. | ||||||
|
|
||||||
| Reference: https://github.com/aras-p/sog-format | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| """ | ||||||
| import io | ||||||
| import json | ||||||
| import math | ||||||
| import zipfile | ||||||
|
|
||||||
| from PIL import Image | ||||||
|
|
||||||
| xyz = gaussians.mean_vectors.flatten(0, 1).cpu().numpy() | ||||||
| scales = gaussians.singular_values.flatten(0, 1).cpu().numpy() | ||||||
| quats = gaussians.quaternions.flatten(0, 1).cpu().numpy() | ||||||
| colors_linear = gaussians.colors.flatten(0, 1).cpu().numpy() | ||||||
| opacities = gaussians.opacities.flatten(0, 1).cpu().numpy() | ||||||
|
|
||||||
| num_gaussians = len(xyz) | ||||||
|
|
||||||
| # Compute image dimensions (roughly square) | ||||||
| img_width = int(math.ceil(math.sqrt(num_gaussians))) | ||||||
| img_height = int(math.ceil(num_gaussians / img_width)) | ||||||
| total_pixels = img_width * img_height | ||||||
|
|
||||||
| # Pad arrays to fill image | ||||||
| def pad_array(arr: np.ndarray, total: int) -> np.ndarray: | ||||||
| if len(arr) < total: | ||||||
| pad_shape = (total - len(arr),) + arr.shape[1:] | ||||||
| return np.concatenate([arr, np.zeros(pad_shape, dtype=arr.dtype)]) | ||||||
| return arr | ||||||
|
|
||||||
| xyz = pad_array(xyz, total_pixels) | ||||||
| scales = pad_array(scales, total_pixels) | ||||||
| quats = pad_array(quats, total_pixels) | ||||||
| colors_linear = pad_array(colors_linear, total_pixels) | ||||||
| opacities = pad_array(opacities, total_pixels) | ||||||
|
|
||||||
| # Normalize quaternions | ||||||
| quats = quats / (np.linalg.norm(quats, axis=1, keepdims=True) + 1e-8) | ||||||
|
|
||||||
| # === 1. Encode positions (16-bit per axis with symmetric log transform) === | ||||||
| def symlog(x: np.ndarray) -> np.ndarray: | ||||||
| return np.sign(x) * np.log1p(np.abs(x)) | ||||||
|
|
||||||
| xyz_log = symlog(xyz) | ||||||
| mins = xyz_log.min(axis=0) | ||||||
| maxs = xyz_log.max(axis=0) | ||||||
|
|
||||||
| # Avoid division by zero | ||||||
| ranges = maxs - mins | ||||||
| ranges = np.where(ranges < 1e-8, 1.0, ranges) | ||||||
|
|
||||||
| # Quantize to 16-bit | ||||||
| xyz_norm = (xyz_log - mins) / ranges | ||||||
| xyz_q16 = (xyz_norm * 65535).clip(0, 65535).astype(np.uint16) | ||||||
|
|
||||||
| means_l = (xyz_q16 & 0xFF).astype(np.uint8) | ||||||
| means_u = (xyz_q16 >> 8).astype(np.uint8) | ||||||
|
|
||||||
| # === 2. Encode quaternions (smallest-three, 26-bit) === | ||||||
| def encode_quaternion(q: np.ndarray) -> np.ndarray: | ||||||
| """Encode quaternion using smallest-three method.""" | ||||||
| # Find largest component | ||||||
| abs_q = np.abs(q) | ||||||
| mode = np.argmax(abs_q, axis=1) | ||||||
|
|
||||||
| # Ensure the largest component is positive | ||||||
| signs = np.sign(q[np.arange(len(q)), mode]) | ||||||
| q = q * signs[:, None] | ||||||
|
|
||||||
| # Extract the three smallest components | ||||||
| result = np.zeros((len(q), 4), dtype=np.uint8) | ||||||
| sqrt2_inv = 1.0 / math.sqrt(2) | ||||||
|
|
||||||
| for i in range(len(q)): | ||||||
| m = mode[i] | ||||||
| # Get indices of the three kept components | ||||||
| kept = [j for j in range(4) if j != m] | ||||||
| vals = q[i, kept] | ||||||
| # Quantize from [-sqrt2/2, sqrt2/2] to [0, 255] | ||||||
| encoded = ((vals * sqrt2_inv + 0.5) * 255).clip(0, 255).astype(np.uint8) | ||||||
| result[i, :3] = encoded | ||||||
| result[i, 3] = 252 + m # Mode in alpha channel | ||||||
|
|
||||||
| return result | ||||||
|
|
||||||
| quats_encoded = encode_quaternion(quats) | ||||||
|
|
||||||
| # === 3. Build scale codebook (256 entries) === | ||||||
| # SOG stores scales in LOG space - the renderer does exp(codebook[idx]) | ||||||
| scales_log = np.log(np.maximum(scales, 1e-10)) | ||||||
| scales_log_flat = scales_log.flatten() | ||||||
|
|
||||||
| # Use percentiles for codebook (in log space) | ||||||
| percentiles = np.linspace(0, 100, 256) | ||||||
| scale_codebook = np.percentile(scales_log_flat, percentiles).astype(np.float32) | ||||||
|
|
||||||
| # Quantize values to nearest codebook entry | ||||||
| def quantize_to_codebook(values: np.ndarray, codebook: np.ndarray) -> np.ndarray: | ||||||
| indices = np.searchsorted(codebook, values) | ||||||
| indices = np.clip(indices, 0, len(codebook) - 1) | ||||||
| # Check if previous index is closer | ||||||
| prev_indices = np.clip(indices - 1, 0, len(codebook) - 1) | ||||||
| dist_curr = np.abs(values - codebook[indices]) | ||||||
| dist_prev = np.abs(values - codebook[prev_indices]) | ||||||
| use_prev = (dist_prev < dist_curr) & (indices > 0) | ||||||
| indices = np.where(use_prev, prev_indices, indices) | ||||||
| return indices.astype(np.uint8) | ||||||
|
|
||||||
| scales_q = np.stack( | ||||||
| [ | ||||||
| quantize_to_codebook(scales_log[:, 0], scale_codebook), | ||||||
| quantize_to_codebook(scales_log[:, 1], scale_codebook), | ||||||
| quantize_to_codebook(scales_log[:, 2], scale_codebook), | ||||||
| ], | ||||||
| axis=1, | ||||||
| ) | ||||||
|
|
||||||
| # === 4. Build SH0 codebook and encode colors === | ||||||
| SH_C0 = 0.28209479177387814 | ||||||
| sh0_coeffs = (colors_linear - 0.5) / SH_C0 | ||||||
| sh0_flat = sh0_coeffs.flatten() | ||||||
|
|
||||||
| sh0_percentiles = np.linspace(0, 100, 256) | ||||||
| sh0_codebook = np.percentile(sh0_flat, sh0_percentiles).astype(np.float32) | ||||||
|
|
||||||
| sh0_r = quantize_to_codebook(sh0_coeffs[:, 0], sh0_codebook) | ||||||
| sh0_g = quantize_to_codebook(sh0_coeffs[:, 1], sh0_codebook) | ||||||
| sh0_b = quantize_to_codebook(sh0_coeffs[:, 2], sh0_codebook) | ||||||
| sh0_a = (opacities * 255).clip(0, 255).astype(np.uint8) | ||||||
|
|
||||||
| # === 5. Create images === | ||||||
| def create_image(data: np.ndarray, width: int, height: int) -> Image.Image: | ||||||
| data = data.reshape(height, width, -1) | ||||||
| if data.shape[2] == 3: | ||||||
| return Image.fromarray(data, mode="RGB") | ||||||
| elif data.shape[2] == 4: | ||||||
| return Image.fromarray(data, mode="RGBA") | ||||||
| else: | ||||||
| raise ValueError(f"Unexpected channel count: {data.shape[2]}") | ||||||
|
|
||||||
| means_l_img = create_image(means_l, img_width, img_height) | ||||||
| means_u_img = create_image(means_u, img_width, img_height) | ||||||
| quats_img = create_image(quats_encoded, img_width, img_height) | ||||||
| scales_img = create_image(scales_q, img_width, img_height) | ||||||
|
|
||||||
| sh0_data = np.stack([sh0_r, sh0_g, sh0_b, sh0_a], axis=1) | ||||||
| sh0_img = create_image(sh0_data, img_width, img_height) | ||||||
|
|
||||||
| # === 6. Create meta.json === | ||||||
| meta = { | ||||||
| "version": 2, | ||||||
| "count": num_gaussians, | ||||||
| "antialias": False, | ||||||
| "means": { | ||||||
| "mins": mins.tolist(), | ||||||
| "maxs": maxs.tolist(), | ||||||
| "files": ["means_l.webp", "means_u.webp"], | ||||||
| }, | ||||||
| "scales": {"codebook": scale_codebook.tolist(), "files": ["scales.webp"]}, | ||||||
| "quats": {"files": ["quats.webp"]}, | ||||||
| "sh0": {"codebook": sh0_codebook.tolist(), "files": ["sh0.webp"]}, | ||||||
| } | ||||||
|
|
||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
| # === 7. Save as ZIP archive === | ||||||
| path = Path(path) | ||||||
| if path.suffix.lower() != ".sog": | ||||||
| path = path.with_suffix(".sog") | ||||||
|
|
||||||
| with zipfile.ZipFile(path, "w", zipfile.ZIP_DEFLATED) as zf: | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does ZIP_DEFLATED make sense for lossless webp? |
||||||
| # Save images as lossless WebP | ||||||
| for name, img in [ | ||||||
| ("means_l.webp", means_l_img), | ||||||
| ("means_u.webp", means_u_img), | ||||||
| ("quats.webp", quats_img), | ||||||
| ("scales.webp", scales_img), | ||||||
| ("sh0.webp", sh0_img), | ||||||
| ]: | ||||||
| buf = io.BytesIO() | ||||||
| img.save(buf, format="WEBP", lossless=True) | ||||||
| zf.writestr(name, buf.getvalue()) | ||||||
|
|
||||||
| # Save meta.json | ||||||
| zf.writestr("meta.json", json.dumps(meta, indent=2)) | ||||||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In addition to gaussians, sharp's save_ply function also saves many metadata, such as image size, camera intrinsic/extrinsic and colorspace. These information are important for rendering