Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/sharp/cli/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
Gaussians3D,
SceneMetaData,
save_ply,
save_splat,
save_sog,
unproject_gaussians,
)

Expand Down Expand Up @@ -66,6 +68,15 @@
default=False,
help="Whether to render trajectory for checkpoint.",
)
@click.option(
"-f",
"--format",
"export_formats",
type=click.Choice(["ply", "splat", "sog"], case_sensitive=False),
multiple=True,
default=["ply"],
help="Output format(s). Can specify multiple: -f ply -f splat -f sog",
)
@click.option(
"--device",
type=str,
Expand All @@ -78,12 +89,16 @@ def predict_cli(
output_path: Path,
checkpoint_path: Path,
with_rendering: bool,
export_formats: tuple[str, ...],
device: str,
verbose: bool,
):
"""Predict Gaussians from input images."""
logging_utils.configure(logging.DEBUG if verbose else logging.INFO)

# Normalize export formats to lowercase
export_formats = tuple(fmt.lower() for fmt in export_formats)

extensions = io.get_supported_image_extensions()

image_paths = []
Expand Down Expand Up @@ -145,7 +160,16 @@ def predict_cli(
gaussians = predict_image(gaussian_predictor, image, f_px, torch.device(device))

LOGGER.info("Saving 3DGS to %s", output_path)
save_ply(gaussians, f_px, (height, width), output_path / f"{image_path.stem}.ply")
for fmt in export_formats:
if fmt == "ply":
save_ply(gaussians, f_px, (height, width), output_path / f"{image_path.stem}.ply")
LOGGER.info("Saved PLY: %s", output_path / f"{image_path.stem}.ply")
elif fmt == "splat":
save_splat(gaussians, f_px, (height, width), output_path / f"{image_path.stem}.splat")
LOGGER.info("Saved SPLAT: %s", output_path / f"{image_path.stem}.splat")
elif fmt == "sog":
save_sog(gaussians, f_px, (height, width), output_path / f"{image_path.stem}.sog")
LOGGER.info("Saved SOG: %s", output_path / f"{image_path.stem}.sog")

if with_rendering:
output_video_path = (output_path / image_path.stem).with_suffix(".mp4")
Expand Down
230 changes: 230 additions & 0 deletions src/sharp/utils/gaussians.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,3 +478,233 @@ def _inverse_sigmoid(tensor: torch.Tensor) -> torch.Tensor:

plydata.write(path)
return plydata


@torch.no_grad()
def save_splat(
gaussians: Gaussians3D, f_px: float, image_shape: tuple[int, int], path: Path
) -> None:
"""Save Gaussians to .splat format (compact binary format for web viewers).

The .splat format is a simple binary format used by web-based 3DGS viewers.
Each Gaussian is stored as 32 bytes:
- 12 bytes: xyz position (3 x float32)
- 12 bytes: scales (3 x float32)
- 4 bytes: RGBA color (4 x uint8)
- 4 bytes: quaternion rotation (4 x uint8, encoded as (q * 128 + 128))

Gaussians are sorted by size * opacity (descending) for progressive rendering.
"""
xyz = gaussians.mean_vectors.flatten(0, 1).cpu().numpy()
scales = gaussians.singular_values.flatten(0, 1).cpu().numpy()
quats = gaussians.quaternions.flatten(0, 1).cpu().numpy()
colors_rgb = cs_utils.linearRGB2sRGB(gaussians.colors.flatten(0, 1)).cpu().numpy()
opacities = gaussians.opacities.flatten(0, 1).cpu().numpy()

# Sort by size * opacity (descending) for progressive rendering
sort_idx = np.argsort(-(scales.prod(axis=1) * opacities))

# Normalize quaternions
quats = quats / np.linalg.norm(quats, axis=1, keepdims=True)

with open(path, "wb") as f:
for i in sort_idx:
f.write(xyz[i].astype(np.float32).tobytes())
f.write(scales[i].astype(np.float32).tobytes())
rgba = np.concatenate([colors_rgb[i], [opacities[i]]])
f.write((rgba * 255).clip(0, 255).astype(np.uint8).tobytes())
f.write((quats[i] * 128 + 128).clip(0, 255).astype(np.uint8).tobytes())
Comment on lines +510 to +516
Copy link

@Enter-tainer Enter-tainer Dec 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition to gaussians, sharp's save_ply function also saves many metadata, such as image size, camera intrinsic/extrinsic and colorspace. These information are important for rendering



@torch.no_grad()
def save_sog(
Copy link

@Enter-tainer Enter-tainer Dec 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

linear to srgb transform is missing so the color of sog is very dark Sorry I got it wrong, sog does use linear srgb

gaussians: Gaussians3D, f_px: float, image_shape: tuple[int, int], path: Path
) -> None:
"""Save Gaussians to SOG format (Spatially Ordered Gaussians).

SOG is a highly compressed format using quantization and WebP images.
Typically 15-20x smaller than PLY. The format stores data in a ZIP archive
containing WebP images for positions, rotations, scales, and colors.

Reference: https://github.com/aras-p/sog-format
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Reference: https://github.com/aras-p/sog-format
Reference: https://developer.playcanvas.com/user-manual/gaussian-splatting/formats/sog/

"""
import io
import json
import math
import zipfile

from PIL import Image

xyz = gaussians.mean_vectors.flatten(0, 1).cpu().numpy()
scales = gaussians.singular_values.flatten(0, 1).cpu().numpy()
quats = gaussians.quaternions.flatten(0, 1).cpu().numpy()
colors_linear = gaussians.colors.flatten(0, 1).cpu().numpy()
opacities = gaussians.opacities.flatten(0, 1).cpu().numpy()

num_gaussians = len(xyz)

# Compute image dimensions (roughly square)
img_width = int(math.ceil(math.sqrt(num_gaussians)))
img_height = int(math.ceil(num_gaussians / img_width))
total_pixels = img_width * img_height

# Pad arrays to fill image
def pad_array(arr: np.ndarray, total: int) -> np.ndarray:
if len(arr) < total:
pad_shape = (total - len(arr),) + arr.shape[1:]
return np.concatenate([arr, np.zeros(pad_shape, dtype=arr.dtype)])
return arr

xyz = pad_array(xyz, total_pixels)
scales = pad_array(scales, total_pixels)
quats = pad_array(quats, total_pixels)
colors_linear = pad_array(colors_linear, total_pixels)
opacities = pad_array(opacities, total_pixels)

# Normalize quaternions
quats = quats / (np.linalg.norm(quats, axis=1, keepdims=True) + 1e-8)

# === 1. Encode positions (16-bit per axis with symmetric log transform) ===
def symlog(x: np.ndarray) -> np.ndarray:
return np.sign(x) * np.log1p(np.abs(x))

xyz_log = symlog(xyz)
mins = xyz_log.min(axis=0)
maxs = xyz_log.max(axis=0)

# Avoid division by zero
ranges = maxs - mins
ranges = np.where(ranges < 1e-8, 1.0, ranges)

# Quantize to 16-bit
xyz_norm = (xyz_log - mins) / ranges
xyz_q16 = (xyz_norm * 65535).clip(0, 65535).astype(np.uint16)

means_l = (xyz_q16 & 0xFF).astype(np.uint8)
means_u = (xyz_q16 >> 8).astype(np.uint8)

# === 2. Encode quaternions (smallest-three, 26-bit) ===
def encode_quaternion(q: np.ndarray) -> np.ndarray:
"""Encode quaternion using smallest-three method."""
# Find largest component
abs_q = np.abs(q)
mode = np.argmax(abs_q, axis=1)

# Ensure the largest component is positive
signs = np.sign(q[np.arange(len(q)), mode])
q = q * signs[:, None]

# Extract the three smallest components
result = np.zeros((len(q), 4), dtype=np.uint8)
sqrt2_inv = 1.0 / math.sqrt(2)

for i in range(len(q)):
m = mode[i]
# Get indices of the three kept components
kept = [j for j in range(4) if j != m]
vals = q[i, kept]
# Quantize from [-sqrt2/2, sqrt2/2] to [0, 255]
encoded = ((vals * sqrt2_inv + 0.5) * 255).clip(0, 255).astype(np.uint8)
result[i, :3] = encoded
result[i, 3] = 252 + m # Mode in alpha channel

return result

quats_encoded = encode_quaternion(quats)

# === 3. Build scale codebook (256 entries) ===
# SOG stores scales in LOG space - the renderer does exp(codebook[idx])
scales_log = np.log(np.maximum(scales, 1e-10))
scales_log_flat = scales_log.flatten()

# Use percentiles for codebook (in log space)
percentiles = np.linspace(0, 100, 256)
scale_codebook = np.percentile(scales_log_flat, percentiles).astype(np.float32)

# Quantize values to nearest codebook entry
def quantize_to_codebook(values: np.ndarray, codebook: np.ndarray) -> np.ndarray:
indices = np.searchsorted(codebook, values)
indices = np.clip(indices, 0, len(codebook) - 1)
# Check if previous index is closer
prev_indices = np.clip(indices - 1, 0, len(codebook) - 1)
dist_curr = np.abs(values - codebook[indices])
dist_prev = np.abs(values - codebook[prev_indices])
use_prev = (dist_prev < dist_curr) & (indices > 0)
indices = np.where(use_prev, prev_indices, indices)
return indices.astype(np.uint8)

scales_q = np.stack(
[
quantize_to_codebook(scales_log[:, 0], scale_codebook),
quantize_to_codebook(scales_log[:, 1], scale_codebook),
quantize_to_codebook(scales_log[:, 2], scale_codebook),
],
axis=1,
)

# === 4. Build SH0 codebook and encode colors ===
SH_C0 = 0.28209479177387814
sh0_coeffs = (colors_linear - 0.5) / SH_C0
sh0_flat = sh0_coeffs.flatten()

sh0_percentiles = np.linspace(0, 100, 256)
sh0_codebook = np.percentile(sh0_flat, sh0_percentiles).astype(np.float32)

sh0_r = quantize_to_codebook(sh0_coeffs[:, 0], sh0_codebook)
sh0_g = quantize_to_codebook(sh0_coeffs[:, 1], sh0_codebook)
sh0_b = quantize_to_codebook(sh0_coeffs[:, 2], sh0_codebook)
sh0_a = (opacities * 255).clip(0, 255).astype(np.uint8)

# === 5. Create images ===
def create_image(data: np.ndarray, width: int, height: int) -> Image.Image:
data = data.reshape(height, width, -1)
if data.shape[2] == 3:
return Image.fromarray(data, mode="RGB")
elif data.shape[2] == 4:
return Image.fromarray(data, mode="RGBA")
else:
raise ValueError(f"Unexpected channel count: {data.shape[2]}")

means_l_img = create_image(means_l, img_width, img_height)
means_u_img = create_image(means_u, img_width, img_height)
quats_img = create_image(quats_encoded, img_width, img_height)
scales_img = create_image(scales_q, img_width, img_height)

sh0_data = np.stack([sh0_r, sh0_g, sh0_b, sh0_a], axis=1)
sh0_img = create_image(sh0_data, img_width, img_height)

# === 6. Create meta.json ===
meta = {
"version": 2,
"count": num_gaussians,
"antialias": False,
"means": {
"mins": mins.tolist(),
"maxs": maxs.tolist(),
"files": ["means_l.webp", "means_u.webp"],
},
"scales": {"codebook": scale_codebook.tolist(), "files": ["scales.webp"]},
"quats": {"files": ["quats.webp"]},
"sh0": {"codebook": sh0_codebook.tolist(), "files": ["sh0.webp"]},
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. TypeScript (core technology – đáng chú ý nhất)
    Không chỉ là “viết bằng TypeScript”, mà là TypeScript có kỷ luật thiết kế:
    Type annotations rõ ràng (SerializedBlock, DAGEdge, NodeMetadata)
    Dùng private → class-based OOP chuẩn
    Type đóng vai trò hợp đồng (contract) cho dữ liệu DAG
    Codebase lớn hoặc phức tạp
    Ưu tiên độ an toàn khi refactor và khả năng mở rộng

  2. ES Modules + Path Alias (@/...)
    Công nghệ liên quan:
    tsconfig.json → paths
    Bundler: Vite / Webpack / ts-node / esbuild
    Ý nghĩa kiến trúc:
    Không phải project nhỏ, cấu trúc thư mục được thiết kế từ đầu
    Dùng alias để:
    Tránh relative import dài dòng
    Dễ tái cấu trúc module
    Đây là dấu hiệu của codebase enterprise / production-grade

  3. DAG / Workflow Engine (Domain-specific technology)
    Đây là điểm rất đáng chú ý về mặt kiến trúc, không phải công nghệ phổ thông.
    Các khái niệm xuất hiện:
    DAG (Directed Acyclic Graph)
    nodes / edges
    loops / parallelConfigs
    sentinel nodes
    validateSubflowStructure
    Điều này cho thấy:
    Hệ thống là workflow engine / rule engine / pipeline engine
    Có thể tương tự:
    Temporal / Airflow (ở mức ý tưởng)
    Node-based automation (Zapier / n8n / Prefect style)
    Đây không phải CRUD app thông thường.

  4. Defensive Validation Pattern
    Kỹ thuật nổi bật:
    Validation được tách thành function riêng
    Check cấu trúc DAG trước khi execute
    Fail fast bằng throw Error(...)
    Đây là:
    Defensive programming
    Design-by-contract (ở mức ứng dụng)
    Rất đáng học nếu bạn làm hệ thống:
    Workflow
    Plugin system
    User-defined graph/config

  5. Sentinel Nodes (kỹ thuật nâng cao)
    buildSentinelStartId, extractBaseBlockId
    Đây không phải công nghệ, mà là kỹ thuật mô hình hóa:
    Sentinel node = node giả để:
    Chuẩn hóa entry/exit
    Tránh edge case
    ID manipulation → cho thấy blockId mang semantic
    Đây là dấu hiệu:
    Người viết code có kinh nghiệm xử lý graph phức tạp
    Tối ưu cho thuật toán hơn là UI

  6. Logging Library (Pino / Winston / Bunyan style)
    logger.info(...) cho thấy:
    Không dùng console.log
    Có logging abstraction
    Điều này thường đi kèm:
    Production environment
    Observability / debugging workflow execution

  7. JavaScript Collection API được dùng “đúng cách”
    Map thay vì object
    Array.from(iterator)
    .some() cho short-circuit logic
    Dấu hiệu:
    Developer nắm vững ES6+
    Code hướng đến readability + performance

# === 7. Save as ZIP archive ===
path = Path(path)
if path.suffix.lower() != ".sog":
path = path.with_suffix(".sog")

with zipfile.ZipFile(path, "w", zipfile.ZIP_DEFLATED) as zf:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does ZIP_DEFLATED make sense for lossless webp?

# Save images as lossless WebP
for name, img in [
("means_l.webp", means_l_img),
("means_u.webp", means_u_img),
("quats.webp", quats_img),
("scales.webp", scales_img),
("sh0.webp", sh0_img),
]:
buf = io.BytesIO()
img.save(buf, format="WEBP", lossless=True)
zf.writestr(name, buf.getvalue())

# Save meta.json
zf.writestr("meta.json", json.dumps(meta, indent=2))