diff --git a/.gitignore b/.gitignore
index 6dea9b1..180be97 100644
--- a/.gitignore
+++ b/.gitignore
@@ -206,4 +206,6 @@ marimo/_static/
marimo/_lsp/
__marimo__/
-shine/_version.py
\ No newline at end of file
+shine/_version.py
+
+results/
\ No newline at end of file
diff --git a/DESIGN.md b/DESIGN.md
index 67545ea..0b93413 100644
--- a/DESIGN.md
+++ b/DESIGN.md
@@ -24,26 +24,26 @@ SHINE treats shear measurement as a Bayesian inverse problem. Instead of measuri
```mermaid
graph TD
Config[YAML Configuration] --> Loader[Config Handler]
- Loader --> Scene[Scene Modelling (NumPyro + JAX-GalSim)]
-
+ Loader --> Scene[Scene Modelling
NumPyro + JAX-GalSim]
+
subgraph "Forward Model (JAX)"
- Priors[Priors (NumPyro)] --> Scene
- Scene --> Galaxy[Galaxy Generation (Sersic/Morphology)]
+ Priors[Priors
NumPyro] --> Scene
+ Scene --> Galaxy[Galaxy Generation
Sersic/Morphology]
Scene --> PSF[PSF Modelling]
Galaxy --> Convolve[Convolution]
PSF --> Convolve
Convolve --> Noise[Noise Model]
Noise --> ModelImage[Simulated Image]
end
-
- Data[Observed Data (Fits/HDF5)] --> Likelihood
+
+ Data[Observed Data
Fits/HDF5] --> Likelihood
ModelImage --> Likelihood[Likelihood Evaluation]
-
- Likelihood --> Inference[Inference Engine (NumPyro/BlackJAX)]
+
+ Likelihood --> Inference[Inference Engine
NumPyro/BlackJAX]
Inference --> Posterior[Shear Posterior]
-
+
subgraph "Workflow Management"
- WMS[WMS (Slurm/Cluster)] --> Config
+ WMS[WMS
Slurm/Cluster] --> Config
WMS --> Data
end
```
diff --git a/configs/test_run.yaml b/configs/test_run.yaml
new file mode 100644
index 0000000..77021bf
--- /dev/null
+++ b/configs/test_run.yaml
@@ -0,0 +1,38 @@
+image:
+ pixel_scale: 0.2
+ size_x: 32
+ size_y: 32
+ n_objects: 1
+ noise:
+ type: Gaussian
+ sigma: 0.01
+
+psf:
+ type: Gaussian
+ sigma: 0.1
+
+gal:
+ type: Exponential
+ flux:
+ type: LogNormal
+ mean: 100.0
+ sigma: 0.1
+ half_light_radius:
+ type: Uniform
+ min: 0.3
+ max: 0.8
+ shear:
+ type: G1G2
+ g1:
+ type: Normal
+ mean: 0.05
+ sigma: 0.02
+ g2:
+ type: Normal
+ mean: -0.05
+ sigma: 0.02
+
+inference:
+ warmup: 100
+ samples: 100
+ chains: 1
diff --git a/shine/config.py b/shine/config.py
new file mode 100644
index 0000000..b915c1a
--- /dev/null
+++ b/shine/config.py
@@ -0,0 +1,83 @@
+from typing import Union, Optional, Dict, Any, List
+from pydantic import BaseModel, Field, validator
+import yaml
+from pathlib import Path
+
+# --- Distribution Models (for Priors) ---
+
+class DistributionConfig(BaseModel):
+ type: str
+ mean: Optional[float] = None
+ sigma: Optional[float] = None
+ min: Optional[float] = None
+ max: Optional[float] = None
+
+ # Allow extra fields for other distributions
+ class Config:
+ extra = "allow"
+
+# --- Component Models ---
+
+class NoiseConfig(BaseModel):
+ type: str = "Gaussian"
+ sigma: float
+
+class ImageConfig(BaseModel):
+ pixel_scale: float
+ size_x: int
+ size_y: int
+ n_objects: int = 1 # Default to 1 for simple tests
+ noise: NoiseConfig
+
+class PSFConfig(BaseModel):
+ type: str = "Gaussian"
+ sigma: float
+ beta: Optional[float] = 2.5 # For Moffat
+
+class ShearComponentConfig(BaseModel):
+ # Can be a fixed float or a distribution
+ type: Optional[str] = None # If None, assume fixed value in parent or handled elsewhere
+ mean: Optional[float] = 0.0
+ sigma: Optional[float] = 0.05
+
+ # To handle the case where it's just a float in YAML, we might need a custom validator
+ # but for now let's assume structured input as per design doc
+
+class ShearConfig(BaseModel):
+ type: str = "G1G2"
+ g1: Union[float, DistributionConfig]
+ g2: Union[float, DistributionConfig]
+
+class GalaxyConfig(BaseModel):
+ type: str = "Exponential" # Changed default from Sersic to Exponential
+ n: Optional[Union[float, DistributionConfig]] = None # Make optional for Exponential
+ flux: Union[float, DistributionConfig]
+ half_light_radius: Union[float, DistributionConfig] = Field(..., alias="half_light_radius")
+ shear: ShearConfig
+
+class InferenceConfig(BaseModel):
+ warmup: int = 500
+ samples: int = 1000
+ chains: int = 1
+ dense_mass: bool = False
+
+class ShineConfig(BaseModel):
+ image: ImageConfig
+ psf: PSFConfig
+ gal: GalaxyConfig
+ inference: InferenceConfig = Field(default_factory=InferenceConfig)
+ data_path: Optional[str] = None
+ output_path: str = "results"
+
+class ConfigHandler:
+ @staticmethod
+ def load(path: str) -> ShineConfig:
+ path = Path(path)
+ if not path.exists():
+ raise FileNotFoundError(f"Config file not found: {path}")
+
+ with open(path, "r") as f:
+ data = yaml.safe_load(f)
+
+ # Basic validation and type conversion via Pydantic
+ return ShineConfig(**data)
diff --git a/shine/data.py b/shine/data.py
new file mode 100644
index 0000000..5540cf0
--- /dev/null
+++ b/shine/data.py
@@ -0,0 +1,83 @@
+import jax.numpy as jnp
+import jax
+from dataclasses import dataclass
+from typing import Optional, Any, Dict
+from shine.config import ShineConfig
+
+@dataclass
+class Observation:
+ image: jnp.ndarray
+ noise_map: jnp.ndarray
+ psf_config: Dict[str, Any] # Store PSF config instead of object
+ wcs: Any = None
+
+class DataLoader:
+ @staticmethod
+ def load(config: ShineConfig) -> Observation:
+ if config.data_path and config.data_path != "None":
+ # TODO: Implement real data loading (Fits/HDF5)
+ raise NotImplementedError("Real data loading not yet implemented. Use synthetic generation.")
+ else:
+ print("No data path provided. Generating synthetic data...")
+ return DataLoader.generate_synthetic(config)
+
+ @staticmethod
+ def generate_synthetic(config: ShineConfig) -> Observation:
+ import galsim
+
+ # 1. Define PSF
+ if config.psf.type == "Gaussian":
+ psf = galsim.Gaussian(sigma=config.psf.sigma)
+ else:
+ raise NotImplementedError(f"PSF type {config.psf.type} not supported for synthetic gen")
+
+ # 2. Define Galaxy (using mean values from config for "truth")
+ def get_mean(param):
+ if isinstance(param, (float, int)):
+ return float(param)
+ # If it's a distribution config
+ if param.mean is not None:
+ return param.mean
+ # Handle Uniform
+ if param.type == 'Uniform' and param.min is not None and param.max is not None:
+ return (param.min + param.max) / 2.0
+ return param.mean # Fallback (might still be None if not handled)
+
+ gal_flux = get_mean(config.gal.flux)
+ gal_hlr = get_mean(config.gal.half_light_radius)
+
+ # Shear
+ g1 = get_mean(config.gal.shear.g1)
+ g2 = get_mean(config.gal.shear.g2)
+ shear = galsim.Shear(g1=g1, g2=g2)
+
+ # Create Galaxy Object - Use Exponential (Sersic n=1)
+ gal = galsim.Exponential(half_light_radius=gal_hlr, flux=gal_flux)
+ gal = gal.shear(shear)
+
+ # Convolve
+ final = galsim.Convolve([gal, psf])
+
+ # 3. Draw Image
+ image = final.drawImage(nx=config.image.size_x,
+ ny=config.image.size_y,
+ scale=config.image.pixel_scale).array
+
+ # 4. Add Noise
+ rng = galsim.BaseDeviate(0)
+ noise_sigma = config.image.noise.sigma
+ noise = galsim.GaussianNoise(rng, sigma=noise_sigma)
+
+ # GalSim image for noise addition
+ gs_image = galsim.Image(image)
+ gs_image.addNoise(noise)
+ noisy_image = gs_image.array
+
+ noise_map = jnp.ones_like(noisy_image) * (noise_sigma**2)
+
+ # Return JAX arrays and PSF config
+ return Observation(
+ image=jnp.array(noisy_image),
+ noise_map=noise_map,
+ psf_config={'type': config.psf.type, 'sigma': config.psf.sigma}
+ )
diff --git a/shine/inference.py b/shine/inference.py
new file mode 100644
index 0000000..559ca0d
--- /dev/null
+++ b/shine/inference.py
@@ -0,0 +1,55 @@
+import jax
+import numpyro
+from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
+from numpyro.infer.autoguide import AutoDelta
+from typing import Dict, Any
+import arviz as az
+
+class HMCInference:
+ def __init__(self, model, num_warmup=500, num_samples=1000, num_chains=1, dense_mass=False):
+ self.model = model
+ self.num_warmup = num_warmup
+ self.num_samples = num_samples
+ self.num_chains = num_chains
+ self.dense_mass = dense_mass
+
+ def run(self, rng_key, observed_data, extra_args=None):
+ if extra_args is None:
+ extra_args = {}
+
+ kernel = NUTS(self.model, dense_mass=self.dense_mass)
+ mcmc = MCMC(kernel, num_warmup=self.num_warmup, num_samples=self.num_samples, num_chains=self.num_chains)
+
+ mcmc.run(rng_key, observed_data=observed_data, **extra_args)
+ mcmc.print_summary()
+
+ # Convert to ArviZ InferenceData
+ return az.from_numpyro(mcmc)
+
+class MAPInference:
+ def __init__(self, model, num_steps=1000, learning_rate=1e-2):
+ self.model = model
+ self.num_steps = num_steps
+ self.learning_rate = learning_rate
+
+ def run(self, rng_key, observed_data, extra_args=None):
+ if extra_args is None:
+ extra_args = {}
+
+ guide = AutoDelta(self.model)
+ optimizer = numpyro.optim.Adam(step_size=self.learning_rate)
+ svi = SVI(self.model, guide, optimizer, loss=Trace_ELBO())
+
+ svi_result = svi.run(rng_key, self.num_steps, observed_data=observed_data, **extra_args)
+
+ params = svi_result.params
+ # The params from AutoDelta are the MAP estimates (in unconstrained space usually,
+ # but AutoDelta returns constrained values if using `median` init or similar,
+ # actually AutoDelta parameters are the values themselves).
+
+ # We need to sample from the guide to get the values in the proper structure if needed,
+ # but for AutoDelta, params contains the values.
+ # Note: AutoDelta names parameters with `_auto_loc` suffix sometimes or keeps original names depending on version.
+ # Let's check the guide median.
+
+ return guide.median(params)
diff --git a/shine/main.py b/shine/main.py
new file mode 100644
index 0000000..9c74f4a
--- /dev/null
+++ b/shine/main.py
@@ -0,0 +1,130 @@
+import argparse
+import jax
+import jax.numpy as jnp
+from pathlib import Path
+import pickle
+import xarray as xr
+
+from shine.config import ConfigHandler
+from shine.data import DataLoader
+from shine.scene import SceneBuilder
+from shine.inference import HMCInference, MAPInference
+
+def main():
+ parser = argparse.ArgumentParser(description="SHINE: SHear INference Environment")
+ parser.add_argument("--config", type=str, required=True, help="Path to configuration YAML")
+ parser.add_argument("--mode", type=str, default="hmc", choices=["hmc", "map"], help="Inference mode: hmc or map")
+ parser.add_argument("--output", type=str, default=None, help="Output directory")
+
+ args = parser.parse_args()
+
+ # 1. Load Config
+ config = ConfigHandler.load(args.config)
+
+ # Override output path if provided
+ if args.output:
+ config.output_path = args.output
+
+ output_dir = Path(config.output_path)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ # 2. Load Data
+ observation = DataLoader.load(config)
+
+ # Save observation for reference
+ jnp.savez(output_dir / "observation.npz",
+ image=observation.image,
+ noise_map=observation.noise_map)
+
+ # 3. Build Model
+ scene_builder = SceneBuilder(config)
+ model_fn = scene_builder.build_model()
+
+ # 4. Run Inference
+ print(f"Starting inference in {args.mode.upper()} mode...")
+ rng_key = jax.random.PRNGKey(42)
+
+ if args.mode == "hmc":
+ engine = HMCInference(
+ model=model_fn,
+ num_warmup=config.inference.warmup,
+ num_samples=config.inference.samples,
+ num_chains=config.inference.chains,
+ dense_mass=config.inference.dense_mass
+ )
+ results = engine.run(
+ rng_key=rng_key,
+ observed_data=observation.image,
+ extra_args={"psf_config": observation.psf_config}
+ )
+
+ # Save results
+ output_file = output_dir / "posterior.nc"
+ results.to_netcdf(output_file)
+ print(f"Results saved to {output_file}")
+
+ # Print summary
+ print(results.posterior)
+
+ elif args.mode == "map":
+ engine = MAPInference(model=model_fn)
+ results = engine.run(
+ rng_key=rng_key,
+ observed_data=observation.image,
+ extra_args={"psf_config": observation.psf_config}
+ )
+
+ # Save MAP estimates
+ output_file = output_dir / "map_results.pkl"
+ with open(output_file, "wb") as f:
+ pickle.dump(results, f)
+ print(f"MAP estimates saved to {output_file}")
+
+ print("MAP Estimates:")
+ for k, v in results.items():
+ print(f"{k}: {v}")
+
+ # Calculate Residuals
+ # We need to re-render the scene with MAP parameters
+ # This is a bit tricky without exposing the render function directly from the model.
+ # A clean way is to use numpyro.handlers.substitute and trace the model.
+
+ from numpyro.handlers import substitute, trace, seed
+
+ # We need to run the model with the MAP parameters
+ # The model function expects (observed_data, psf)
+ # We pass None for observed_data to get the model prediction (if the model returns it or we capture the deterministic site)
+ # But our model samples 'obs' at the end.
+ # We can capture the 'obs' site mean (which is the model image).
+
+ # However, our model defines 'obs' as Normal(model_image, sigma).
+ # We want 'model_image'.
+ # We didn't expose 'model_image' as a deterministic site in the builder.
+ # Let's modify the builder to expose it, OR we can just inspect the 'obs' distribution mean in the trace.
+
+ model_with_params = substitute(model_fn, results)
+
+ # Trace the model execution
+ # We need to pass the same args as during inference
+ traced_model = trace(seed(model_with_params, rng_key))
+ trace_out = traced_model.get_trace(observed_data=None, psf_config=observation.psf_config)
+
+ # The 'obs' site contains the distribution with the mean we want
+ obs_node = trace_out['obs']
+ model_image = obs_node['fn'].mean
+
+ residual = observation.image - model_image
+ chi2 = jnp.sum((residual**2) / observation.noise_map)
+ dof = observation.image.size # approx
+ reduced_chi2 = chi2 / dof
+
+ print(f"Reduced Chi2: {reduced_chi2:.4f}")
+
+ jnp.savez(output_dir / "residuals.npz",
+ data=observation.image,
+ model=model_image,
+ residual=residual,
+ chi2=chi2)
+
+if __name__ == "__main__":
+ main()
diff --git a/shine/scene.py b/shine/scene.py
new file mode 100644
index 0000000..0428fa6
--- /dev/null
+++ b/shine/scene.py
@@ -0,0 +1,82 @@
+import numpyro
+import numpyro.distributions as dist
+import jax.numpy as jnp
+import jax
+import jax_galsim as galsim
+from shine.config import ShineConfig, DistributionConfig
+
+class SceneBuilder:
+ def __init__(self, config: ShineConfig):
+ self.config = config
+
+ def _parse_prior(self, name: str, param_config):
+ """Helper to create NumPyro distributions from config or return fixed value."""
+ if isinstance(param_config, (float, int)):
+ return float(param_config)
+
+ if isinstance(param_config, DistributionConfig):
+ if param_config.type == 'Normal':
+ return numpyro.sample(name, dist.Normal(param_config.mean, param_config.sigma))
+ elif param_config.type == 'LogNormal':
+ return numpyro.sample(name, dist.LogNormal(jnp.log(param_config.mean), param_config.sigma))
+ elif param_config.type == 'Uniform':
+ return numpyro.sample(name, dist.Uniform(param_config.min, param_config.max))
+ else:
+ raise ValueError(f"Unknown distribution type: {param_config.type}")
+
+ return param_config
+
+ def build_model(self):
+ """Returns a callable model function for NumPyro."""
+
+ def model(observed_data=None, psf_config=None):
+ # Define GSParams with fixed FFT size to avoid dynamic shape issues in JAX
+ fft_size = 128
+ gsparams = galsim.GSParams(maximum_fft_size=fft_size, minimum_fft_size=fft_size)
+
+ # --- 0. Build PSF from config using JAX-GalSim ---
+ if psf_config['type'] == 'Gaussian':
+ psf = galsim.Gaussian(sigma=psf_config['sigma'], gsparams=gsparams)
+ else:
+ raise NotImplementedError(f"PSF type {psf_config['type']} not implemented")
+
+ # --- 1. Global Parameters (Shear) ---
+ g1 = self._parse_prior("g1", self.config.gal.shear.g1)
+ g2 = self._parse_prior("g2", self.config.gal.shear.g2)
+ shear = galsim.Shear(g1=g1, g2=g2)
+
+ # --- 2. Galaxy Population ---
+ n_galaxies = self.config.image.n_objects
+
+ with numpyro.plate("galaxies", n_galaxies):
+ flux = self._parse_prior("flux", self.config.gal.flux)
+ hlr = self._parse_prior("hlr", self.config.gal.half_light_radius)
+
+ x = self.config.image.size_x / 2.0
+ y = self.config.image.size_y / 2.0
+
+ # --- 3. Differentiable Rendering ---
+ def render_one_galaxy(flux, hlr, x, y):
+ # Use Exponential instead of Sersic (not available in jax_galsim yet)
+ gal = galsim.Exponential(half_light_radius=hlr, flux=flux, gsparams=gsparams)
+ gal = gal.shear(shear)
+ gal = galsim.Convolve([gal, psf], gsparams=gsparams)
+ return gal.drawImage(nx=self.config.image.size_x,
+ ny=self.config.image.size_y,
+ scale=self.config.image.pixel_scale,
+ offset=(x - self.config.image.size_x/2 + 0.5, y - self.config.image.size_y/2 + 0.5)
+ ).array
+
+ flux = jnp.atleast_1d(flux)
+ hlr = jnp.atleast_1d(hlr)
+ x = jnp.atleast_1d(x)
+ y = jnp.atleast_1d(y)
+
+ galaxy_images = jax.vmap(render_one_galaxy)(flux, hlr, x, y)
+ model_image = jnp.sum(galaxy_images, axis=0)
+
+ # --- 4. Likelihood ---
+ sigma = self.config.image.noise.sigma
+ numpyro.sample("obs", dist.Normal(model_image, sigma), obs=observed_data)
+
+ return model