diff --git a/.gitignore b/.gitignore index 6dea9b1..180be97 100644 --- a/.gitignore +++ b/.gitignore @@ -206,4 +206,6 @@ marimo/_static/ marimo/_lsp/ __marimo__/ -shine/_version.py \ No newline at end of file +shine/_version.py + +results/ \ No newline at end of file diff --git a/DESIGN.md b/DESIGN.md index 67545ea..0b93413 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -24,26 +24,26 @@ SHINE treats shear measurement as a Bayesian inverse problem. Instead of measuri ```mermaid graph TD Config[YAML Configuration] --> Loader[Config Handler] - Loader --> Scene[Scene Modelling (NumPyro + JAX-GalSim)] - + Loader --> Scene[Scene Modelling
NumPyro + JAX-GalSim] + subgraph "Forward Model (JAX)" - Priors[Priors (NumPyro)] --> Scene - Scene --> Galaxy[Galaxy Generation (Sersic/Morphology)] + Priors[Priors
NumPyro] --> Scene + Scene --> Galaxy[Galaxy Generation
Sersic/Morphology] Scene --> PSF[PSF Modelling] Galaxy --> Convolve[Convolution] PSF --> Convolve Convolve --> Noise[Noise Model] Noise --> ModelImage[Simulated Image] end - - Data[Observed Data (Fits/HDF5)] --> Likelihood + + Data[Observed Data
Fits/HDF5] --> Likelihood ModelImage --> Likelihood[Likelihood Evaluation] - - Likelihood --> Inference[Inference Engine (NumPyro/BlackJAX)] + + Likelihood --> Inference[Inference Engine
NumPyro/BlackJAX] Inference --> Posterior[Shear Posterior] - + subgraph "Workflow Management" - WMS[WMS (Slurm/Cluster)] --> Config + WMS[WMS
Slurm/Cluster] --> Config WMS --> Data end ``` diff --git a/configs/test_run.yaml b/configs/test_run.yaml new file mode 100644 index 0000000..77021bf --- /dev/null +++ b/configs/test_run.yaml @@ -0,0 +1,38 @@ +image: + pixel_scale: 0.2 + size_x: 32 + size_y: 32 + n_objects: 1 + noise: + type: Gaussian + sigma: 0.01 + +psf: + type: Gaussian + sigma: 0.1 + +gal: + type: Exponential + flux: + type: LogNormal + mean: 100.0 + sigma: 0.1 + half_light_radius: + type: Uniform + min: 0.3 + max: 0.8 + shear: + type: G1G2 + g1: + type: Normal + mean: 0.05 + sigma: 0.02 + g2: + type: Normal + mean: -0.05 + sigma: 0.02 + +inference: + warmup: 100 + samples: 100 + chains: 1 diff --git a/shine/config.py b/shine/config.py new file mode 100644 index 0000000..b915c1a --- /dev/null +++ b/shine/config.py @@ -0,0 +1,83 @@ +from typing import Union, Optional, Dict, Any, List +from pydantic import BaseModel, Field, validator +import yaml +from pathlib import Path + +# --- Distribution Models (for Priors) --- + +class DistributionConfig(BaseModel): + type: str + mean: Optional[float] = None + sigma: Optional[float] = None + min: Optional[float] = None + max: Optional[float] = None + + # Allow extra fields for other distributions + class Config: + extra = "allow" + +# --- Component Models --- + +class NoiseConfig(BaseModel): + type: str = "Gaussian" + sigma: float + +class ImageConfig(BaseModel): + pixel_scale: float + size_x: int + size_y: int + n_objects: int = 1 # Default to 1 for simple tests + noise: NoiseConfig + +class PSFConfig(BaseModel): + type: str = "Gaussian" + sigma: float + beta: Optional[float] = 2.5 # For Moffat + +class ShearComponentConfig(BaseModel): + # Can be a fixed float or a distribution + type: Optional[str] = None # If None, assume fixed value in parent or handled elsewhere + mean: Optional[float] = 0.0 + sigma: Optional[float] = 0.05 + + # To handle the case where it's just a float in YAML, we might need a custom validator + # but for now let's assume structured input as per design doc + +class ShearConfig(BaseModel): + type: str = "G1G2" + g1: Union[float, DistributionConfig] + g2: Union[float, DistributionConfig] + +class GalaxyConfig(BaseModel): + type: str = "Exponential" # Changed default from Sersic to Exponential + n: Optional[Union[float, DistributionConfig]] = None # Make optional for Exponential + flux: Union[float, DistributionConfig] + half_light_radius: Union[float, DistributionConfig] = Field(..., alias="half_light_radius") + shear: ShearConfig + +class InferenceConfig(BaseModel): + warmup: int = 500 + samples: int = 1000 + chains: int = 1 + dense_mass: bool = False + +class ShineConfig(BaseModel): + image: ImageConfig + psf: PSFConfig + gal: GalaxyConfig + inference: InferenceConfig = Field(default_factory=InferenceConfig) + data_path: Optional[str] = None + output_path: str = "results" + +class ConfigHandler: + @staticmethod + def load(path: str) -> ShineConfig: + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Config file not found: {path}") + + with open(path, "r") as f: + data = yaml.safe_load(f) + + # Basic validation and type conversion via Pydantic + return ShineConfig(**data) diff --git a/shine/data.py b/shine/data.py new file mode 100644 index 0000000..5540cf0 --- /dev/null +++ b/shine/data.py @@ -0,0 +1,83 @@ +import jax.numpy as jnp +import jax +from dataclasses import dataclass +from typing import Optional, Any, Dict +from shine.config import ShineConfig + +@dataclass +class Observation: + image: jnp.ndarray + noise_map: jnp.ndarray + psf_config: Dict[str, Any] # Store PSF config instead of object + wcs: Any = None + +class DataLoader: + @staticmethod + def load(config: ShineConfig) -> Observation: + if config.data_path and config.data_path != "None": + # TODO: Implement real data loading (Fits/HDF5) + raise NotImplementedError("Real data loading not yet implemented. Use synthetic generation.") + else: + print("No data path provided. Generating synthetic data...") + return DataLoader.generate_synthetic(config) + + @staticmethod + def generate_synthetic(config: ShineConfig) -> Observation: + import galsim + + # 1. Define PSF + if config.psf.type == "Gaussian": + psf = galsim.Gaussian(sigma=config.psf.sigma) + else: + raise NotImplementedError(f"PSF type {config.psf.type} not supported for synthetic gen") + + # 2. Define Galaxy (using mean values from config for "truth") + def get_mean(param): + if isinstance(param, (float, int)): + return float(param) + # If it's a distribution config + if param.mean is not None: + return param.mean + # Handle Uniform + if param.type == 'Uniform' and param.min is not None and param.max is not None: + return (param.min + param.max) / 2.0 + return param.mean # Fallback (might still be None if not handled) + + gal_flux = get_mean(config.gal.flux) + gal_hlr = get_mean(config.gal.half_light_radius) + + # Shear + g1 = get_mean(config.gal.shear.g1) + g2 = get_mean(config.gal.shear.g2) + shear = galsim.Shear(g1=g1, g2=g2) + + # Create Galaxy Object - Use Exponential (Sersic n=1) + gal = galsim.Exponential(half_light_radius=gal_hlr, flux=gal_flux) + gal = gal.shear(shear) + + # Convolve + final = galsim.Convolve([gal, psf]) + + # 3. Draw Image + image = final.drawImage(nx=config.image.size_x, + ny=config.image.size_y, + scale=config.image.pixel_scale).array + + # 4. Add Noise + rng = galsim.BaseDeviate(0) + noise_sigma = config.image.noise.sigma + noise = galsim.GaussianNoise(rng, sigma=noise_sigma) + + # GalSim image for noise addition + gs_image = galsim.Image(image) + gs_image.addNoise(noise) + noisy_image = gs_image.array + + noise_map = jnp.ones_like(noisy_image) * (noise_sigma**2) + + # Return JAX arrays and PSF config + return Observation( + image=jnp.array(noisy_image), + noise_map=noise_map, + psf_config={'type': config.psf.type, 'sigma': config.psf.sigma} + ) diff --git a/shine/inference.py b/shine/inference.py new file mode 100644 index 0000000..559ca0d --- /dev/null +++ b/shine/inference.py @@ -0,0 +1,55 @@ +import jax +import numpyro +from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO +from numpyro.infer.autoguide import AutoDelta +from typing import Dict, Any +import arviz as az + +class HMCInference: + def __init__(self, model, num_warmup=500, num_samples=1000, num_chains=1, dense_mass=False): + self.model = model + self.num_warmup = num_warmup + self.num_samples = num_samples + self.num_chains = num_chains + self.dense_mass = dense_mass + + def run(self, rng_key, observed_data, extra_args=None): + if extra_args is None: + extra_args = {} + + kernel = NUTS(self.model, dense_mass=self.dense_mass) + mcmc = MCMC(kernel, num_warmup=self.num_warmup, num_samples=self.num_samples, num_chains=self.num_chains) + + mcmc.run(rng_key, observed_data=observed_data, **extra_args) + mcmc.print_summary() + + # Convert to ArviZ InferenceData + return az.from_numpyro(mcmc) + +class MAPInference: + def __init__(self, model, num_steps=1000, learning_rate=1e-2): + self.model = model + self.num_steps = num_steps + self.learning_rate = learning_rate + + def run(self, rng_key, observed_data, extra_args=None): + if extra_args is None: + extra_args = {} + + guide = AutoDelta(self.model) + optimizer = numpyro.optim.Adam(step_size=self.learning_rate) + svi = SVI(self.model, guide, optimizer, loss=Trace_ELBO()) + + svi_result = svi.run(rng_key, self.num_steps, observed_data=observed_data, **extra_args) + + params = svi_result.params + # The params from AutoDelta are the MAP estimates (in unconstrained space usually, + # but AutoDelta returns constrained values if using `median` init or similar, + # actually AutoDelta parameters are the values themselves). + + # We need to sample from the guide to get the values in the proper structure if needed, + # but for AutoDelta, params contains the values. + # Note: AutoDelta names parameters with `_auto_loc` suffix sometimes or keeps original names depending on version. + # Let's check the guide median. + + return guide.median(params) diff --git a/shine/main.py b/shine/main.py new file mode 100644 index 0000000..9c74f4a --- /dev/null +++ b/shine/main.py @@ -0,0 +1,130 @@ +import argparse +import jax +import jax.numpy as jnp +from pathlib import Path +import pickle +import xarray as xr + +from shine.config import ConfigHandler +from shine.data import DataLoader +from shine.scene import SceneBuilder +from shine.inference import HMCInference, MAPInference + +def main(): + parser = argparse.ArgumentParser(description="SHINE: SHear INference Environment") + parser.add_argument("--config", type=str, required=True, help="Path to configuration YAML") + parser.add_argument("--mode", type=str, default="hmc", choices=["hmc", "map"], help="Inference mode: hmc or map") + parser.add_argument("--output", type=str, default=None, help="Output directory") + + args = parser.parse_args() + + # 1. Load Config + config = ConfigHandler.load(args.config) + + # Override output path if provided + if args.output: + config.output_path = args.output + + output_dir = Path(config.output_path) + output_dir.mkdir(parents=True, exist_ok=True) + + # 2. Load Data + observation = DataLoader.load(config) + + # Save observation for reference + jnp.savez(output_dir / "observation.npz", + image=observation.image, + noise_map=observation.noise_map) + + # 3. Build Model + scene_builder = SceneBuilder(config) + model_fn = scene_builder.build_model() + + # 4. Run Inference + print(f"Starting inference in {args.mode.upper()} mode...") + rng_key = jax.random.PRNGKey(42) + + if args.mode == "hmc": + engine = HMCInference( + model=model_fn, + num_warmup=config.inference.warmup, + num_samples=config.inference.samples, + num_chains=config.inference.chains, + dense_mass=config.inference.dense_mass + ) + results = engine.run( + rng_key=rng_key, + observed_data=observation.image, + extra_args={"psf_config": observation.psf_config} + ) + + # Save results + output_file = output_dir / "posterior.nc" + results.to_netcdf(output_file) + print(f"Results saved to {output_file}") + + # Print summary + print(results.posterior) + + elif args.mode == "map": + engine = MAPInference(model=model_fn) + results = engine.run( + rng_key=rng_key, + observed_data=observation.image, + extra_args={"psf_config": observation.psf_config} + ) + + # Save MAP estimates + output_file = output_dir / "map_results.pkl" + with open(output_file, "wb") as f: + pickle.dump(results, f) + print(f"MAP estimates saved to {output_file}") + + print("MAP Estimates:") + for k, v in results.items(): + print(f"{k}: {v}") + + # Calculate Residuals + # We need to re-render the scene with MAP parameters + # This is a bit tricky without exposing the render function directly from the model. + # A clean way is to use numpyro.handlers.substitute and trace the model. + + from numpyro.handlers import substitute, trace, seed + + # We need to run the model with the MAP parameters + # The model function expects (observed_data, psf) + # We pass None for observed_data to get the model prediction (if the model returns it or we capture the deterministic site) + # But our model samples 'obs' at the end. + # We can capture the 'obs' site mean (which is the model image). + + # However, our model defines 'obs' as Normal(model_image, sigma). + # We want 'model_image'. + # We didn't expose 'model_image' as a deterministic site in the builder. + # Let's modify the builder to expose it, OR we can just inspect the 'obs' distribution mean in the trace. + + model_with_params = substitute(model_fn, results) + + # Trace the model execution + # We need to pass the same args as during inference + traced_model = trace(seed(model_with_params, rng_key)) + trace_out = traced_model.get_trace(observed_data=None, psf_config=observation.psf_config) + + # The 'obs' site contains the distribution with the mean we want + obs_node = trace_out['obs'] + model_image = obs_node['fn'].mean + + residual = observation.image - model_image + chi2 = jnp.sum((residual**2) / observation.noise_map) + dof = observation.image.size # approx + reduced_chi2 = chi2 / dof + + print(f"Reduced Chi2: {reduced_chi2:.4f}") + + jnp.savez(output_dir / "residuals.npz", + data=observation.image, + model=model_image, + residual=residual, + chi2=chi2) + +if __name__ == "__main__": + main() diff --git a/shine/scene.py b/shine/scene.py new file mode 100644 index 0000000..0428fa6 --- /dev/null +++ b/shine/scene.py @@ -0,0 +1,82 @@ +import numpyro +import numpyro.distributions as dist +import jax.numpy as jnp +import jax +import jax_galsim as galsim +from shine.config import ShineConfig, DistributionConfig + +class SceneBuilder: + def __init__(self, config: ShineConfig): + self.config = config + + def _parse_prior(self, name: str, param_config): + """Helper to create NumPyro distributions from config or return fixed value.""" + if isinstance(param_config, (float, int)): + return float(param_config) + + if isinstance(param_config, DistributionConfig): + if param_config.type == 'Normal': + return numpyro.sample(name, dist.Normal(param_config.mean, param_config.sigma)) + elif param_config.type == 'LogNormal': + return numpyro.sample(name, dist.LogNormal(jnp.log(param_config.mean), param_config.sigma)) + elif param_config.type == 'Uniform': + return numpyro.sample(name, dist.Uniform(param_config.min, param_config.max)) + else: + raise ValueError(f"Unknown distribution type: {param_config.type}") + + return param_config + + def build_model(self): + """Returns a callable model function for NumPyro.""" + + def model(observed_data=None, psf_config=None): + # Define GSParams with fixed FFT size to avoid dynamic shape issues in JAX + fft_size = 128 + gsparams = galsim.GSParams(maximum_fft_size=fft_size, minimum_fft_size=fft_size) + + # --- 0. Build PSF from config using JAX-GalSim --- + if psf_config['type'] == 'Gaussian': + psf = galsim.Gaussian(sigma=psf_config['sigma'], gsparams=gsparams) + else: + raise NotImplementedError(f"PSF type {psf_config['type']} not implemented") + + # --- 1. Global Parameters (Shear) --- + g1 = self._parse_prior("g1", self.config.gal.shear.g1) + g2 = self._parse_prior("g2", self.config.gal.shear.g2) + shear = galsim.Shear(g1=g1, g2=g2) + + # --- 2. Galaxy Population --- + n_galaxies = self.config.image.n_objects + + with numpyro.plate("galaxies", n_galaxies): + flux = self._parse_prior("flux", self.config.gal.flux) + hlr = self._parse_prior("hlr", self.config.gal.half_light_radius) + + x = self.config.image.size_x / 2.0 + y = self.config.image.size_y / 2.0 + + # --- 3. Differentiable Rendering --- + def render_one_galaxy(flux, hlr, x, y): + # Use Exponential instead of Sersic (not available in jax_galsim yet) + gal = galsim.Exponential(half_light_radius=hlr, flux=flux, gsparams=gsparams) + gal = gal.shear(shear) + gal = galsim.Convolve([gal, psf], gsparams=gsparams) + return gal.drawImage(nx=self.config.image.size_x, + ny=self.config.image.size_y, + scale=self.config.image.pixel_scale, + offset=(x - self.config.image.size_x/2 + 0.5, y - self.config.image.size_y/2 + 0.5) + ).array + + flux = jnp.atleast_1d(flux) + hlr = jnp.atleast_1d(hlr) + x = jnp.atleast_1d(x) + y = jnp.atleast_1d(y) + + galaxy_images = jax.vmap(render_one_galaxy)(flux, hlr, x, y) + model_image = jnp.sum(galaxy_images, axis=0) + + # --- 4. Likelihood --- + sigma = self.config.image.noise.sigma + numpyro.sample("obs", dist.Normal(model_image, sigma), obs=observed_data) + + return model