diff --git a/src/sharp/utils/gaussians.py b/src/sharp/utils/gaussians.py index 9c037eb..5c3f683 100644 --- a/src/sharp/utils/gaussians.py +++ b/src/sharp/utils/gaussians.py @@ -320,6 +320,8 @@ def load_ply(path: Path) -> tuple[Gaussians3D, SceneMetaData]: # Parse color space. color_space_index = supplement_data.get("color_space", 1) color_space = cs_utils.decode_color_space(color_space_index) + colors = torch.from_numpy(colors).view(1, -1, 3).float() + if color_space == "sRGB": colors = cs_utils.sRGB2linearRGB(colors) @@ -327,7 +329,6 @@ def load_ply(path: Path) -> tuple[Gaussians3D, SceneMetaData]: quaternions = torch.from_numpy(quaternions).view(1, -1, 4).float() singular_values = torch.exp(torch.from_numpy(scale_logits).view(1, -1, 3)).float() opacities = torch.sigmoid(torch.from_numpy(opacity_logits).view(1, -1)).float() - colors = torch.from_numpy(colors).view(1, -1, 3).float() gaussians = Gaussians3D( mean_vectors=mean_vectors, diff --git a/src/sharp/utils/gsplat.py b/src/sharp/utils/gsplat.py index 0ce72e2..d4e4971 100644 --- a/src/sharp/utils/gsplat.py +++ b/src/sharp/utils/gsplat.py @@ -118,7 +118,7 @@ def forward( # Colorspace conversion. if self.color_space == "sRGB": - pass + rendered_color = cs_utils.linearRGB2sRGB(rendered_color) elif self.color_space == "linearRGB": rendered_color = cs_utils.linearRGB2sRGB(rendered_color) else: