"""File containing Synthesizer library generation and galaxy creation utilities."""
import copy
import inspect
import os
import threading
from collections import defaultdict
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import astropy.units as u
import h5py
import matplotlib.patheffects as PathEffects
import matplotlib.pyplot as plt
import numpy as np
import torch
from astropy.cosmology import Cosmology, Planck18, z_at_value
from dill.source import getsource
from joblib import Parallel, delayed, parallel_config
from matplotlib.ticker import FuncFormatter, ScalarFormatter
from scipy.linalg import inv
from scipy.stats import qmc
from . import logger
try:
from synthesizer.conversions import lnu_to_fnu
from synthesizer.emission_models import EmissionModel
from synthesizer.emission_models.attenuation import Inoue14
from synthesizer.emissions import plot_spectra
from synthesizer.grid import Grid
from synthesizer.instruments import UVJ, FilterCollection, Instrument
from synthesizer.parametric import SFH, Galaxy, Stars, ZDist
from synthesizer.particle.stars import sample_sfzh
from synthesizer.pipeline import Pipeline
synthesizer_available = True
except ImportError as e:
print(e)
logger.warning(
"Synthesizer dependencies not installed. Only the SBI functions will be available."
)
synthesizer_available = False
from tqdm import tqdm
from unyt import (
Angstrom,
Jy,
Mpc,
Msun,
Myr,
Unit,
define_unit,
dimensionless,
eV,
nJy,
uJy,
um,
unyt_array,
unyt_quantity,
yr,
)
from .noise_models import (
EmpiricalUncertaintyModel,
)
from .utils import check_log_scaling, check_scaling, list_parameters, save_emission_model
file_path = os.path.dirname(os.path.realpath(__file__))
library_folder = os.path.join(os.path.dirname(os.path.dirname(file_path)), "libraries")
# Global variables for thread-shared data (initialized once per process)
_thread_local = threading.local()
UNIT_DICT = {
"log10metallicity": "log10(Zmet)",
"metallicity": "Zmet",
"av": "mag",
"tau_v": "mag",
"tau_v_ism": "mag",
"tau_v_birth": "mag",
"weight_fraction": "dimensionless",
"log_sfr": "log10(Msun/yr)",
"sfr": "Msun/yr",
"log_stellar_mass": "log10(Msun)",
"log_surviving_mass": "log10(Msun)",
"stellar_mass": "Msun",
}
if synthesizer_available:
uvj = UVJ()
uvj = {
"U": FilterCollection(filters=[uvj["U"]]),
"V": FilterCollection(filters=[uvj["V"]]),
"J": FilterCollection(filters=[uvj["J"]]),
}
else:
uvj = {}
tophats = {
"MUV": {"lam_eff": 1500 * Angstrom, "lam_fwhm": 100 * Angstrom},
}
muv_filter = FilterCollection(tophat_dict=tophats, verbose=False)
try:
define_unit("log10_Msun", 1 * dimensionless)
except RuntimeError:
pass
try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
except ImportError:
rank = 0
size = 1
comm = None
# ------------------------------------------
# Functions for galaxy parameters
# ------------------------------------------
if not synthesizer_available:
# Make dummy classes for type checking
class SFH:
"""Dummy class for SFH to allow type checking without synthesizer installed."""
class Common:
"""Dummy class."""
pass
class ZDist:
"""Dummy class."""
class Common:
"""Dummy class."""
pass
class EmissionModel:
"""Dummy class."""
pass
class Galaxy:
"""Dummy class."""
pass
class Grid:
"""Dummy class."""
pass
class Instrument:
"""Dummy class."""
pass
class Pipeline:
"""Dummy class."""
pass
[docs]
def calculate_muv(galaxy, cosmo=Planck18):
"""Calculate the apparent mUV magnitude of a galaxy.
Parameters
----------
galaxy : Galaxy
The galaxy object containing stellar spectra.
cosmo : Cosmology, optional
The cosmology to use for redshift calculations, by default Planck18.
Returns:
-------
dict
Dictionary containing the MUV magnitude for each stellar spectrum in the galaxy.
"""
z = galaxy.redshift
phots = {}
for key in list(galaxy.stars.spectra.keys()):
lnu = galaxy.stars.spectra[key].get_photo_lnu(muv_filter).photo_lnu[0]
phot = lnu_to_fnu(lnu, cosmo=cosmo, redshift=z)
phots[key] = phot
return phots
[docs]
def calculate_MUV(galaxy, cosmo=Planck18):
"""Calculate the absolute MUV magnitude of a galaxy.
Parameters
----------
galaxy : Galaxy
The galaxy object containing stellar spectra.
cosmo : Cosmology, optional
The cosmology to use for redshift calculations, by default Planck18.
Returns:
-------
dict
Dictionary containing the MUV magnitude for each stellar spectrum in the galaxy.
"""
phots = {}
for key in list(galaxy.stars.spectra.keys()):
lnu = galaxy.stars.spectra[key].get_photo_lnu(muv_filter).photo_lnu[0]
phots[key] = lnu
return phots
[docs]
def calculate_sfr(galaxy, timescale=10 * Myr):
"""Calculate the star formation rate (SFR) of a galaxy over a specified timescale.
Args:
galaxy: An instance of a synthesizer.parametric.Galaxy object.
timescale: The timescale over which to calculate the SFR (default is 10 Myr).
Returns:
The star formation rate as a float.
"""
timescale = (0, timescale.to("yr").value) # Convert timescale to years
sfr = galaxy.stars.calculate_average_sfr(t_range=timescale)
return sfr
[docs]
def calculate_mass_weighted_age(galaxy):
"""Calculate the mass-weighted age of the stars in the galaxy."""
return galaxy.stars.get_mass_weighted_age().to("Myr")
[docs]
def calculate_lum_weighted_age(galaxy, spectra_type="total", filter_code="V"):
"""Calculate the luminosity-weighted age of the stars in the galaxy."""
return galaxy.stars.get_lum_weighted_age(spectra_type=spectra_type, filter_code=filter_code).to(
"Myr"
)
[docs]
def calculate_flux_weighted_age(galaxy, spectra_type="total", filter_code="JWST/NIRCam.F444W"):
"""Calculate the flux-weighted age of the stars in the galaxy."""
return galaxy.stars.get_flux_weighted_age(
spectra_type=spectra_type, filter_code=filter_code
).to("Myr")
[docs]
def calculate_colour(
galaxy: Galaxy,
filter1: str,
filter2: str,
emission_model_key: str = "total",
rest_frame: bool = False,
) -> float:
"""Measures the colour of a galaxy between two filters (filter1 - filter2).
Args:
galaxy: An instance of a synthesizer.parametric.Galaxy object.
filter1: The first filter code (e.g., 'JWST/NIRCam.F444W').
filter2: The second filter code (e.g., 'JWST/NIRCam.F115W').
emission_model_key: The key for the emission model to use (default is 'total').
rest_frame: Whether to use the rest frame (default is False).
Returns:
The colour of the galaxy as a float.
"""
if (filter1 in ["U", "V", "J"] or filter2 in ["U", "V", "J"]) and not rest_frame:
logger.warning(
"Warning: Using 'U', 'V', or 'J' filters in the observed frame is not recommended. "
"Set 'rest_frame=True' to use these filters in the rest frame."
)
for i, filter_code in enumerate([filter1, filter2]):
sed = locate_sed(galaxy, emission_model_key)
if sed.photo_fnu is None or filter_code not in sed.photo_fnu:
try:
if filter_code in ["U", "V", "J"]:
filters = uvj[filter_code]
else:
filters = FilterCollection(filter_codes=[filter_code])
if rest_frame:
galaxy.stars.get_photo_lnu(filters)
else:
galaxy.stars.get_photo_fnu(filters)
except ValueError:
raise ValueError(
"Filter '{filter_code}' is not available in the "
f"emission model '{emission_model_key}'."
)
if i == 0:
flux1 = galaxy.stars.spectra[emission_model_key]
if rest_frame:
flux1 = flux1.photo_lnu[filter_code]
else:
flux1 = flux1.photo_fnu[filter_code]
else:
flux2 = galaxy.stars.spectra[emission_model_key]
if rest_frame:
flux2 = flux2.photo_lnu[filter_code]
else:
flux2 = flux2.photo_fnu[filter_code]
colour = 2.5 * np.log10(flux2 / flux1)
return colour
[docs]
def locate_sed(galaxy, emission_model_key):
"""Locate the SED in the galaxy based on the emission model key."""
if emission_model_key in galaxy.spectra.keys():
sed = galaxy.spectra[emission_model_key]
elif emission_model_key in galaxy.stars.spectra.keys():
sed = galaxy.stars.spectra[emission_model_key]
elif emission_model_key in galaxy.gas.spectra.keys():
sed = galaxy.gas.spectra[emission_model_key]
elif emission_model_key in galaxy.black_hole.spectra.keys():
sed = galaxy.black_hole.spectra[emission_model_key]
else:
raise ValueError(f"Emission model key '{emission_model_key}' not found in galaxy spectra.")
return sed
[docs]
def calculate_d4000(galaxy: Galaxy, emission_model_key: str = "total") -> float:
"""Measures the D4000 index of a galaxy.
Args:
galaxy: An instance of a synthesizer.parametric.Galaxy object.
emission_model_key: The key for the emission model to use (default is 'total').
Returns:
The D4000 index as a float.
"""
sed = locate_sed(galaxy, emission_model_key)
d4000 = sed.measure_d4000()
return d4000
[docs]
def calculate_beta(galaxy: Galaxy, emission_model_key: str = "total") -> float:
"""Measures the beta index of a galaxy.
Args:
galaxy: An instance of a synthesizer.parametric.Galaxy object.
emission_model_key: The key for the emission model to use (default is 'total').
Returns:
The beta index as a float.
"""
sed = locate_sed(galaxy, emission_model_key)
beta = sed.measure_beta()
return beta
[docs]
def calculate_balmer_decrement(galaxy: Galaxy, emission_model_key: str = "total") -> float:
"""Measures the Balmer decrement a galaxy.
Args:
galaxy: An instance of a synthesizer.parametric.Galaxy object.
emission_model_key: The key for the emission model to use (default is 'total').
Returns:
The Balmer decrement as a float.
"""
sed = locate_sed(galaxy, emission_model_key)
balmer_decrement = sed.measure_balmer_break(integration_method="simps")
return balmer_decrement
[docs]
def calculate_line_flux(
galaxy: Galaxy, emission_model, line="Ha", emission_model_key="total", cosmo=Planck18
):
"""Measures the equivalent widths of specific emission lines in a galaxy.
Args:
galaxy: An instance of a synthesizer.parametric.Galaxy object.
emission_model: An instance of a synthesizer.emission_models.EmissionModel.
line: The name of the emission line to measure (default is 'Ha').
emission_model_key: The key for the emission model to use (default is 'total').
cosmo: An instance of astropy.cosmology.Cosmology (default is Planck18).
Returns:
A dictionary with line names as keys and their equivalent widths as values.
"""
from synthesizer.emissions.utils import aliases
line = aliases.get(line, line) # Handle aliases for line names
line = galaxy.stars.get_lines(([line]), emission_model)
flux = line.get_flux(cosmo=cosmo, z=galaxy.redshift)[0]
return flux
[docs]
def calculate_line_ew(galaxy: Galaxy, emission_model, line="Ha", emission_model_key="total"):
"""Measures the rest-frame equivalent widths of specific emission lines in a galaxy.
Args:
galaxy: An instance of a synthesizer.parametric.Galaxy object.
emission_model: An instance of a synthesizer.emission_models.EmissionModel.
line: The name of the emission line to measure (default is 'Ha').
emission_model_key: The key for the emission model to use (default is 'total').
Returns:
A dictionary with line names as keys and their equivalent widths as values.
"""
from synthesizer.emissions.utils import aliases
line = aliases.get(line, line) # Handle aliases for line names
galaxy.stars.get_lines(([line]), emission_model)
line = galaxy.stars.lines[emission_model_key]
return line.equivalent_width[0]
[docs]
def calculate_burstiness(galaxy: Galaxy):
"""Calculate the burstiness parameter of a galaxy.
Burstiness is defined as the ratio of the SFR averaged over the last 10 Myr
to the SFR averaged over the last 100 Myr.
Args:
galaxy: An instance of a synthesizer.parametric.Galaxy object.
Returns:
The burstiness parameter as a dimensionless float.
"""
sfr_10 = galaxy.stars.calculate_average_sfr((0, 1e7))
sfr_100 = galaxy.stars.calculate_average_sfr((0, 1e8))
return (sfr_10 / sfr_100).to_value("dimensionless")
[docs]
def calculate_line_luminosity(
galaxy: Galaxy, emission_model, line="Ha", emission_model_key="total"
):
"""Measures the luminosity of specific emission lines in a galaxy.
Args:
galaxy: An instance of a synthesizer.parametric.Galaxy object.
emission_model: An instance of a synthesizer.emission_models.EmissionModel.
line: The name of the emission line to measure (default is 'Ha').
emission_model_key: The key for the emission model to use (default is 'total').
Returns:
A dictionary with line names as keys and their luminosities as values.
"""
from synthesizer.emissions.utils import aliases
line = aliases.get(line, line) # Handle aliases for line names
galaxy.stars.get_lines(([line]), emission_model)
return galaxy.stars.lines[emission_model_key].luminosity[0]
[docs]
def calculate_sfh_quantile(galaxy, quantile=0.5, norm=False, cosmo=Planck18):
"""Calculate the lookback time at which a certain fraction of the total mass is formed.
Args:
galaxy: An instance of a synthesizer.parametric.Galaxy object.
quantile: The fraction of total mass formed (default is 0.5 for median).
norm: If True then the age is as a fraction of the age of the universe at
the redshift of the galaxy.
cosmo: Cosmology object to use for age calculations, default is Planck18.
Returns:
The lookback time in Myr at which the specified fraction of total mass is formed.
"""
if not isinstance(galaxy, Galaxy):
raise TypeError("galaxy must be an instance of synthesizer.parametric.Galaxy")
if not hasattr(galaxy.stars, "sfh"):
raise AttributeError("galaxy.stars must have a 'sfh' attribute")
assert 0 <= quantile <= 1, "quantile must be between 0 and 1."
mass_bins = galaxy.stars.sf_hist
ages = galaxy.stars.ages
# from young to old
cumulative_mass = np.cumsum(mass_bins[::-1])
# Find the time at which the specified quantile of total mass is formed
total_mass = cumulative_mass[-1]
target_mass = quantile * total_mass
# Find the index where cumulative mass exceeds target mass
index = np.searchsorted(cumulative_mass, target_mass)
lookback_time = ages[::-1][index].to("Myr") # Get the corresponding age from the ages array
if norm:
# Normalize the lookback time to the age of the universe at the galaxy's redshift
age_of_universe = cosmo.age(galaxy.redshift).to("Myr").value
lookback_time = lookback_time.value / age_of_universe
return lookback_time
[docs]
def calculate_surviving_mass(galaxy, grid: Grid):
"""Calculate the surviving mass of the stellar population in a galaxy.
Args:
galaxy: An instance of a synthesizer.parametric.Galaxy object.
grid: An instance of synthesizer.grid.Grid containing the SPS grid.
Returns:
The surviving mass as a unyt_quantity in Msun.
"""
mass = galaxy.get_surviving_mass(grid)
mass = np.log10(mass.to_value("Msun"))
mass = unyt_array(mass, "log10_Msun")
return mass
[docs]
def calculate_xi_ion0(galaxy, emission_model, emission_model_key="total"):
"""Calculate the ionizing photon production efficiency (xi_ion) of a galaxy.
Args:
galaxy: An instance of a synthesizer.parametric.Galaxy object.
emission_model: An instance of a synthesizer.emission_models.EmissionModel.
emission_model_key: The key for the emission model to use (default is 'total').
Returns:
The ionizing photon production efficiency (xi_ion) as a unyt_quantity in Hz/erg.
"""
from synthesizer.emissions.utils import Ha
from unyt import Hz, erg, s
galaxy.stars.get_lines(([Ha]), emission_model)
ha_luminosity = galaxy.stars.lines[emission_model_key].luminosity[0]
MUV = calculate_MUV(galaxy, cosmo=Planck18)[emission_model_key].to("erg/s/Hz")
conv = 7.37 * 1e11 * Hz / (erg / s)
Q_H0 = ha_luminosity * conv
xi_ion0 = Q_H0 / MUV
return xi_ion0.to("Hz/erg")
[docs]
def calculate_Ndot_ion(galaxy, emission_model_key="total", ionisation_energy=13.6 * eV, limit=100):
"""Calculate the ionizing photon production rate (Ndot_ion) of a galaxy.
Args:
galaxy: An instance of a synthesizer.parametric.Galaxy object.
emission_model_key: The key for the emission model to use (default is 'total').
ionisation_energy: The ionisation energy threshold (default is 13.6 eV).
limit: An upper bound on the number of subintervals
used in the integration adaptive algorithm.
Returns:
The ionizing photon production rate (Ndot_ion) as a unyt_quantity in s^-1.
"""
sed = locate_sed(galaxy, emission_model_key)
return sed.calculate_ionising_photon_production_rate(
ionisation_energy=ionisation_energy,
limit=limit,
)
[docs]
def calculate_agn_fraction(
galaxy,
total_em_key="total",
agn_em_key="agn_attenuated",
min_wav_rest=1 * um,
max_wav_rest=30 * um,
):
"""Compute the fraction of the total luminosity in a given wavelength range due to the AGN.
Parameters
----------
galaxy : Galaxy
The Galaxy object containing the spectra.
total_em_key : str
The key for the total emission model in the galaxy.
agn_em_key : str
The key for the AGN emission model in the galaxy.
min_wav_rest : float
The minimum rest-frame wavelength (in unyt units) for the integration.
max_wav_rest : float
The maximum rest-frame wavelength (in unyt units) for the integration.
Returns:
-------
float
The fraction of the total luminosity in the specified wavelength range due to the AGN.
"""
if len(galaxy.spectra) == 0:
raise ValueError("Galaxy has no spectra. Please run `get_spectra` first.")
galaxy.get_spectra_combined()
if total_em_key not in galaxy.spectra:
raise ValueError(
f"Total emission model '{total_em_key}' not found in galaxy spectra."
f" Available keys: {list(galaxy.spectra.keys())}"
)
if agn_em_key not in galaxy.black_holes.spectra:
raise ValueError(
f"AGN emission model '{agn_em_key}' not found in galaxy spectra."
f" Available keys: {list(galaxy.black_holes.spectra.keys())}"
)
wav = galaxy.spectra[total_em_key].lam
total_lum = galaxy.spectra[total_em_key].luminosity # erg/s
agn_lum = galaxy.black_holes.spectra[agn_em_key].luminosity # erg/s
mask = (wav >= min_wav_rest) & (wav <= max_wav_rest)
agn_fraction = np.trapezoid(agn_lum[mask], x=wav[mask]) / np.trapezoid(
total_lum[mask], x=wav[mask]
)
return agn_fraction
[docs]
def calculate_ml(galaxy, emission_model_key):
"""Calculate M/L ratio (in solar units)."""
raise NotImplementedError
[docs]
class SUPP_FUNCTIONS:
"""A class to hold supplementary functions for galaxy analysis."""
calculate_muv = calculate_muv
calculate_MUV = calculate_MUV
calculate_sfr = calculate_sfr
calculate_mass_weighted_age = calculate_mass_weighted_age
calculate_lum_weighted_age = calculate_lum_weighted_age
calculate_flux_weighted_age = calculate_flux_weighted_age
calculate_colour = calculate_colour
calculate_d4000 = calculate_d4000
calculate_beta = calculate_beta
calculate_balmer_decrement = calculate_balmer_decrement
calculate_line_flux = calculate_line_flux
calculate_line_ew = calculate_line_ew
calculate_sfh_quantile = calculate_sfh_quantile
calculate_surviving_mass = calculate_surviving_mass
calculate_xi_ion0 = calculate_xi_ion0
calculate_Ndot_ion = calculate_Ndot_ion
calculate_agn_fraction = calculate_agn_fraction
calculate_burstiness = calculate_burstiness
def __repr__(self) -> str:
"""Provides a nicely formatted string of available functions."""
header = "Available supplementary functions:\n"
# Introspect the class to find all callable attributes that don't start with '_'
method_names: List[str] = sorted(
[
attr
for attr in dir(self)
if not attr.startswith("_") and callable(getattr(self, attr))
]
)
# Format the list with bullet points
function_list = "\n".join(f" - {name}" for name in method_names)
return f"{header}{function_list}"
def __str__(self):
"""Provides a nicely formatted string of available functions."""
return self
# ------------------------------------------
[docs]
def generate_random_DB_sfh(
Nparam=3,
tx_alpha=5,
redshift=6,
logmass=8,
logsfr=-1,
seed: Optional[int] = None,
):
"""Generate a random Dense Basis SFH.
This function generates a random star formation history (SFH) using a dense basis
representation. It creates a star formation history with a specified number of
parameters, a Dirichlet distribution for the time series, and allows for customization
of the log mass and log SFR.
Parameters
----------
size : int, optional
Number of SFHs to generate, by default 1
Nparam : int, optional
Number of parameters for the SFH, by default 3
tx_alpha : float, optional
Concentration parameter for the Dirichlet distribution, by default 5
redshift : float, optional
Redshift of the SFH, by default 6
logmass : float, optional
Logarithm of the stellar mass in solar masses, by default 8
logsfr : float, optional
Logarithm of the star formation rate in solar masses per year, by default -1
Returns:
-------
SFH.DenseBasis
A DenseBasis SFH object with the generated parameters.
np.ndarray
An array of time series parameters for the SFH.
"""
if seed is not None:
np.random.seed(seed)
txs = np.cumsum(
np.random.dirichlet(np.ones((Nparam + 1,)) * tx_alpha, size=1),
axis=1,
)[0:, 0:-1][0]
db_tuple = (logmass, logsfr, Nparam, *txs)
sfh = SFH.DenseBasis(db_tuple, redshift)
return sfh, txs
[docs]
def generate_sfh_grid(
sfh_type: Type[SFH.Common],
sfh_priors: Dict[str, Dict[str, Any]],
redshift: Union[Dict[str, Any], float],
max_redshift: float = 15,
cosmo: Type[Cosmology] = Planck18,
) -> Tuple[List[Type[SFH.Common]], np.ndarray]:
"""Generate a grid of SFHs based on prior distributions.
This function creates a grid of SFH models by combining possible parameter
values, which can depend explicitly on the redshift. It first draws
redshifts, calculates maximum stellar ages at each
redshift, and then creates parameter combinations within valid ranges.
Parameters
----------
sfh_type : Type[SFH]
The star formation history class to instantiate
sfh_priors : Dict[str, Dict[str, Any]]
Dictionary of prior distributions for SFH parameters. Each parameter should have:
- 'prior': scipy.stats distribution
- 'min': minimum value
- 'max': maximum value
- 'size': number of samples to draw
- 'units': astropy unit (optional)
- 'name': parameter name for SFH constructor (optional, defaults to the key)
- 'depends_on': special flag if parameter depends on redshift ('max_redshift')
redshift : Union[Dict[str, Any], float]
Either a single redshift value or a dictionary with:
- 'prior': scipy.stats distribution
- 'min': minimum redshift
- 'max': maximum redshift
- 'size': number of redshift samples
max_redshift : float, optional
Maximum possible redshift to consider for age calculations, by default 15
cosmo : Type[Cosmology], optional
Cosmology to use for age calculations, by default Planck18
Returns:
-------
Tuple[List[SFH], np.ndarray]
- List of SFH objects with parameters drawn from the priors
- Array of parameter combinations, where the first column is redshift followed
by SFH parameters
Notes:
-----
For parameters that depend on redshift (marked with 'depends_on': 'max_redshift'),
the maximum allowed value will be dynamically adjusted based on
the age of the universe at that redshift.
"""
# Draw redshifts
if isinstance(redshift, dict):
redshifts = redshift["prior"].rvs(
size=int(redshift["size"]),
loc=redshift["min"],
scale=redshift["max"] - redshift["min"],
)
else:
redshifts = np.array([redshift])
# Calculate maximum ages at each redshift
max_ages = (cosmo.age(redshifts) - cosmo.age(max_redshift)).to(u.Myr).value
# Prepare parameter arrays for each parameter
param_arrays = [redshifts]
param_names = ["redshift"] # Track parameter names for later use
# Process each SFH parameter
for key in sfh_priors.keys():
param_data = sfh_priors[key]
param_size = int(param_data["size"])
param_min = param_data["min"]
param_max = param_data["max"]
# If parameter depends on redshift, adjust for each redshift
if "depends_on" in param_data and param_data["depends_on"] == "max_redshift":
# Create parameter values for each redshift
all_values = []
for z, max_age in zip(redshifts, max_ages):
# Adjust maximum value based on the age of the universe at this redshift
adjusted_max = min(param_max, max_age)
if "units" in param_data and param_data["units"] is not None:
adjusted_max = adjusted_max * param_data["units"]
if hasattr(adjusted_max, "value"):
adjusted_max = adjusted_max.value
# Draw values for this parameter at this redshift
values = param_data["prior"].rvs(
size=param_size // len(redshifts), # Distribute samples across redshifts
loc=param_min,
scale=adjusted_max - param_min,
)
all_values.append(values)
# Combine values from all redshifts
param_values = np.concatenate(all_values)
else:
# Parameter doesn't depend on redshift, draw from the same distribution
param_values = param_data["prior"].rvs(
size=param_size, loc=param_min, scale=param_max - param_min
)
param_arrays.append(param_values)
param_names.append(param_data.get("name", key)) # Use specified name or key
# Create parameter combinations using meshgrid
mesh_arrays = np.meshgrid(*param_arrays, indexing="ij")
param_combinations = np.stack([arr.flatten() for arr in mesh_arrays], axis=1)
# Create SFH objects for each parameter combination
sfhs = []
for params in tqdm(param_combinations, desc="Generating SFHs", disable=(rank != 0)):
z = params[0] # First parameter is always redshift
max_age = (cosmo.age(z) - cosmo.age(max_redshift)).to(u.Myr).value
# Create parameter dictionary for SFH constructor
sfh_params = {param_names[i + 1]: params[i + 1] for i in range(len(param_names) - 1)}
# Apply units if not None
for key, value in sfh_params.items():
if "units" in sfh_priors[key] and sfh_priors[key]["units"] is not None:
sfh_params[key] = value * sfh_priors[key]["units"]
# Add max_age parameter
sfh_params["max_age"] = max_age * Myr
# Create and append SFH instance
sfh = sfh_type(**sfh_params)
sfhs.append(sfh)
return sfhs, param_combinations
[docs]
def generate_emission_models(
emission_model: Type[EmissionModel],
varying_params: dict,
grid: Grid,
fixed_params: dict = None,
):
"""Generate a grid of emission models based on varying parameters.
Parameters
----------
emission_model : Type[EmissionModel]
The emission model class to instantiate. E.g., TotalEmission or PacmanEmission
varying_params : dict
Dictionary of varying parameters for emission model. Each parameter should have:
- 'prior': scipy.stats distribution
- 'min': minimum value
- 'max': maximum value
- 'size': number of samples to draw
- 'units': unyt unit (optional)
- 'name': parameter name for constructor (optional, defaults to the key)
grid : Grid
Grid object containing the SPS grid.
fixed_params : dict, optional
Dictionary of fixed parameters for the emission model. Each parameter should have:
- 'value': fixed value
- 'units': unyt unit (optional)
- 'name': parameter name for constructor (optional, defaults to the key)
"""
# Create parameter combinations using meshgrid
varying_param_arrays = []
for key, param_data in varying_params.items():
args = {}
args.update(param_data)
if "min" in args:
args["loc"] = param_data["min"]
if "max" in args:
args["scale"] = param_data["max"] - param_data["min"]
if "units" in args:
args.pop("units")
if "name" in args:
args.pop("name")
if "prior" in args:
args.pop("prior")
# Draw values for this parameter
param_values = param_data["prior"].rvs(
**args,
)
varying_param_arrays.append(param_values)
# Create parameter combinations using meshgrid
varying_param_mesh = np.meshgrid(*varying_param_arrays, indexing="ij")
varying_param_combinations = np.stack([arr.flatten() for arr in varying_param_mesh], axis=1)
# Create emission model objects for each parameter combination
emission_models = []
out_params = {}
for i, params in tqdm(
enumerate(varying_param_combinations),
desc="Generating Emission Models",
disable=(rank != 0),
):
# Create parameter dictionary for emission model constructor
emission_params = {key: params[j] for j, key in enumerate(varying_params.keys())}
# Apply units if not None
for key, value in emission_params.items():
if "units" in varying_params[key] and varying_params[key]["units"] is not None:
emission_params[key] = value * varying_params[key]["units"]
if key not in out_params:
out_params[key] = []
out_params[key].append(emission_params[key])
# store value of varying parameter(s) in dictionary for this emission model
if fixed_params is None:
fixed_params = {}
emission_params.update(fixed_params)
# Create and append emission model instance
emission_model_instance = emission_model(grid=grid, **emission_params)
emission_models.append(emission_model_instance)
return emission_models, out_params
[docs]
def draw_from_hypercube(
param_ranges,
N: int = 1e6,
model: Type[qmc.QMCEngine] = qmc.LatinHypercube,
rng: Optional[np.random.Generator] = None,
unlog_keys: Optional[List[str]] = None,
):
"""Draw N samples from a hypercube defined by the parameter ranges.
Parameters
----------
N : int
Number of samples to draw.
param_ranges : dict, optional
Dictionary where keys are parameter names and values are tuples (min, max).
Can be unyt_quantities for units.
model : Type[qmc], optional
The sampling model to use, by default LatinHypercube.
rng : Optional[np.random.Generator], optional
Random number generator to use for sampling, by default None.
unlog_keys : Optional[List[str]], optional
List of keys in param_ranges that should be unlogged
(i.e., raised to power of 10). Units will be preserved
even if this doesn't really make sense (e.g. Msun).
Returns:
-------
dict
Dictionary where keys are parameter names and values are arrays of sampled values.
-------
"""
if unlog_keys is None:
unlog_keys = []
# check if model takes 'rng' or 'seed' as an argument
if "rng" in inspect.signature(model).parameters:
key = "rng"
elif "seed" in inspect.signature(model).parameters:
key = "seed"
else:
raise ValueError("The model must accept either 'rng' or 'seed' as an argument.")
rng_ = {key: rng} if rng is not None else {}
# Create a Latin Hypercube sampler
sampler = model(d=len(param_ranges), **rng_)
# Generate samples in the unit hypercube
sample = sampler.random(int(N))
low = [param_ranges[key][0] for key in param_ranges.keys()]
high = [param_ranges[key][1] for key in param_ranges.keys()]
units = []
for pos, (i, j) in enumerate(zip(low, high)):
assert i < j, f"Parameter range {i} must be less than {j}"
if isinstance(i, unyt_quantity):
# If the parameter is a unyt_quantity, extract the value and unit
units.append(i.units)
i = i.value
j = j.value
low[pos] = i
high[pos] = j
else:
# Otherwise, just use the value
units.append(None)
# Scale samples to the specified ranges
scaled_samples = qmc.scale(
sample,
np.array(low),
np.array(high),
)
all_param_dict = {}
for i, key in enumerate(param_ranges.keys()):
samples = scaled_samples[:, i].astype(np.float32)
if key in unlog_keys:
samples = 10**samples
# If the key is in unlog_keys, raise to power of 10
key = key.replace("log_", "") # Remove 'log_' prefix if present
if units[i] is not None:
# If the parameter has units, convert samples to unyt_quantity
samples = unyt_array(samples, units=units[i])
if np.any(~np.isfinite(samples)):
raise ValueError(
f"Non-finite values found in samples for parameter '{key}'. "
"Check the parameter ranges and ensure they are valid."
)
all_param_dict[key] = samples
return all_param_dict
[docs]
def load_hypercube_from_npy(file_path: str):
"""Load a hypercube from a .npy file.
Parameters
----------
file_path : str
Path to the .npy file containing the hypercube data.
Returns:
-------
np.ndarray
Array of shape (N, M) containing the loaded hypercube data.
"""
# Load the hypercube data from the .npy file
hypercube = np.load(file_path)
return hypercube.astype(np.float32)
[docs]
def generate_sfh_basis(
sfh_type: Type[SFH.Common],
sfh_param_names: List[str],
sfh_param_arrays: List[np.ndarray],
redshifts: Union[Dict[str, Any], float, np.ndarray],
sfh_param_units: List[Union[None, Unit]] = None,
max_redshift: float = 20,
calculate_min_age: bool = False,
min_age_frac=0.001,
cosmo: Type[Cosmology] = Planck18,
iterate_redshifts: bool = False,
) -> Tuple[List[Type[SFH.Common]], np.ndarray]:
"""Generate a grid of SFHs based on parameter arrays and redshifts.
Parameters
----------
sfh_type : Type[SFH]
The star formation history class to instantiate
sfh_param_names : List[str]
List of parameter names for SFH constructor
sfh_param_arrays : List[np.ndarray]
List of parameter arrays for SFH constructor.
Should have the same length in the first dimension as sfh_param_names.
if values are lambda functions the input will be the max age given max_redshift.
redshifts : Union[Dict[str, Any], float]
Either a single redshift value, an array of redshifts, or a dictionary with:
'prior': scipy.stats distribution
'min': minimum redshift
'max': maximum redshift
'size': number of redshift samples
max_redshift : float, optional
Maximum possible redshift to consider for age calculations, by default 15
cosmo : Type[Cosmology], optional
Cosmology to use for age calculations, by default Planck18
calculate_min_age : bool, optional
If True, calculate the lookback time at which only min_age_frac of total mass
is formed, by default True
min_age_frac : float, optional
Fraction of total mass formed to calculate the minimum age, by default 0.001
iterate_redshifts : bool, optional
If True, iterate over redshifts and create SFH for each, by default True
If False, assume input redshift SFH param array is a 1:1 mapping of
redshift to SFH parameters.
Returns:
-------
Tuple[List[SFH], np.ndarray]
List of SFH objects with parameters drawn from the priors
Array of parameter combinations, where the first column is redshift
followed by SFH parameters
"""
if isinstance(redshifts, dict):
redshifts = redshifts["prior"].rvs(
size=int(redshifts["size"]),
loc=redshifts["min"],
scale=redshifts["max"] - redshifts["min"],
)
elif isinstance(redshifts, (float, int)):
redshifts = np.array([redshifts])
if not iterate_redshifts:
# extend redshifts to match the length of sfh_param_arrays
redshifts = np.full(len(sfh_param_arrays[0]), redshifts)
elif isinstance(redshifts, np.ndarray):
pass
else:
raise ValueError("redshifts must be a dictionary, float/int, or numpy array")
# Calculate maximum ages at each redshift
max_ages = (cosmo.age(redshifts) - cosmo.age(max_redshift)).to(u.Myr).value
sfhs = []
if sfh_param_units is None:
# If no units are provided, assume all parameters are dimensionless
sfh_param_units = [None] * len(sfh_param_names)
all_redshifts = []
param_names_i = [i.replace("_norm", "") for i in sfh_param_names]
if isinstance(sfh_param_arrays, tuple):
sfh_param_arrays = list(sfh_param_arrays)
if isinstance(sfh_param_arrays, (list, np.ndarray)):
for pos, (set_unit, param) in enumerate(zip(sfh_param_units, sfh_param_arrays)):
if isinstance(param, unyt_array):
# If the parameter is a unyt_array, extract the unit
sfh_param_units[pos] = param.units
# Convert to numpy array if needed
sfh_param_arrays[pos] = param.value
elif isinstance(param, (list, tuple)):
# If it's a list or tuple, convert to numpy array
sfh_param_arrays[pos] = np.array(param)
assert isinstance(sfh_param_units[pos], (Unit, type(None)))
if len(sfh_param_names) == len(sfh_param_arrays):
sfh_param_arrays = np.vstack(sfh_param_arrays).T
if iterate_redshifts:
for i, redshift in enumerate(redshifts):
params = copy.deepcopy(sfh_param_arrays)
for j, row in enumerate(params):
row_params = {}
for k, param in enumerate(row):
# Check if the parameter is a function
if callable(param):
# Call the function with the maximum age
row_params[k] = param(max_ages[i])
else:
# Otherwise, just use the parameter value
row_params[k] = param
if sfh_param_units[k] is not None:
# Apply units if not None
row_params[k] = row_params[k] * sfh_param_units[k]
# Create parameter dictionary for SFH constructor
sfh_params = {
sfh_param_names[sf]: row_params[sf] for sf in range(len(sfh_param_names))
}
if "sfh_timescale" in sfh_param_names:
sfh_params["max_age"] = sfh_params["min_age"] + sfh_params["sfh_timescale"]
sfh_params.pop("sfh_timescale")
# Add max_age parameter
if "max_age" in sfh_param_names:
sfh_params["max_age"] = min(max_ages[i] * Myr, sfh_params["max_age"])
else:
sfh_params["max_age"] = max_ages[i] * Myr
# Create and append SFH instance
sfh = sfh_type(**sfh_params)
sfh.redshift = redshift
sfhs.append(sfh)
all_redshifts.append(redshift)
else:
assert len(redshifts) == len(
sfh_param_arrays
), """If iterate_redshifts is False, len(redshifts)
must equal len(sfh_param_arrays)"""
for i, redshift in tqdm(enumerate(redshifts), desc="Creating SFHs"):
params = copy.deepcopy(sfh_param_arrays[i])
row_params = {}
for j, param in enumerate(params):
# Check if the parameter is a function
if callable(param):
# Call the function with the maximum age
row_params[j] = param(max_ages[i])
elif sfh_param_names[j].endswith("_norm"):
# If the parameter is normalized to max age, multiply by max_age
row_params[j] = param * max_ages[i] * Myr
else:
# Otherwise, just use the parameter value
row_params[j] = param
if sfh_param_units[j] is not None:
# Apply units if not None
row_params[j] = row_params[j] * sfh_param_units[j]
# Create parameter dictionary for SFH constructor
sfh_params = {param_names_i[sf]: row_params[sf] for sf in range(len(param_names_i))}
# remove _norm from parameter names
if "sfh_timescale" in sfh_param_names:
sfh_params["max_age"] = sfh_params["min_age"] + sfh_params["sfh_timescale"]
sfh_params.pop("sfh_timescale")
# Add max_age parameter
if "max_age" in sfh_param_names:
sfh_params["max_age"] = min(max_ages[i] * Myr, sfh_params["max_age"])
else:
sfh_params["max_age"] = max_ages[i] * Myr
# Create and append SFH instance
sfh = sfh_type(**sfh_params)
sfh.redshift = redshift
sfhs.append(sfh)
all_redshifts.append(redshift)
if calculate_min_age:
# Calculate lookback time at which only min_age_frac of total mass is formed
min_ages = []
for i, sfh in enumerate(sfhs):
max_age = sfh.max_age
# Calculate the cumulative mass formed
age, sfr = sfh.calculate_sfh(t_range=(0, 1.1 * max_age), dt=1e6 * yr)
sfr = sfr / sfr.max()
mass_formed = np.cumsum(sfr[::-1])[::-1] / np.sum(sfr)
total_mass = mass_formed[0]
# Find the age at which min_age_frac of total mass is formed
# interpolate
min_age = np.interp(min_age_frac * total_mass, mass_formed, age)
min_ages.append(min_age)
return np.array(sfhs), redshifts
# from joblib import delayed, Parallel, wrap_non_picklable_objects, parallel_config
[docs]
def create_galaxy(
sfh: Type[SFH.Common],
redshift: float,
metal_dist: Type[ZDist.Common],
grid: Grid,
log_stellar_masses: Union[float, list] = 9,
**galaxy_kwargs,
) -> Type[Galaxy]:
"""Create a new galaxy with the specified parameters."""
# Initialise the parametric Stars object
assert not isinstance(log_stellar_masses, (unyt_array, unyt_quantity)), (
"log_stellar_masses must be a float or list of floats, not a unyt array"
)
single_mass = (
log_stellar_masses[0] if isinstance(log_stellar_masses, (list)) else log_stellar_masses
)
single_mass = 10**single_mass * Msun
# If there are any black hole or gas kwargs, pop them out of galaxy_kwargs
bh_kwargs = {}
gas_kwargs = {}
for key in list(galaxy_kwargs.keys()):
if key.startswith("bh_"):
new_key = key.replace("bh_", "")
bh_kwargs[new_key] = galaxy_kwargs.pop(key)
elif key.startswith("gas_"):
new_key = key.replace("gas_", "")
gas_kwargs[new_key] = galaxy_kwargs.pop(key)
param_stars = Stars(
log10ages=grid.log10ages,
metallicities=grid.metallicity,
sf_hist=sfh,
metal_dist=metal_dist,
initial_mass=single_mass,
**galaxy_kwargs, # most parameters want to be on the emitter
)
# Define the number of stellar particles
n = len(log_stellar_masses) if isinstance(log_stellar_masses, list) else 1
if n > 1:
# Sample the parametric SFZH to create "fake" stellar particles
part_stars = sample_sfzh(
sfzh=param_stars.sfzh,
log10ages=np.log10(param_stars.ages),
log10metallicities=np.log10(param_stars.metallicities),
nstar=n,
current_masses=np.power(10.0, np.asarray(log_stellar_masses, dtype=float)) * Msun,
redshift=redshift,
coordinates=np.random.normal(0, 0.01, (n, 3)) * Mpc,
centre=np.zeros(3) * Mpc,
)
part_stars.__dict__.update(
galaxy_kwargs
) # Add any additional parameters to the stars object
from synthesizer.particle import Galaxy
else:
from synthesizer.parametric import Galaxy
part_stars = param_stars
if len(bh_kwargs) > 0:
from synthesizer.parametric import BlackHole
bh = BlackHole(**bh_kwargs)
# Dumb fix
bh.accretion_rates_eddington = bh.accretion_rate_eddington
else:
bh = None
if len(gas_kwargs) > 0:
from synthesizer.parametric import Gas
gas = Gas(**gas_kwargs)
else:
gas = None
# And create the galaxy
galaxy = Galaxy(stars=part_stars, redshift=redshift, gas=gas, black_holes=bh)
return galaxy
def _init_worker(grid, alt_parametrizations, fixed_params):
"""Initialize worker process with shared data."""
_thread_local.grid = grid
_thread_local.alt_parametrizations = alt_parametrizations
_thread_local.fixed_params = fixed_params
def _process_galaxy_batch(galaxy_indices_and_data):
"""Process a batch of galaxies in a single worker."""
# Access shared data from thread-local storage
grid = _thread_local.grid
alt_parametrizations = _thread_local.alt_parametrizations
base_params = _thread_local.fixed_params
galaxies = []
for galaxy_idx, galaxy_data in galaxy_indices_and_data:
# Reconstruct minimal parameters for this galaxy
params = base_params.copy() # Only copy once per batch item
params.update(galaxy_data["varying_params"])
save_params = copy.deepcopy(params)
# Create galaxy
gal = create_galaxy(
sfh=galaxy_data["sfh"],
redshift=galaxy_data["redshift"],
metal_dist=galaxy_data["metal_dist"],
log_stellar_masses=galaxy_data["log_stellar_mass"],
grid=grid, # Use shared reference
**params,
)
# Process parameters (reuse logic from original)
save_params["redshift"] = galaxy_data["redshift"]
try:
save_params.update(galaxy_data["sfh"].parameters)
except AttributeError:
# This occurs if SFH is just a burst e.g. a single value
save_params["sfh"] = galaxy_data["sfh"]
pass
try:
save_params.update(galaxy_data["metal_dist"].parameters)
except AttributeError:
save_params["metal_dist"] = galaxy_data["metal_dist"]
pass
# Apply alternative parametrizations
if len(alt_parametrizations) > 0:
to_remove = set()
for key, (new_key, func) in alt_parametrizations.items():
if key in save_params:
if isinstance(new_key, str):
save_params[new_key] = func(save_params)
to_remove.add(key)
elif isinstance(new_key, (list, tuple)):
for k in new_key:
save_params[k] = func(k, save_params)
to_remove.add(key)
for key in to_remove:
save_params.pop(key, None)
gal.all_params = save_params
galaxies.append((galaxy_idx, gal))
return galaxies
[docs]
class GalaxyBasis:
"""Class to create a basis of galaxies with different SFHs, redshifts, and parameters.
It can support two modes of operation.
The first case in the simplest - you have some set of priors controlling e.g. mass,
redshift, SFH parameters, and metallicity distribution parameters, which you draw
randomly or from a Latin hypercube, and then you create a galaxy for each set of
parameters. So if you are drawing 1000 galaxies, you would pass in 1000 redshifts,
1000 SFHs, 1000 metallicity distributions, and 1000 sets of galaxy parameters.
If however you are sampling on a grid, or want dependent priors, you can instead
pass in a few basis SFHs, redshifts, and metallicity distributions, and then
the class will generate a library of galaxies by combining every combination of
SFH, redshift, and metallicity distributions. This can reduce the number of
galaxies created.
"""
def __init__(
self,
model_name: str,
redshifts: np.ndarray,
grid: Grid,
emission_model: Type[EmissionModel],
sfhs: List[Type[SFH.Common]],
metal_dists: List[Type[ZDist.Common]],
log_stellar_masses: Optional[np.ndarray] = None,
galaxy_params: dict = None,
alt_parametrizations: Dict[str, Tuple[str, callable]] = None,
cosmo: Type[Cosmology] = Planck18,
instrument: Instrument = None,
redshift_dependent_sfh: bool = False,
params_to_ignore: List[str] = None,
build_library: bool = False,
) -> None:
"""Initialize the GalaxyBasis object with SFHs, redshifts, and other parameters.
Parameters
----------
sfhs : List[Type[SFH.Common]]
List of SFH objects.
redshifts : np.ndarray
Array of redshift values.
grid : Grid
Grid object containing the SPS grid.
emission_model : Type[EmissionModel]
Emission model class to instantiate.
emission_model_params : dict
Dictionary of parameters for the emission model.
galaxy_params : dict
Dictionary of parameters for the galaxy.
alt_parametrizations : dict
Dictionary of alternative parametrizations for the galaxy parameters -
for parametrizing differently to Synthesizer if wanted. Should be a dictionary
with keys as the parameter names and values as tuples of the new parameter
name and a function which takes the parameter dictionary and returns the new
parameter value (so it can be calculated from the other parameters if needed).
metal_dists : List[Type[ZDist.Common]], optional
List of metallicity distribution objects, by default None
log_stellar_masses : Optional[np.ndarray], optional
Array of logarithmic stellar masses in solar masses, by default None.
cosmo : Type[Cosmology], optional
Cosmology object, by default Planck18
instrument : Instrument, optional
Instrument object containing the filters, by default None
redshift_dependent_sfh : bool, optional
If True, the SFH will depend on redshift, by default False. If True, expect
each SFH to have a redshift attribute.
params_to_ignore : List[str], optional
List of parameters to ignore as being different when calculating
which parameters are varying.
E.g. max_age may be dependent on redshift, so we don't want to include it
in the varying parameters as the model can learn this.
build_library : bool, optional
If True, build the library of galaxies, by default False.
If False, assume all dimensions of parameters are the same size and
build the library from the parameters. I.e don't generate combinations of
parameters, just use the parameters as they are.
"""
if galaxy_params is None:
galaxy_params = {}
if alt_parametrizations is None:
alt_parametrizations = {}
if params_to_ignore is None:
params_to_ignore = []
if isinstance(redshifts, (float, int)) and not build_library:
redshifts = np.full(len(sfhs), redshifts)
self.model_name = model_name
self.sfhs = sfhs
self.redshifts = redshifts
self.grid = grid
self.emission_model = emission_model
self.galaxy_params = galaxy_params
self.alt_parametrizations = alt_parametrizations
self.metal_dists = metal_dists
self.cosmo = cosmo
self.instrument = instrument
self.redshift_dependent_sfh = redshift_dependent_sfh
self.log_stellar_masses = log_stellar_masses
self.params_to_ignore = params_to_ignore
self.build_library = build_library
self.galaxies = []
if isinstance(self.metal_dists, ZDist.Common):
self.metal_dists = [self.metal_dists]
if isinstance(self.sfhs, SFH.Common):
self.sfhs = [self.sfhs]
# if self.stellar_masses is not None:
# assert isinstance(self.stellar_masses, (unyt_array, unyt_quantity)), (
# "stellar_masses must be a unyt array or quantity"
# )
self.per_particle = False
# Check if any galaxy parameters are dictionaries with keys like 'prior', 'min'
for key in list(galaxy_params.keys()):
value = galaxy_params[key]
if isinstance(value, dict):
if key == "bh" or key == "gas":
temp_dict = galaxy_params.pop(key)
for subkey, subvalue in temp_dict.items():
galaxy_params[f"{key}_{subkey}"] = subvalue
else:
# If the value is a dictionary, process it as a prior
self.galaxy_params[key] = self.process_priors(value)
if not build_library:
logger.info("Generating library directly from provided parameter samples.")
elif self.redshift_dependent_sfh:
# Check if the SFHs have a redshift attribute
for sfh in self.sfhs:
if not hasattr(sfh, "redshift"):
raise ValueError(
"SFH must have a redshift attr if redshift_dependent_sfh==True"
)
if not isinstance(self.redshifts, np.ndarray):
self.redshifts = np.array(self.redshifts)
# Check all redshifts are in self.redshifts
for sfh in self.sfhs:
if sfh.redshift not in self.redshifts:
raise ValueError(f"SFH redshift {sfh.redshift} not in redshifts array")
# Make self.SFHs a dictionary with redshift as key
sfh_dict = {}
for sfh in self.sfhs:
if sfh.redshift not in sfh_dict:
sfh_dict[sfh.redshift] = []
sfh_dict[sfh.redshift].append(sfh)
self.sfhs = sfh_dict
else:
self.sfhs = {z: self.sfhs for z in self.redshifts}
[docs]
def process_priors(
self,
prior_dict: Dict[str, Any],
) -> unyt_array:
"""Process priors from dictionary.
Parameters
----------
prior_dict : Dict[str, Any]
Dictionary containing prior information. Must contain:
- 'prior': scipy.stats distribution
- 'size': number of samples to draw
- 'units': unyt unit (optional)
- other parameters required by the distribution
"""
assert isinstance(prior_dict, dict), "prior_dict must be a dictionary"
assert "prior" in prior_dict, "prior_dict must contain a 'prior' key"
assert "size" in prior_dict, "prior_dict must contain a 'size' key"
stats_params = list_parameters(prior_dict["prior"])
# Check required parameters are present
params = {}
for param in stats_params:
assert param in prior_dict, f"prior_dict must contain a '{param}' key"
params[param] = prior_dict[param]
# draw values for this parameter
values = prior_dict["prior"].rvs(size=int(prior_dict["size"]), **params)
if "units" in prior_dict and prior_dict["units"] is not None:
values = unyt_array(values, units=prior_dict["units"])
return values
def _create_galaxies(
self,
log_base_masses: 9,
) -> List[Type[Galaxy]]:
"""Create galaxies with the specified SFHs, redshifts, and other parameters.
Parameters
----------
log_base_masses : Union[float, np.ndarray], optional
Base mass (or array of base masses) to use for the galaxies.
Units of log10 M sun.
Default mass (or mass array) to use for the galaxies.
Returns:
-------
List[Type[Galaxy]]
List of Galaxy objects.
"""
if not self.build_library:
raise ValueError("You probably meant to call_create_matched_galaxies instead.")
varying_param_values = [
i for i in self.galaxy_params.values() if isinstance(i, (list, np.ndarray))
]
if isinstance(log_base_masses, (list, np.ndarray)):
self.per_particle = True
# generate all combinations of the varying parameters
if len(varying_param_values) == 0:
param_list = [{}]
fixed_params = self.galaxy_params
else:
varying_param_combinations = np.array(np.meshgrid(*varying_param_values)).T.reshape(
-1, len(varying_param_values)
)
column_names = [
i
for i, j in zip(self.galaxy_params.keys(), varying_param_values)
if isinstance(j, (list, np.ndarray))
]
fixed_params = {
key: value
for key, value in self.galaxy_params.items()
if not isinstance(value, (list, np.ndarray))
}
param_list = [
{column_names[i]: j for i, j in enumerate(row)}
for row in varying_param_combinations
]
galaxies = []
all_parameters = {}
for i, redshift in tqdm(
enumerate(self.redshifts),
desc=f"Creating {self.model_name} galaxies",
total=len(self.redshifts),
disable=(rank != 0),
):
# get the sfh for this redshift
sfh_models = self.sfhs[redshift]
for sfh_model in sfh_models:
sfh_parameters = sfh_model.parameters
for k, Z_dist in enumerate(self.metal_dists):
Z_parameters = Z_dist.parameters
# Create a new galaxy with the specified parameters
for params in param_list:
params.update(fixed_params)
save_params = copy.deepcopy(params)
gal = create_galaxy(
sfh=sfh_model,
redshift=redshift,
metal_dist=Z_dist,
log_stellar_masses=log_base_masses,
grid=self.grid,
**params,
)
save_params["redshift"] = redshift
save_params.update(sfh_parameters)
save_params.update(Z_parameters)
if len(self.alt_parametrizations) > 0:
to_remove = []
# Apply alternative parametrizations if provided
for key, (
new_key,
func,
) in self.alt_parametrizations.items():
if isinstance(new_key, str):
save_params[new_key] = func(save_params)
to_remove.append(key)
elif isinstance(new_key, (list, tuple)):
for k in new_key:
save_params[k] = func(k, save_params)
to_remove.append(key)
for key in to_remove:
save_params.pop(key)
# This stores all input parameters for the galaxy
# so we can work out which parameters
# are varying and which are fixed later.
gal.all_params = save_params
# add all_parameters to dictionary if that key doesn't exist
for key, value in save_params.items():
if key not in all_parameters:
all_parameters[key] = []
if value not in all_parameters[key]:
all_parameters[key].append(value)
else:
pass
galaxies.append(gal)
self.galaxies = galaxies
# Remove any paremters which are just [None]
to_remove = []
fixed_param_names = []
fixed_param_values = []
fixed_param_units = []
varying_param_names = []
for key, value in all_parameters.items():
if len(value) == 1 and value[0] is None:
to_remove.append(key)
continue
# check if all values are the same.
if len(np.unique(value)) == 1:
fixed_param_names.append(key)
fixed_param_units.append(
str(value[0].units) if isinstance(value[0], unyt_quantity) else ""
)
fixed_param_values.append(value[0])
else:
varying_param_names.append(key)
for param in self.params_to_ignore:
if param in varying_param_names:
varying_param_names.remove(param)
# Sanity check all varying parameters combinations on self.galaxies are unique
logger.info("Checking parameters are unique.")
hashes = []
for gal in self.galaxies:
relevant_params = {
key: gal.all_params[key] for key in varying_param_names if key in gal.all_params
}
# Calculate hash for each parameter and sum them
hash_i = sum(
hash(float(param.value))
if isinstance(param, (unyt_array, unyt_quantity))
else hash(param)
for param in relevant_params.values()
)
hashes.append(hash_i)
if len(hashes) != len(set(hashes)):
raise ValueError(
"""Varying parameters are not unique across galaxies.
Check your input parameters."""
)
self.varying_param_names = varying_param_names
self.fixed_param_names = fixed_param_names
self.fixed_param_values = fixed_param_values
self.fixed_param_units = fixed_param_units
self.all_parameters = all_parameters
for key in to_remove:
all_parameters.pop(key)
logger.info("Finished creating galaxies.")
return self.galaxies
def _check_model_simplicity(self, parameter_transforms_to_save=None, verbose=True) -> bool:
"""Check if the model is simple enough to be stored in a file.
Checks include:
All SFHs are the same underlying SFH class.
All metallicity distributions are the same underlying ZDist class.
Single emission model class.
We understand emitter parameters and can store them.
We can store the library path and filter names.
We can serialize alt_parametrizations.
If the model is simple enough, we can store in HDF5 and
allow creation of a GalaxySimulator from it.
"""
accept = True
sfh_classes = set(type(sfh) for sfh in self.sfhs)
if len(sfh_classes) > 1:
if verbose:
logger.warning(
f"SFH classes are not all the same: {sfh_classes}. Cannot store model."
)
accept = False
# Flatten SFHs regardless of storage (list or {z: [sfh,...]})
if isinstance(self.sfhs, dict):
_sfh_list = []
for v in self.sfhs.values():
_sfh_list.extend(v if isinstance(v, (list, tuple)) else [v])
else:
_sfh_list = self.sfhs
sfh_classes = set(type(sfh) for sfh in _sfh_list)
_sfh0 = _sfh_list[0]
if type(_sfh0).__name__ not in SFH.parametrisations and not isinstance(
_sfh0, (unyt_quantity, float, int)
):
if verbose:
logger.warning(
f"SFH class {type(self.sfhs[0]).__name__} is not in SFH.parametrisations. ({SFH.parametrisations})" # noqa: E501
"Cannot store model."
)
accept = False
metal_dist_classes = set(type(metal_dist) for metal_dist in self.metal_dists)
if len(metal_dist_classes) > 1:
if verbose:
logger.warning(
f"Metallicity distribution classes are not all the same: {metal_dist_classes}. " # noqa: E501
"Cannot store model."
)
accept = False
if type(self.metal_dists[0]).__name__ not in ZDist.parametrisations:
if verbose:
logger.warning(
f"Metallicity distribution class {type(self.metal_dists[0]).__name__} " # noqa: E501
"is not in ZDist.parametrisations. Cannot store model."
)
accept = False
if not isinstance(self.emission_model, EmissionModel):
if verbose:
logger.warning(
f"Emission model is not an instance of EmissionModel: {self.emission_model}. " # noqa: E501
"Cannot store model."
)
accept = False
# check emission model in synthesizer.emission_models.PREMADE_MODELS
from synthesizer.emission_models import PREMADE_MODELS
# These models aren't in PREMADE_MODELS for some reason
PREMADE_MODELS.extend(
[
"PacmanEmissionNoEscapedWithDust",
"PacmanEmissionWithEscapedNoDust",
"PacmanEmissionNoEscapedWithDust",
"PacmanEmissionWithEscapedWithDust",
"EmissionModel",
]
)
if type(self.emission_model).__name__ not in PREMADE_MODELS:
if verbose:
logger.warning(
f"Emission model {type(self.emission_model).__name__} is not in PREMADE_MODELS. " # noqa: E501
"Cannot store model."
)
accept = False
em_args = inspect.signature(type(self.emission_model)).parameters
forbidden_args = [
"nlr_grid",
"blr_grid",
"covering_fraction",
"covering_fraction_nlr",
"covering_fraction_blr",
"torus_emission_model",
"dust_curve_ism",
"dust_curve_birth",
"dust_emission_ism",
"dust_emission_birth",
]
for arg in forbidden_args:
if arg in em_args:
if verbose:
logger.warning(
f"Emission model {type(self.emission_model).__name__} has forbidden argument '{arg}'. " # noqa: E501
"Cannot store model."
)
accept = False
# can we we convert functions in alt_parametrizations to strings?
# and load with ast.literal_eval?
if parameter_transforms_to_save is not None:
for key, value in parameter_transforms_to_save.items():
# value should be (str, callable) or (list/tuple, callable)
if isinstance(value, tuple) or callable(value):
if callable(value):
value = (key, value)
if len(value) != 2 or not callable(value[1]):
accept = False
else:
try:
# Check if we can get the source code of the function
getsource(value[1])
except Exception:
if verbose:
logger.warning(
f"Cannot serialize function for alt_parametrization '{key}': {value[1]}" # noqa: E501
)
accept = False
else:
accept = False
return accept
def _store_model(
self,
model_path: str,
overwrite=False,
group: str = "Model",
other_info: dict = None,
parameter_transforms_to_save=None,
) -> bool:
if not self._check_model_simplicity(parameter_transforms_to_save):
logger.warning("Model is too complex to be stored in a file.")
return False
# if not overwrite, append to existing file
if os.path.exists(model_path) and not overwrite:
open_mode = "a"
else:
open_mode = "w"
with h5py.File(model_path, open_mode) as f:
if group in f:
if overwrite:
del f[group]
else:
logger.warning(f"Group {group} already exists in {model_path}.")
return False
base = f.create_group(group)
# store grid_name and grid_dir
base.attrs["grid_name"] = self.grid.grid_name
base.attrs["grid_dir"] = self.grid.grid_dir
# store emission model class name
em_group = base.create_group("EmissionModel")
em_group.attrs["name"] = type(self.emission_model).__name__
# store emission model parameters
em_model_params = save_emission_model(self.emission_model)
try:
em_group.attrs["parameter_keys"] = em_model_params["fixed_parameter_keys"]
em_group.attrs["parameter_values"] = em_model_params["fixed_parameter_values"]
# if it can be an int or float, store as such
em_group.attrs["parameter_units"] = em_model_params["fixed_parameter_units"]
if em_model_params["dust_law"] is not None:
em_group.attrs["dust_law"] = em_model_params["dust_law"]
em_group.attrs["dust_attenuation_keys"] = em_model_params["dust_attenuation_keys"]
em_group.attrs["dust_attenuation_values"] = em_model_params[
"dust_attenuation_values"
]
em_group.attrs["dust_attenuation_units"] = em_model_params["dust_attenuation_units"]
if em_model_params["dust_emission"] is not None:
em_group.attrs["dust_emission"] = em_model_params["dust_emission"]
em_group.attrs["dust_emission_keys"] = em_model_params["dust_emission_keys"]
em_group.attrs["dust_emission_values"] = em_model_params["dust_emission_values"]
em_group.attrs["dust_emission_units"] = em_model_params["dust_emission_units"]
except Exception as e:
print(em_model_params["dust_emission_keys"], em_model_params["dust_emission_values"])
logger.critical(f'Error saving emission model to library: {e}, {em_model_params}')
# Store a version of astropy cosmo
cosmo_yaml = self.cosmo.to_format("yaml")
base.attrs["cosmology"] = cosmo_yaml
# store instrument.filter_codes
base.attrs["instrument"] = self.instrument.label if self.instrument else None
if self.instrument:
base.attrs["filters"] = self.instrument.filters.filter_codes
instrument_group = base.create_group("Instrument")
self.instrument.to_hdf5(instrument_group)
# store sfh class
base.attrs["sfh_class"] = type(self.sfhs[0]).__name__
# store metallicity distribution class
base.attrs["metallicity_distribution_class"] = type(self.metal_dists[0]).__name__
base.attrs["model_name"] = self.model_name
if other_info is None:
other_info = {}
for key, value in other_info.items():
if isinstance(value, (list, np.ndarray)):
# Store as a dataset
base.create_dataset(key, data=np.array(value))
# if a unyt quantity, store the units as an attribute
if isinstance(value, unyt_array):
base[key].attrs["units"] = str(value.units)
else:
# Store as an attribute
base.attrs[key] = value
base.attrs["stellar_params"] = list(self.galaxy_params.keys())
if parameter_transforms_to_save is not None:
transforms_group = base.create_group("Transforms")
for key, value in parameter_transforms_to_save.items():
if isinstance(value, tuple) or callable(value):
if callable(value):
value = (key, value)
# Store the new parameter name and the function as a string
transforms_group.create_dataset(
key, data=getsource(value[1]).encode("utf-8")
)
transforms_group[key].attrs["new_parameter_name"] = value[0]
# Store param_order
base.attrs["varying_param_names"] = self.varying_param_names
base.attrs["fixed_param_names"] = self.fixed_param_names
base.attrs["fixed_param_values"] = self.fixed_param_values
base.attrs["fixed_param_units"] = self.fixed_param_units
[docs]
def create_galaxy(
self,
sfh: Type[SFH.Common],
redshift: float,
metal_dist: Type[ZDist.Common],
log_stellar_masses: Union[float, list] = 9,
grid: Optional[Grid] = None,
**galaxy_kwargs,
) -> Type[Galaxy]:
"""Create a galaxy from parameters.
Parameters:
------------
sfh: SFH model class, e.g. SFH.LogNormal
redshift: redshift of the galaxy
metal_dist: metallicity distriution of the galaxy
log_stellar_masses: float or list of floats, log10 stellar mass in solar masses
grid: Grid object, if None use self.grid
galaxy_kwargs: additional keyword arguments to pass to the Galaxy class
Returns:
---------
Type[Galaxy]
Galaxy object created with the specified parameters.
"""
return create_galaxy(
sfh=sfh,
redshift=redshift,
metal_dist=metal_dist,
log_stellar_masses=log_stellar_masses,
grid=grid if grid is not None else self.grid,
**galaxy_kwargs,
)
[docs]
def create_galaxies_optimized(
self,
fixed_params,
varying_param_names,
log_base_masses,
galaxies_mask=None,
n_proc=28,
batch_size=None,
):
"""Optimized version with reduced serialization overhead."""
# Determine optimal batch size if not provided
if batch_size is None:
total_galaxies = len(self.sfhs)
if galaxies_mask is not None and len(galaxies_mask) > 0:
total_galaxies = np.sum(galaxies_mask)
batch_size = max(1, total_galaxies // (n_proc)) # 1 batch per thread
logger.info(total_galaxies)
# Prepare lightweight job inputs (grouped into batches)
job_batches = []
current_batch = []
for i in tqdm(range(len(self.sfhs)), desc="Preparing galaxy batches", disable=(rank != 0)):
# Skip this galaxy if the mask is provided and is False
if galaxies_mask is not None and len(galaxies_mask) > 0 and not galaxies_mask[i]:
continue
# Create minimal data structure (no large objects)
varying_params = {}
for key in varying_param_names:
varying_params[key] = self.galaxy_params[key][i]
# Determine the stellar mass for this galaxy
try:
mass = log_base_masses[i]
except (IndexError, TypeError):
mass = log_base_masses
galaxy_data = {
"sfh": self.sfhs[i],
"redshift": self.redshifts[i],
"metal_dist": self.metal_dists[i],
"log_stellar_mass": mass,
"varying_params": varying_params, # Only varying parameters
}
current_batch.append((i, galaxy_data))
# Create batch when it reaches desired size
if len(current_batch) >= batch_size:
job_batches.append(current_batch)
current_batch = []
# Add remaining galaxies to final batch
if current_batch:
job_batches.append(current_batch)
if n_proc > 1:
logger.info(f"Creating {len(job_batches)} batches across {n_proc} processes.")
logger.info(
f"Average batch size: {sum(len(batch) for batch in job_batches) / len(job_batches):.1f}" # noqa: E501
)
# Create a wrapper function that includes the shared data
def process_batch_with_shared_data(batch):
# Initialize thread-local data for this worker
_init_worker(self.grid, self.alt_parametrizations, fixed_params)
return _process_galaxy_batch(batch)
# Use threading backend
with parallel_config("threading"):
tasks = [delayed(process_batch_with_shared_data)(batch) for batch in job_batches]
# Process batches
batch_results = Parallel(n_jobs=n_proc)(
tqdm(tasks, desc="Creating galaxy batches", disable=(rank != 0))
)
# Flatten results and sort by original index
all_results = []
for batch_result in batch_results:
all_results.extend(batch_result)
# Sort by original galaxy index to maintain order
all_results.sort(key=lambda x: x[0])
self.galaxies = [galaxy for _, galaxy in all_results]
else:
# Sequential processing - initialize once for the main thread
_init_worker(self.grid, self.alt_parametrizations, fixed_params)
self.galaxies = []
for batch in tqdm(job_batches, desc="Creating galaxies", disable=(rank != 0)):
batch_galaxies = _process_galaxy_batch(batch)
self.galaxies.extend([galaxy for _, galaxy in batch_galaxies])
def _create_matched_galaxies(
self,
log_base_masses: Union[float, np.ndarray] = 9,
galaxies_mask: Optional[np.ndarray] = None,
n_proc: int = 1,
) -> List[Type[Galaxy]]:
"""Creates galaxies where all parameters have been sampled.
A galaxy is created for each SFH, redshift, and metallicity distribution
supplied.
Parameters
----------
log_base_masses : Union[float, np.ndarray], optional
Base mass (or array of base masses) to use for the galaxies.
n_procs : int, optional
Number of processes to use for parallel processing, by default 1.
Returns:
-------
List[Type[Galaxy]]
List of Galaxy objects created with the specified parameters.
"""
if len(self.metal_dists) == 1:
# Just reference the first one
self.metal_dists = [self.metal_dists[0]] * len(self.sfhs)
assert len(self.sfhs) == len(
self.redshifts
), f"""If iterate_redshifts is False, sfhs and redshifts must be the same length,
got {len(self.sfhs)} and {len(self.redshifts)}"""
assert len(self.sfhs) == len(
self.metal_dists
), f"""If iterate_redshifts is False, sfhs and metal_dists must be the same
length, got {len(self.sfhs)} and {len(self.metal_dists)}"""
if galaxies_mask is not None:
assert len(galaxies_mask) == len(self.sfhs), (
"galaxies_mask must be the same length as sfhs, redshifts, and metal_dists"
)
logger.info("Checking parameters inside create_matched_galaxies.")
varying_param_values = [
i for i in self.galaxy_params.values() if isinstance(i, (list, np.ndarray))
]
if len(varying_param_values) == 0:
fixed_params = self.galaxy_params
varying_param_names = []
else:
fixed_params = {
key: value
for key, value in self.galaxy_params.items()
if not isinstance(value, (list, np.ndarray))
}
varying_param_names = [
i for i in self.galaxy_params.keys() if i not in fixed_params.keys()
]
assert all(
len(self.galaxy_params[i]) == len(self.sfhs) for i in varying_param_names
), f"""All varying parameters must be the same length,
got {len(self.sfhs)} and {len(self.galaxy_params)}"""
# This was a side-effect in the original loop. We can detect it here
# before running the jobs.
if isinstance(log_base_masses, (list, np.ndarray)) and isinstance(
log_base_masses[0], (list, np.ndarray)
):
self.per_particle = True
"""
job_inputs: List[Dict[str, Any]] = []
for i in tqdm(range(len(self.sfhs)), desc="Batching galaxy inputs", disable=(rank != 0)):
# Skip this galaxy if the mask is provided and is False
if galaxies_mask is not None and len(galaxies_mask) > 0 and not galaxies_mask[i]:
continue
# Assemble the parameters for this specific galaxy
params = fixed_params.copy()
for key in varying_param_names:
params[key] = self.galaxy_params[key][i]
# Determine the stellar mass for this galaxy
try:
mass = log_base_masses[i]
except (IndexError, TypeError):
mass = log_base_masses
job_inputs.append(
{
"sfh": self.sfhs[i],
"redshift": self.redshifts[i],
"metal_dist": self.metal_dists[i],
"log_stellar_mass": mass,
"params": params,
"grid": self.grid,
"alt_parametrizations": self.alt_parametrizations,
}
)
"""
# Use the optimized version
self.create_galaxies_optimized(
galaxies_mask=galaxies_mask,
varying_param_names=varying_param_names,
log_base_masses=log_base_masses,
fixed_params=fixed_params,
n_proc=1,
)
logger.info(f"Created {len(self.galaxies)} galaxies.")
# Use sets instead of lists for faster lookups and unique values
self.all_parameters = defaultdict(set)
self.all_params = {}
param_units = {}
# Process all galaxies in one pass
for i, gal in enumerate(self.galaxies):
# Store the galaxy parameters directly
self.all_params[i] = gal.all_params
# Add unique values to sets using update operation
for key, value in gal.all_params.items():
if key not in self.params_to_ignore:
if isinstance(value, (unyt_quantity, unyt_array)):
unit = str(value.units)
param_units[key] = unit
value = value.value
if isinstance(value, np.ndarray):
value = value.tolist()
if isinstance(value, list):
# If value is a list, add each element to the set
# for u, v in enumerate(value):
# self.all_parameters[f'{key}_{u}'].add(v)
pass
else:
self.all_parameters[key].add(value)
# Convert sets to lists at the end if needed
self.all_parameters = {k: list(v) for k, v in self.all_parameters.items()}
# Remove any paremters which are just [None]
to_remove = []
fixed_param_names = []
fixed_param_values = []
varying_param_names = []
for key, value in tqdm(
self.all_parameters.items(), desc="Processing parameters", disable=(rank != 0)
):
if len(value) == 1 and value[0] is None:
to_remove.append(key)
continue
# check if all values are the same.
if len(np.unique(value)) == 1:
fixed_param_names.append(key)
fixed_param_values.append(value[0])
else:
varying_param_names.append(key)
for param in self.params_to_ignore:
if param in varying_param_names:
varying_param_names.remove(param)
self.varying_param_names = varying_param_names
self.fixed_param_names = fixed_param_names
self.fixed_param_values = fixed_param_values
self.fixed_param_units = []
for key in fixed_param_names:
if key in param_units:
self.fixed_param_units.append(param_units[key])
else:
self.fixed_param_units.append("")
for key in to_remove:
self.all_parameters.pop(key)
logger.info("Finished creating galaxies.")
return self.galaxies
[docs]
def process_galaxies(
self,
galaxies: List[Type[Galaxy]],
out_name: str = "auto",
out_dir: str = "internal",
n_proc: int = 4,
verbose: int = 1,
save: bool = True,
emission_model_keys=None,
batch_galaxies: bool = True,
batch_size: int = 40_000,
overwrite: bool = False,
multi_node: bool = False,
spectra_to_save: list = None,
em_lines_to_save: list = None,
igm=Inoue14,
**extra_analysis_functions,
) -> Pipeline:
"""Processes galaxies through Synthesizer pipeline.
Parameters
----------
galaxies : List[Type[Galaxy]]
List of Galaxy objects to process.
out_name : str, optional
Name of the output file to save the pipeline results, by default "auto".
If "auto", uses the model_name.
out_dir : str, optional
Directory to save the output file, by default "internal".
If "internal", saves to the Synthesizer grids directory.
n_proc : int, optional
Number of processes to use for parallel processing, by default 4.
verbose : int, optional
Verbosity level for the pipeline, by default 1.
save : bool, optional
If True, saves the pipeline results to disk, by default True.
emission_model_keys : List[str], optional
List of emission model keys to save spectra for, by default None.
If None, saves all spectra.
batch_galaxies : bool, optional
If True, processes galaxies in batches, by default True.
If False, processes all galaxies in a single batch.
batch_size : int, optional
Size of each batch of galaxies to process, by default 40,000.
spectra_to_save : list, optional
List of spectra types to save, by default None. If None, saves only the root spectra.
em_lines_to_save : list, optional
List of emission lines to save, by default None.
extra_analysis_functions : dict, optional
Additional analysis functions to add to the pipeline.
Should be a dictionary where keys are function names and values are tuples
of the function and its parameters.
Returns:
-------
Pipeline
The pipeline object after processing the galaxies.
"""
self.emission_model.set_per_particle(self.per_particle)
if spectra_to_save is None:
spectra_to_save = []
if emission_model_keys is not None:
self.emission_model.save_spectra(emission_model_keys, *spectra_to_save)
logger.info("Creating pipeline.")
if not batch_galaxies:
batch_size = len(galaxies)
n_gal = len(galaxies)
n_batches = int(np.ceil(n_gal / batch_size))
if n_batches > 1:
logger.info(f"Splitting galaxies into {n_batches} batches of size {batch_size}.")
galaxies = [galaxies[i * batch_size : (i + 1) * batch_size] for i in range(n_batches)]
else:
logger.info("Processing all galaxies in a single batch.")
galaxies = [galaxies]
for batch_i, batch_gals in enumerate(galaxies):
skip = False
if save:
if out_dir == "internal":
out_dir = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
"grids/",
)
if out_name == "auto":
out_name = self.model_name
fullpath = os.path.join(out_dir, out_name)
if not os.path.exists(out_dir):
os.makedirs(out_dir)
final_fullpath = fullpath.replace(".hdf5", f"_{batch_i + 1}.hdf5")
init_fullpath = fullpath.replace(".hdf5", "_0.hdf5")
if os.path.exists(final_fullpath) and not overwrite:
logger.warning(
f"Skipping batch {batch_i + 1} as {final_fullpath} already exists."
)
galaxies[batch_i] = None # Clear the batch to free memory
skip = True
if not skip:
if multi_node:
logger.info("Running pipeline in multi-node mode with MPI.")
# logger.debug(f"SIZE: {size}, RANK: {rank}")
else:
logger.info("Running in single-node mode.")
from synthesizer.extensions.openmp_check import check_openmp
if not check_openmp() and n_proc > 1:
logger.warning(
"Synthesizer not installed with OpenMP support. "
"Pipeline will run on a single core"
)
n_proc = 1
pipeline = Pipeline(
emission_model=self.emission_model,
nthreads=n_proc,
verbose=verbose,
comm=comm,
)
for key in self.all_parameters.keys():
pipeline.add_analysis_func(
lambda gal, key=key: gal.all_params[key],
result_key=key,
)
pipeline.add_analysis_func(lambda gal: gal.stars.initial_mass, result_key="mass")
logger.info("Added analysis functions to pipeline.")
if multi_node:
logger.info(f"Pipeline MPI: {pipeline.using_mpi}")
# Add any extra analysis functions requested by the user.
for key, params in extra_analysis_functions.items():
if callable(params):
func = copy.deepcopy(params)
params = []
else:
func = params[0]
params = params[1:]
pipeline.add_analysis_func(func, f"supp_{key}", *params)
# pipeline.get_spectra() # Switch off so they aren't saved
pipeline.get_observed_spectra(self.cosmo, igm=igm)
if self.instrument.can_do_photometry:
pipeline.get_photometry_fluxes(self.instrument, igm=igm)
if em_lines_to_save is not None and len(em_lines_to_save) > 0:
# ["H 1 6562.80A", "O 3 5006.84A", "H 1 4861.32A"]
pipeline.get_lines(line_ids=em_lines_to_save)
pipeline.get_observed_lines(self.cosmo, igm=igm)
pipeline.add_galaxies(batch_gals)
ngal = len(batch_gals)
start = datetime.now()
logger.info(f"Running pipeline at {start} for {ngal} galaxies")
pipeline.run()
elapsed = datetime.now() - start
logger.info(f"Finished running pipeline at {datetime.now()} for {ngal} galaxies")
logger.info(f"Pipeline took {elapsed} to run.")
if save:
# Save the pipeline to a file
pipeline.write(fullpath, verbose=0)
if multi_node:
logger.info("Combining HDF5 files across nodes.")
pipeline.combine_files() # virtual needs work
else:
elapsed = 0
if multi_node:
if not os.path.exists(final_fullpath) and rank == 0:
from synference.utils import combine_rank_files
galaxies_per_node = n_gal // n_batches
size = n_batches
starts = [int(i * galaxies_per_node) for i in range(size)]
ends = [int(start + galaxies_per_node) for start in starts]
ends[-1] = n_gal # Ensure last end is n_gal
logger.info("Combining rank files on rank 0.")
combine_rank_files(size, final_fullpath, ngal, starts, ends)
if save:
wav = self.grid.lam.to(Angstrom).value
if n_batches == 1:
final_fullpath = fullpath
# IF MPI, only do this on rank 0
add = True
if multi_node:
if rank != 0:
add = False
logger.debug(f"Skipping adding attributes on rank {rank}.")
if add:
try:
os.rename(init_fullpath, final_fullpath)
except FileNotFoundError:
pass
with h5py.File(final_fullpath, "r+") as f:
# Add the varying and fixed parameters to the file
f.attrs["varying_param_names"] = self.varying_param_names
f.attrs["fixed_param_names"] = self.fixed_param_names
f.attrs["fixed_param_values"] = self.fixed_param_values
f.attrs["fixed_param_units"] = self.fixed_param_units
# Write some metadata about the model
f.attrs["model_name"] = self.model_name
f.attrs["grid_name"] = self.grid.grid_name
f.attrs["grid_dir"] = self.grid.grid_dir
f.attrs["date_created"] = str(datetime.now())
f.attrs["pipeline_time"] = str(elapsed)
f.create_dataset("Wavelengths", data=wav)
f["Wavelengths"].attrs["Units"] = "Angstrom"
logger.info(f"Written pipeline to disk at {final_fullpath}.")
else:
logger.warning("Not saving pipeline to disk.")
if not skip:
del pipeline # Clean up the pipeline object to free memory
galaxies[batch_i] = None # Clear the batch to free memory
import gc
gc.collect() # Force garbage collection to free memory
[docs]
def plot_random_galaxy(self, masses, **kwargs):
"""Plot a random galaxy from the list of galaxies."""
if not self.build_library:
idx = np.random.randint(0, len(self.redshifts))
mass = masses[idx]
return self.plot_galaxy(idx, log_stellar_mass=mass, **kwargs)
[docs]
def plot_galaxy(
self,
idx,
save: bool = True,
log_stellar_mass: float = 9,
emission_model_keys: List[str] = ["total"],
out_dir="./",
):
"""Plot the galaxy with the given index."""
galaxy_params = {}
for param in self.galaxy_params.keys():
if isinstance(self.galaxy_params[param], dict):
galaxy_params[param] = self.process_priors(self.galaxy_params[param])
elif isinstance(self.galaxy_params[param], (list, np.ndarray)):
galaxy_params[param] = self.galaxy_params[param][idx]
else:
galaxy_params[param] = self.galaxy_params[param]
if not self.build_library and len(self.galaxies) == 0:
# Get idx's from requirements and build galaxy directly
galaxy = create_galaxy(
sfh=self.sfhs[idx],
redshift=self.redshifts[idx],
metal_dist=self.metal_dists[idx],
log_stellar_masses=log_stellar_mass,
grid=self.grid,
**galaxy_params,
)
else:
if idx >= len(self.galaxies):
raise ValueError(f"Index {idx} out of range for galaxies.")
galaxy = self.galaxies[idx]
# Generate spectra
if isinstance(emission_model_keys, str):
emission_model_keys = [emission_model_keys]
galaxy.stars.get_spectra(self.emission_model)
galaxy.get_observed_spectra(cosmo=self.cosmo, igm=Inoue14)
fig, ax = plt.subplots(1, 2, figsize=(10, 5), layout="constrained")
plot_dict = {key: galaxy.stars.spectra[key] for key in emission_model_keys}
colors = {
key: plt.cm.viridis(i / len(emission_model_keys))
for i, key in enumerate(plot_dict.keys())
}
plot_spectra(
plot_dict,
show=False,
fig=fig,
ax=ax[0],
x_units=um,
quantity_to_plot="fnu",
draw_legend=False,
)
# change color of line with label key to the color of the key
for line in ax[0].lines:
label = line.get_label()
test_colors = [i.title() for i in colors.keys()]
if label in test_colors:
pos = test_colors.index(label)
label = list(colors.keys())[pos]
line.set_color(colors[label])
line.set_linewidth(1.5)
line.set_label(label)
ax[0].set_yscale("log")
def custom_xticks(x, pos):
if x == 0:
return "0"
else:
return f"{x / 1e4:.1f}"
ax[0].xaxis.set_major_formatter(FuncFormatter(custom_xticks))
min_x, max_x = 1e10 * um, 0 * um
min_y, max_y = 1e10 * nJy, 0 * nJy
text_gal = {}
for emission_model in emission_model_keys:
sed = galaxy.stars.spectra[emission_model]
# Plot photometry
phot = sed.get_photo_fnu(filters=self.instrument.filters)
min_x = min(min_x, np.nanmin(phot.filters.pivot_lams))
max_x = max(max_x, np.nanmax(phot.filters.pivot_lams))
min_y = min(min_y, np.nanmin(phot.photo_fnu))
max_y = max(max_y, np.nanmax(phot.photo_fnu))
ax[0].plot(
phot.filters.pivot_lams,
phot.photo_fnu,
"+",
color=colors[emission_model],
path_effects=[PathEffects.withStroke(linewidth=4, foreground="white")],
)
# Get the redshift
redshift = galaxy.redshift
# Get the SFH
stars_sfh = galaxy.stars.get_sfh()
stars_sfh = stars_sfh / np.diff(10 ** (self.grid.log10age), prepend=0) / yr
t, sfh = galaxy.stars.sf_hist_func.calculate_sfh()
ax[1].plot(
10 ** (self.grid.log10age - 6),
stars_sfh,
label=f"{emission_model} SFH",
color=colors[emission_model],
)
ax[1].plot(
t / 1e6,
sfh / np.max(sfh) * np.max(stars_sfh),
label=f"Requested {emission_model} SFH",
color=colors[emission_model],
linestyle="--",
)
mass = galaxy.stars.initial_mass
if mass == 0:
text_gal[emission_model] = f"**{emission_model}**\nNo stars"
else:
age = galaxy.stars.calculate_mean_age()
zmet = galaxy.stars.calculate_mean_metallicity()
text_gal[emission_model] = f"""{emission_model}
Age: {age.to(Myr):.0f}
$\\log_{{10}}(Z)$: {np.log10(zmet):.2f}
$\\log_{{10}}(M_\\star/M_\\odot)$: {np.log10(mass):.1f}"""
info_blocks = []
info_blocks.append(f"$z = {redshift:.2f}$")
info_blocks.extend(text_gal.values())
# Format and add galaxy parameters if they exist
if galaxy_params:
param_lines = [f"{key}: {value:.2f}" for key, value in galaxy_params.items()]
info_blocks.append("\n".join(param_lines))
# Join the blocks with double newlines for clear separation
textstr = "\n\n".join(info_blocks)
# Define properties for the text box
props = dict(boxstyle="round", facecolor="wheat", alpha=0.5)
# Place the text box on the axes
ax[1].text(
0.95,
0.95, # Adjusted y-position for better placement with verticalalignment='top'
textstr,
transform=ax[1].transAxes,
fontsize=10, # Adjusted for better fit
horizontalalignment="right",
verticalalignment="top",
bbox=props,
)
# add a secondary axis with AB magnitudes
# 31.4
def ab_to_jy(f):
return 1e9 * 10 ** (f / (-2.5) - 8.9)
def jy_to_ab(f):
f = f / 1e9
return -2.5 * np.log10(f) + 8.9
secax = ax[0].secondary_yaxis("right", functions=(jy_to_ab, ab_to_jy))
# set scalar formatter
max_age = self.cosmo.age(redshift) # - self.cosmo.age(20)
max_age = max_age.to(u.Myr).value
ax[1].set_xlim(0, 2 * max_age)
# vline at max_age
ax[1].axvline(max_age, color="red", linestyle="--", linewidth=0.5)
secax.yaxis.set_major_formatter(ScalarFormatter())
secax.yaxis.set_minor_formatter(ScalarFormatter())
secax.set_ylabel("Flux Density [AB mag]")
ax[1].set_xlabel("Time [Myr]")
ax[1].set_ylabel(r"SFH [M$_\odot$ yr$^{-1}$]")
# ax[1].set_yscale('log')
ax[1].legend()
self.tmp_redshift = redshift
self.tmp_time_unit = u.Myr
# secax = ax[1].secondary_xaxis("top",
# functions=(self._time_convert, self._z_convert))
# secax.set_xlabel("Redshift")
# secax.set_xticks([6, 7, 8, 10, 12, 14, 15, 20])
ax[0].set_xlim(min_x, max_x)
ax[0].set_ylim(min_y, max_y)
if save:
if not os.path.exists(out_dir):
os.makedirs(out_dir)
fig.savefig(f"{out_dir}/{self.model_name}_{idx}.png", dpi=300)
plt.close(fig)
return fig
def _time_convert(self, lookback_time):
time_unit = getattr(self, "tmp_time_unit", u.yr)
lookback_time = lookback_time * time_unit
return z_at_value(
self.cosmo.lookback_time,
self.cosmo.lookback_time(self.tmp_redshift) + lookback_time,
).value
def _z_convert(self, z):
if type(z) in [list, np.ndarray] and len(z) == 0:
return np.array([])
time_unit = getattr(self, "tmp_time_unit", u.yr)
return (
(self.cosmo.lookback_time(z) - self.cosmo.lookback_time(self.tmp_redshift))
.to(time_unit)
.value
)
[docs]
def process_base(
self,
out_name,
log_stellar_masses: unyt_array = None,
emission_model_key: str = "total",
out_dir: str = library_folder,
n_proc: int = 6,
overwrite: Union[bool, List[bool]] = False,
verbose=False,
batch_size: int = 40_000,
multi_node: bool = False,
spectra_to_save: list = None,
em_lines_to_save: list = None,
**extra_analysis_functions,
):
"""Run pipeline for this base.
Implements functionality of CombinedBasis.process_bases for
a single base. This is a convenience method to allow the
GalaxyBasis to be run seperately.
"""
if log_stellar_masses is None:
assert self.log_stellar_masses is not None, (
"log_stellar_masses must be provided or set in the GalaxyBasis"
)
log_stellar_masses = self.log_stellar_masses
assert not isinstance(log_stellar_masses, unyt_array), (
"log_stellar_masses must be a unyt_array"
)
assert len(log_stellar_masses) == len(
self.redshifts
), f"""log_stellar_masses must be the same length as redshifts,
got {len(log_stellar_masses)} and {len(self.redshifts)},
Calling this method on GalaxyBasis only supports
the case where all samples have been provided, not
the case where samples are drawn from a prior and
combined directly.
"""
if not isinstance(overwrite, (tuple, list, np.ndarray)):
overwrite = [overwrite] * 1
else:
if len(overwrite) != 1:
raise ValueError(
"""overwrite must be a boolean or a
list of booleans with the same length as bases"""
)
full_out_path = f"{out_dir}/{out_name}.hdf5"
if (
os.path.exists(full_out_path)
# or os.path.exists(f"{out_dir}/{out_name}_{total_batches}.hdf5")
and not overwrite[0]
):
logger.info(f"File {full_out_path} already exists. Skipping loading.")
return
if os.path.exists(full_out_path) and overwrite[0]:
logger.info(f"Overwriting {full_out_path}.")
if not os.path.exists(out_dir):
os.makedirs(out_dir)
galaxies = self._create_galaxies(log_base_masses=log_stellar_masses)
self.process_galaxies(
galaxies,
f"{out_name}.hdf5",
out_dir=out_dir,
n_proc=n_proc,
verbose=verbose,
save=True,
emission_model_keys=emission_model_key,
batch_size=batch_size,
overwrite=overwrite[0],
multi_node=multi_node,
spectra_to_save=spectra_to_save,
em_lines_to_save=em_lines_to_save,
**extra_analysis_functions,
)
[docs]
def create_mock_library(
self,
out_name,
log_stellar_masses: np.ndarray = None,
emission_model_key: str = "total",
out_dir: str = library_folder,
n_proc: int = 6,
overwrite: Union[bool, List[bool]] = False,
verbose=False,
batch_size: int = 40_000,
parameter_transforms_to_save: dict[str, (str, callable)] = None,
cat_type="photometry",
compile_grid: bool = True,
multi_node: bool = False,
spectra_to_save: Optional[list] = None,
em_lines_to_save: Optional[list] = None,
**extra_analysis_functions,
):
"""Convenience method which calls CombinedBasis.
This is a convenience method which allows
you to not have to pass a GalaxyBasis into
CombinedBasis, and instead just call
this method which will run the components for you.
Parameters
----------
out_name : str
Name of the output file to save the mock catalog.
log_stellar_masses : np.ndarray, optional
Array of log stellar masses to use for the mock catalog,
Units of log10(M sun), by default None.
If None, uses the stellar_masses set in the GalaxyBasis.
emission_model_key : str, optional
Emission model key to use for the mock catalog,
by default "total".
out_dir : str, optional
Directory to save the output file, by default "library_folder".
If "library_folder", saves to the Synthesizer grids directory.
n_proc : int, optional
Number of processes to use for the pipeline, by default 6.
overwrite : Union[bool, List[bool]], optional
If True, overwrites the output file if it exists,
by default False. If a list, must be the same length as bases.
verbose : bool, optional
If True, prints verbose output during processing,
by default False.
batch_size : int, optional
Size of each batch of galaxies to process,
by default 40,000.
parameter_transforms_to_save : Dict[str: (str, callable)], optional
Dictionary of parameter transforms to save in the output file.
Only used for for saving with the simulator to allow
reconstruction of the model later. Lambda functions
cannot be used here.
Should be a dictionary where keys are the parameter names
in the model, and the values are a tuple.
The tuple should be (str, callable), where the str is the
new parameter name to save, and the callable is the function
which takes the model parameters and returns the new parameter value.
It can also be (List[str], callable) if the function returns multiple values.
(e.g. converting one parameter to many.)
Finally, if you are adding a new parameter which is not in the
model, you can a direct str: callable pair, which will add a new
parameter to the model based on the callable function.
compile_grid : bool, optional
If True, compiles the library after processing,
by default True.
multi_node : bool, optional
If True, runs the processing in parallel across multiple nodes,
by default False. Will only enable this, script still needs to be run
with slurm or similar.
cat_type : str, optional
Type of catalog to create, either "photometry" or "spectra",
by default "photometry".
spectra_to_save : Optional[list], optional
List of spectra types to save, by default None. If None, saves only the root spectra.
extra_analysis_functions : dict, optional
Additional analysis functions to add to the pipeline.
Each should have the argument name as the key, and the value
should be a tuple of the function and its parameters.
"""
# make a CombinedBasis object with the current GalaxyBasis
if log_stellar_masses is None:
assert self.log_stellar_masses is not None, (
"log_stellar_masses must be provided or set in the GalaxyBasis"
)
log_stellar_masses = self.log_stellar_masses
assert not isinstance(log_stellar_masses, unyt_array), (
"log_stellar_masses must be not be a unyt_array"
)
combined_basis = CombinedBasis(
bases=[self],
log_stellar_masses=log_stellar_masses,
redshifts=self.redshifts,
base_emission_model_keys=[emission_model_key],
combination_weights=None,
out_name=out_name,
out_dir=out_dir,
draw_parameter_combinations=False,
)
if multi_node:
logger.info("Running in multi-node mode. Using MPI for parallel processing.")
galaxy_mask = np.zeros(len(combined_basis.redshifts), dtype=bool)
total_galaxies = len(combined_basis.redshifts)
galaxies_per_node = total_galaxies // size
start_idx = rank * galaxies_per_node
end_idx = start_idx + galaxies_per_node
if rank == size - 1: # Last node gets the remainder
end_idx = total_galaxies
galaxy_mask[start_idx:end_idx] = True
logger.info(f"Node {rank} processing galaxies from {start_idx} to {end_idx}.")
else:
galaxy_mask = None
combined_basis.process_bases(
n_proc=n_proc,
overwrite=overwrite,
verbose=verbose,
batch_size=batch_size,
multi_node=multi_node,
galaxies_mask=galaxy_mask,
spectra_to_save=spectra_to_save,
em_lines_to_save=em_lines_to_save,
**extra_analysis_functions,
)
if compile_grid:
# Make code wait until all bases are processed
logger.info("Compiling the library after processing bases.")
if cat_type == "photometry":
combined_basis.create_library(overwrite=overwrite)
elif cat_type == "spectra":
combined_basis.create_spectral_grid(overwrite=overwrite)
else:
raise ValueError(
f"Unknown catalog type: {cat_type}. Use 'photometry' or 'spectra'."
)
out_path = f"{combined_basis.out_dir}/{combined_basis.out_name}"
if not out_path.endswith(".hdf5"):
out_path += ".hdf5"
self._store_model(
out_path,
other_info={
"emission_model_key": emission_model_key,
"timestamp": datetime.now().isoformat(),
"cat_type": cat_type,
},
parameter_transforms_to_save=parameter_transforms_to_save,
)
logger.info("Processed the bases and saved the output.")
return combined_basis
[docs]
class CombinedBasis:
"""Class to create a photometry array from Synthesizer pipeline outputs.
This class combines multiple GalaxyBasis objects, processes them,
and saves the output to HDF5 files. The simplest operation is a
single base, where we would provide a single GalaxyBasis object
and this class will handle running the Synthesizer pipeline,
processing the output and saving the results. The reason this
is seperated is because we can combine multiple bases.
So when the GalaxyBasis is run, we determine photometry/spectra
with a filler mass (normalization). Then this class will renormalize
the photometry/spectra to the total stellar masses provided,
and optionally weight them by the combination of bases. This is done
so that you can have a case where there is e.g. two bases with
different SPS grids, and we can build a galaxy where 15% of the mass
is from the first base, and 85% is from the second base. It can
also allow flexiblity when generating galaxies in the pipeline,
as the stellar mass dimension could be neglected in e.g. the LHC
and applied afterwards.
"""
def __init__(
self,
bases: List[Type[GalaxyBasis]],
log_stellar_masses: list,
redshifts: np.ndarray,
base_emission_model_keys: List[str],
combination_weights: np.ndarray,
out_name: str = "combined_basis",
out_dir: str = library_folder,
log_base_masses: Union[float, np.ndarray] = 9,
draw_parameter_combinations: bool = False,
) -> None:
"""Initialize the CombinedBasis object.
Parameters
----------
bases : List[Type[GalaxyBasis]]
List of GalaxyBasis objects to combine.
log_stellar_masses : list
Array of total stellar masses to renormalize fluxes for.
in log10(M sun) units.
redshifts : np.ndarray
Array of redshifts for the bases.
base_emission_model_keys : List[str]
List of emission model keys for each base.
combination_weights : np.ndarray
Array of combination weights for the bases.
out_name : str
Name of the output file.
out_dir : str
Directory to save the output files.
log_base_masses : unyt_array
Default mass (or mass array) to use for the galaxies.
Units of log10(M sun).
draw_parameter_combinations : bool
If True, draw parameter combinations for the galaxies.
If False, create matched galaxies with the same parameters.
"""
self.bases = bases
self.log_stellar_masses = log_stellar_masses
self.redshifts = redshifts
self.combination_weights = combination_weights
self.out_name = out_name
self.out_dir = out_dir
self.log_base_masses = log_base_masses
self.base_emission_model_keys = base_emission_model_keys
self.draw_parameter_combinations = draw_parameter_combinations
if isinstance(redshifts, (int, float)):
redshifts = np.full(len(self.log_stellar_masses), redshifts)
if self.combination_weights is None:
assert len(self.bases) == 1
self.combination_weights = [1.0] * len(redshifts)
def _check_final_hdf5(self, full_out_path: str) -> bool:
"""Checks if custom attributes are in the final HDF5 file."""
if not os.path.exists(full_out_path):
return False
with h5py.File(full_out_path, "r") as f:
required_attrs = [
"varying_param_names",
"fixed_param_names",
"fixed_param_values",
"fixed_param_units",
"model_name",
"grid_name",
"grid_dir",
"date_created",
"pipeline_time",
]
for attr in required_attrs:
if attr not in f.attrs:
logger.warning(f"Attribute {attr} not found in {full_out_path}.")
return False
required_datasets = ["Wavelengths"]
for dataset in required_datasets:
if dataset not in f:
logger.warning(f"Dataset {dataset} not found in {full_out_path}.")
return False
return True
[docs]
def process_bases(
self,
n_proc: int = 6,
overwrite: Union[bool, List[bool]] = False,
verbose=False,
batch_size: int = 40_000,
multi_node: bool = False,
galaxies_mask: Optional[np.ndarray] = None,
spectra_to_save: Optional[list] = None,
em_lines_to_save: Optional[list] = None,
**extra_analysis_functions,
) -> None:
"""Process the bases and save the output to files.
Parameters
----------
n_proc : int
Number of processes to use for the pipeline.
overwrite : bool or list of bools
If True, overwrite the existing files.
If False, skip the files that already exist.
If a list of bools is provided,
it should have the same length as the number of bases.
extra_analysis_functions : dict
Extra analysis functions to add to the pipeline.
The keys should be the names of the functions,
and the values should be the functions themselves,
or a tuple of (function, args). The function
should take a Galaxy object as the first argument,
and the args should be the arguments to pass to the function.
The function should return a single value, an array of values,
or a dictionary of values (with the same keys for all galaxies).
"""
if not isinstance(overwrite, (tuple, list, np.ndarray)):
overwrite = [overwrite] * len(self.bases)
else:
if len(overwrite) != len(self.bases):
raise ValueError(
"""overwrite must be a boolean or a
list of booleans with the same length as bases"""
)
for i, base in enumerate(self.bases):
full_out_path = f"{self.out_dir}/{base.model_name}.hdf5"
ngalaxies = len(self.log_stellar_masses)
if galaxies_mask is not None:
ngalaxies = np.sum(galaxies_mask)
if ngalaxies == 0:
logger.warning(f"No galaxies to process for base {base.model_name}. Skipping.")
continue
if (
os.path.exists(full_out_path) and self._check_final_hdf5(full_out_path)
# or os.path.exists(f"{self.out_dir}/{base.model_name}_{total_batches}.hdf5")
) and not overwrite[i]:
logger.warning(f"File {full_out_path} already exists. Skipping.")
continue
elif os.path.exists(full_out_path) and overwrite[i]:
logger.warning(f"File {full_out_path} already exists. Overwriting..")
os.remove(full_out_path)
elif not os.path.exists(self.out_dir):
if rank == 0:
logger.warning(f"Creating output directory {self.out_dir}.")
os.makedirs(self.out_dir)
if self.draw_parameter_combinations:
galaxies = base._create_galaxies(log_base_masses=self.log_base_masses)
if galaxies_mask is not None:
raise NotImplementedError(
"galaxies_mask is not implemented for draw_parameter_combinations=False."
)
else:
galaxies = base._create_matched_galaxies(
log_base_masses=self.log_base_masses, galaxies_mask=galaxies_mask, n_proc=n_proc
)
logger.info(f"Created {len(galaxies)} galaxies for base {base.model_name}")
# Process the galaxies
base.process_galaxies(
galaxies,
f"{base.model_name}.hdf5",
out_dir=self.out_dir,
n_proc=n_proc,
verbose=verbose,
save=True,
emission_model_keys=self.base_emission_model_keys[i],
batch_size=batch_size,
overwrite=overwrite[i],
multi_node=multi_node,
spectra_to_save=spectra_to_save,
em_lines_to_save=em_lines_to_save,
**extra_analysis_functions,
)
[docs]
def load_bases(self, load_spectra=False) -> dict:
"""Load the processed bases from the output directory.
Parameters
----------
load_spectra : bool
If True, load the observed spectra from the output files.
If False, only load the properties and photometry.
Returns:
-------
dict
A dictionary with the base model names as keys and a dictionary
of properties, observed spectra, wavelengths, and observed photometry.
"""
outputs = {}
for i, base in enumerate(self.bases):
logger.info(
f"Emission model key for base {base.model_name}:{self.base_emission_model_keys[i]}"
)
full_out_path = f"{self.out_dir}/{base.model_name}.hdf5"
if not os.path.exists(full_out_path):
if os.path.exists(f"{self.out_dir}/{base.model_name}_1.hdf5"):
import glob
# Check if there are multiple files for this base
full_out_paths = glob.glob(f"{self.out_dir}/{base.model_name}_*.hdf5")
logger.info(f"Found {len(full_out_paths)} files for base {base.model_name}.")
full_out_paths = sorted(
full_out_paths,
key=lambda x: int(x.split("_")[-1].split(".")[0]),
)
else:
raise ValueError(
f"Synthesizer pipeline output {full_out_path} does not exist. "
"Have you run the pipeline using `combined_basis.process_bases` first?"
) # noqa E501
else:
full_out_paths = [full_out_path]
for j, path in tqdm(enumerate(full_out_paths), desc="Loading galaxy properties"):
properties = {}
supp_properties = {}
with h5py.File(path, "r") as f:
# Load in which parameters are varying and fixed
base.varying_param_names = f.attrs["varying_param_names"]
base.fixed_param_names = f.attrs["fixed_param_names"]
base.fixed_param_units = f.attrs["fixed_param_units"]
base.fixed_param_values = f.attrs["fixed_param_values"]
galaxies = f["Galaxies"]
property_keys = list(galaxies.keys())
for key in ["Stars", "BlackHoles", "Gas", "Photometry", "Spectra"]:
if key in property_keys:
property_keys.remove(key)
for key in property_keys:
if key.startswith("supp_"):
dic = supp_properties
use_key = key[5:]
else:
dic = properties
use_key = key
if isinstance(galaxies[key], h5py.Group):
dic[use_key] = {}
for subkey in galaxies[key].keys():
dic[use_key][subkey] = galaxies[key][subkey][()]
if hasattr(galaxies[key][subkey], "attrs"):
if "Units" in galaxies[key][subkey].attrs:
unit = galaxies[key][subkey].attrs["Units"]
dic[use_key][subkey] = unyt_array(
dic[use_key][subkey], unit
)
else:
dic[use_key] = galaxies[key][()]
if hasattr(galaxies[key], "attrs"):
if "Units" in galaxies[key].attrs:
unit = galaxies[key].attrs["Units"]
dic[use_key] = unyt_array(dic[use_key], unit)
if load_spectra:
# Get the spectra
spec_path = (
f"Spectra/SpectralFluxDensities/{self.base_emission_model_keys[i]}"
)
if "Spectra" not in galaxies.keys():
if (
"Stars" in galaxies.keys()
and "Spectra" in galaxies["Stars"].keys()
and self.base_emission_model_keys[i]
in galaxies["Stars"]["Spectra"]["SpectralFluxDensities"].keys()
):
spec_path = (
"Stars/Spectra/SpectralFluxDensities/"
f"{self.base_emission_model_keys[i]}"
)
elif (
"BlackHoles" in galaxies.keys()
and "Spectra" in galaxies["BlackHoles"].keys()
and self.base_emission_model_keys[i]
in galaxies["BlackHoles"]["Spectra"]["SpectralFluxDensities"].keys()
):
spec_path = (
"BlackHoles/Spectra/SpectralFluxDensities/"
f"{self.base_emission_model_keys[i]}"
)
elif (
"Gas" in galaxies.keys()
and "Spectra" in galaxies["Gas"].keys()
and self.base_emission_model_keys[i]
in galaxies["Gas"]["Spectra"]["SpectralFluxDensities"].keys()
):
spec_path = (
"Gas/Spectra/SpectralFluxDensities/"
f"{self.base_emission_model_keys[i]}"
)
else:
raise ValueError("No spectra found in the output file.")
spec = galaxies[spec_path]
if isinstance(spec, h5py.Dataset):
observed_spectra = spec
else:
assert (
self.base_emission_model_keys[i] in spec.keys()
), f"""Emission model key {self.base_emission_model_keys[i]}
not found in {spec.keys()}"""
observed_spectra = spec[self.base_emission_model_keys[i]]
observed_spectra = unyt_array(
observed_spectra,
units=observed_spectra.attrs["Units"],
)
else:
observed_spectra = {}
phot_path = f"Photometry/Fluxes/{self.base_emission_model_keys[i]}"
if "Photometry" not in galaxies.keys():
if (
"Stars" in galaxies.keys()
and "Photometry" in galaxies["Stars"].keys()
and self.base_emission_model_keys[i]
in galaxies["Stars"]["Photometry"]["Fluxes"].keys()
):
phot_path = (
f"Stars/Photometry/Fluxes/{self.base_emission_model_keys[i]}"
)
elif (
"BlackHoles" in galaxies.keys()
and "Photometry" in galaxies["BlackHoles"].keys()
and self.base_emission_model_keys[i]
in galaxies["BlackHoles"]["Photometry"]["Fluxes"].keys()
):
phot_path = (
f"BlackHoles/Photometry/Fluxes/{self.base_emission_model_keys[i]}"
)
elif (
"Gas" in galaxies.keys()
and "Photometry" in galaxies["Gas"].keys()
and self.base_emission_model_keys[i]
in galaxies["Gas"]["Photometry"]["Fluxes"].keys()
):
phot_path = f"Gas/Photometry/Fluxes/{self.base_emission_model_keys[i]}"
else:
raise ValueError("No photometry found in the output file.")
observed_photometry = galaxies[f"{phot_path}"]
phot = {}
for observatory in observed_photometry:
phot_inst = observed_photometry[observatory]
if isinstance(phot_inst, h5py.Dataset):
# THIS IS A HACK TO AVOID LOADING
# REST-FRAME FLUXES
continue
for key in phot_inst.keys():
full_key = f"{observatory}/{key}"
phot[full_key] = phot_inst[key][()]
if j == 0:
outputs[base.model_name] = {
"properties": properties,
"observed_spectra": observed_spectra,
"wavelengths": unyt_array(
f["Wavelengths"][()],
units=f["Wavelengths"].attrs["Units"],
),
"observed_photometry": phot,
"supp_properties": supp_properties,
}
else:
# Combine the outputs for this base with the previous ones
for key in properties.keys():
outputs[base.model_name]["properties"][key] = np.concatenate(
(
outputs[base.model_name]["properties"][key],
properties[key],
)
)
for key in phot.keys():
if key not in outputs[base.model_name]["observed_photometry"]:
outputs[base.model_name]["observed_photometry"][key] = []
outputs[base.model_name]["observed_photometry"][key] = np.concatenate(
(
outputs[base.model_name]["observed_photometry"][key],
phot[key],
)
)
if load_spectra:
# Combine the observed spectra (from different files)
if "observed_spectra" not in outputs[base.model_name]:
outputs[base.model_name]["observed_spectra"] = []
outputs[base.model_name]["observed_spectra"] = np.concatenate(
(outputs[base.model_name]["observed_spectra"], observed_spectra)
)
# Combine supplementary properties
for key in supp_properties.keys():
if key not in outputs[base.model_name]["supp_properties"]:
outputs[base.model_name]["supp_properties"][key] = {}
if not isinstance(
supp_properties[key],
dict,
):
val = supp_properties[key]
supp_properties[key] = {self.base_emission_model_keys[i]: val}
if not isinstance(
outputs[base.model_name]["supp_properties"][key],
dict,
):
outputs[base.model_name]["supp_properties"][key] = {
self.base_emission_model_keys[i]: outputs[base.model_name][
"supp_properties"
][key]
}
for subkey in supp_properties[key].keys():
if subkey not in outputs[base.model_name]["supp_properties"][key]:
outputs[base.model_name]["supp_properties"][key][subkey] = []
outputs[base.model_name]["supp_properties"][key][subkey] = (
np.concatenate(
(
outputs[base.model_name]["supp_properties"][key][
subkey
],
supp_properties[key][subkey],
)
)
)
self.pipeline_outputs = outputs
return outputs
[docs]
def create_library(
self,
override_instrument: Union[Instrument, None] = None,
save: bool = True,
overload_out_name: str = "",
overwrite: bool = False,
) -> dict:
"""Creates a library of SEDs for the given Synthesizer outputs.
This method assumes each input on CombinedBasis, (redshift, mass,
and varying parameters) should be combined (e.g. sampling
every combination of redshift, mass, and varying parameters) to
create the library. The 'create_full_library' method instead assumes
that the input parameters are already combined, and does not
sample every combination. Generally the 'create_full_library' method
is more useful for the case where you have predrawn parameters
randomly or from a prior or Latin Hypercube.
Parameters
----------
override_instrument : Instrument, optional
If provided, overrides the instrument used for the library.
save : bool, optional
If True, saves the library to a file.
overload_out_name : str, optional
If provided, overrides the output name for the library.
overwrite : bool, optional
If True, overwrites the existing library file if it exists.
Returns:
-------
dict
A dictionary containing the library of SEDs, photometry, and properties.
-----------
"""
if not self.draw_parameter_combinations:
return self.create_full_library(
override_instrument,
overwrite=overwrite,
save=save,
overload_out_name=overload_out_name,
)
if overload_out_name != "":
out_name = overload_out_name
else:
out_name = self.out_name
if os.path.exists(f"{self.out_dir}/{out_name}") and not overwrite:
logger.warning(f"File {self.out_dir}/{out_name} already exists. Skipping.")
self.load_library_from_file(f"{self.out_dir}/{out_name}")
return
pipeline_outputs = self.load_bases()
base_filters = self.bases[0].instrument.filters.filter_codes
for i, base in enumerate(self.bases):
if base.instrument.filters.filter_codes != base_filters:
raise ValueError(
f"""Base {i} has different filters to base 0.
Cannot combine bases with different filters."""
)
if override_instrument is not None:
# Check all filters in override_instrument are in the base filters
for filter_code in override_instrument.filters.filter_codes:
if filter_code not in base_filters:
raise ValueError(
f"""Filter {filter_code} not found in base filters.
Cannot override instrument."""
)
filter_codes = override_instrument.filters.filter_codes
else:
filter_codes = base_filters
# filter_codes = [i.split("/")[-1] for i in filter_codes]
all_outputs = []
all_params = []
all_supp_params = []
ignore_keys = ["redshift"]
total_property_names = {}
for i, base in enumerate(self.bases):
if len(self.bases) > 1:
total_property_names[base.model_name] = [
f"{base.model_name}/{i}"
for i in base.varying_param_names
if i not in ignore_keys
]
else:
total_property_names[base.model_name] = base.varying_param_names
params = pipeline_outputs[base.model_name]["properties"]
rename_keys = [i for i in base.varying_param_names if i not in ignore_keys]
for key in list(params.keys()):
if key in rename_keys:
# rename the key to be the base name + parameter name
params[f"{base.model_name}/{key}"] = params[key]
supp_param_keys = list(pipeline_outputs[self.bases[0].model_name]["supp_properties"].keys())
assert all(
[
i in pipeline_outputs[self.bases[0].model_name]["supp_properties"]
for i in supp_param_keys
]
), f"""Not all bases have the same supplementary parameters.
{supp_param_keys} not found in
{pipeline_outputs[self.bases[0].model_name]["supp_properties"].keys()}"""
# Deal with any supplementary model parameters.
# Currently we require that all bases have the same supplementary parameters
supp_params = {}
supp_param_units = {}
for i, base in enumerate(self.bases):
supp_params[base.model_name] = {}
for key in supp_param_keys:
if isinstance(
pipeline_outputs[base.model_name]["supp_properties"][key],
dict,
):
subkeys = list(pipeline_outputs[base.model_name]["supp_properties"][key].keys())
# Check if the emission model key is in the subkeys
if self.base_emission_model_keys[i] not in subkeys:
raise ValueError(
f"""Emission model key {self.base_emission_model_keys[i]}
not found in {subkeys}.
Don't know how to deal with
dictionary supplementary parameters with other keys."""
)
value = pipeline_outputs[base.model_name]["supp_properties"][key][
self.base_emission_model_keys[i]
]
else:
value = pipeline_outputs[base.model_name]["supp_properties"][key]
supp_params[base.model_name][key] = value
# Check if any of the bases have the same varying parameters
all_combined_param_names = []
for key, value in total_property_names.items():
all_combined_param_names.extend(value)
# Add our standard parameters that are always included
param_columns = ["redshift", "log_mass"]
param_units = {}
param_units["redshift"] = "dimensionless"
param_units["log_mass"] = "log10(Mstar/Msun)"
if len(self.bases) > 1:
param_columns.append("weight_fraction")
param_units["weight_fraction"] = "dimensionless"
param_columns.extend(all_combined_param_names)
for redshift in tqdm(self.redshifts, desc="Creating library"):
for log_total_mass in self.log_stellar_masses:
total_mass = 10**log_total_mass
for combination in self.combination_weights:
mass_weights = np.array(combination) * total_mass
scaled_photometries = []
base_param_values = []
supp_params_values = []
for i, base in enumerate(self.bases):
outputs = pipeline_outputs[base.model_name]
z_base = outputs["properties"]["redshift"]
mask = z_base == redshift
mass = outputs["properties"]["mass"][mask]
if isinstance(mass, unyt_array):
mass = mass.to(Msun).value
# Calculate the scaling factor for each base
scaling_factors = mass_weights[i] / mass
base_photometry = np.array(
[
pipeline_outputs[base.model_name]["observed_photometry"][
filter_code
][mask]
for filter_code in filter_codes
],
dtype=np.float32,
)
# Scale the photometry by the scaling factor
scaled_photometry = base_photometry * scaling_factors
scaled_photometries.append(scaled_photometry)
# Get the varying parameters for this base
base_params = {}
for param_name in total_property_names[base.model_name]:
# Extract the original parameter name without the base prefix
orig_param = param_name.split("/")[-1]
if f"{base.model_name}/{orig_param}" in outputs["properties"]:
base_params[param_name] = outputs["properties"][
f"{base.model_name}/{orig_param}"
][mask]
elif orig_param in outputs["properties"]:
base_params[param_name] = outputs["properties"][orig_param][mask]
# add units
if isinstance(base_params[param_name], unyt_array):
short_param_name = param_name.split("/")[-1]
if short_param_name in UNIT_DICT:
param_units[param_name] = UNIT_DICT[short_param_name]
else:
param_units[param_name] = str(base_params[param_name].units)
else:
# If it's not a unyt_array, assume it's dimensionless
param_units[param_name] = str(dimensionless)
base_param_values.append(base_params)
# Get the supplementary parameters for this base
# For any supp params that are a flux or luminosity,
# scale them by the scaling factor
scaled_supp_params = {}
for key, value in supp_params[base.model_name].items():
if isinstance(value, dict):
scaled_supp_params[key] = {}
for subkey, subvalue in value.items():
if isinstance(subvalue, unyt_array):
scaled_supp_params[key][subkey] = (
subvalue[mask] * scaling_factors
)
else:
scaled_supp_params[key][subkey] = subvalue[mask]
else:
if isinstance(value, unyt_array):
scaled_supp_params[key] = value[mask] * scaling_factors
else:
scaled_supp_params[key] = value[mask]
supp_params_values.append(scaled_supp_params)
# Calculate the total number of combinations
dimension = np.prod([i.shape[-1] for i in scaled_photometries])
output_array = np.zeros((scaled_photometries[0].shape[0], dimension))
params_array = np.zeros((len(param_columns), dimension))
supp_array = np.zeros((len(supp_param_keys), dimension))
# Create all combinations of indices
combinations = np.meshgrid(
*[np.arange(i.shape[-1]) for i in scaled_photometries],
indexing="ij",
)
combinations = np.array(combinations).T.reshape(-1, len(scaled_photometries))
# Fill the output and parameter array.out_name
for i, combo_indices in enumerate(combinations):
# Add the scaled photometries for each base
for j, base in enumerate(self.bases):
output_array[:, i] += scaled_photometries[j][:, combo_indices[j]]
# Fill parameter values
param_idx = 0
# Add standard parameters first
params_array[param_idx, i] = redshift
param_idx += 1
params_array[param_idx, i] = log_total_mass
param_idx += 1
if len(self.bases) > 1:
params_array[param_idx, i] = combination[
0
] # assuming this is weight fraction
param_idx += 1
# Add all varying parameters from each base
for j, base in enumerate(self.bases):
for param_name in total_property_names[base.model_name]:
if param_name in base_param_values[j]:
params_array[param_idx, i] = base_param_values[j][param_name][
combo_indices[j]
]
param_idx += 1
# Add supplementary parameters. Sum parameters from all bases
for j, base in enumerate(self.bases):
for k, param_name in enumerate(supp_param_keys):
data = supp_params_values[j][param_name][combo_indices[j]]
if isinstance(data, unyt_array):
if j == 0:
supp_param_units[param_name] = str(data.units)
data = data.value
else:
if j == 0:
supp_param_units[param_name] = str(dimensionless)
supp_array[k, i] += data
all_outputs.append(output_array)
all_params.append(params_array)
all_supp_params.append(supp_array)
supp_param_units = [i for i in supp_param_units.values()]
# Combine all outputs and parameters
combined_outputs = np.hstack(all_outputs)
combined_params = np.hstack(all_params)
combined_supp_params = np.hstack(all_supp_params)
param_units = [param_units[i] for i in param_columns]
out = {
"photometry": combined_outputs,
"parameters": combined_params,
"parameter_names": param_columns,
"filter_codes": filter_codes,
"supplementary_parameters": combined_supp_params,
"supplementary_parameter_names": supp_param_keys,
"supplementary_parameter_units": supp_param_units,
"parameter_units": param_units,
}
self.library_photometry = combined_outputs
self.library_parameters = combined_params
self.library_parameter_names = param_columns
self.library_filter_codes = filter_codes
self.library_supplementary_parameters = combined_supp_params
self.library_supplementary_parameter_names = supp_param_keys
if save:
self.save_library(out, overload_out_name=out_name, overwrite=overwrite)
def _validate_library(self, library_dict: dict, check_type="photometry") -> None:
"""Validate the library dictionary.
Parameters
----------
library_dict : dict
Dictionary containing the library data.
Expected keys are 'photometry', 'parameters', 'parameter_names',
and 'filter_codes'.
check_type : str, optional
Type of data to check, either 'photometry' or 'spectra',
by default 'photometry'.
Raises:
------
ValueError
If any of the required keys are missing or if the data is not expected format.
"""
required_keys = [
check_type,
"parameters",
"parameter_names",
"filter_codes",
]
for key in required_keys:
if key not in library_dict:
raise ValueError(f"Missing required key: {key}")
if not isinstance(library_dict[check_type], np.ndarray):
raise ValueError(f"{check_type} must be a numpy array.")
if not isinstance(library_dict["parameters"], np.ndarray):
raise ValueError("Parameters must be a numpy array.")
if not isinstance(library_dict["parameter_names"], list):
raise ValueError("Parameter names must be a list.")
if not isinstance(library_dict["filter_codes"], (list, np.ndarray)):
raise ValueError("Filter codes must be a list.")
# Check for NAN/INF in photometry and parameters
assert not np.any(np.isnan(library_dict[check_type])), (
f"{check_type} contains NaN values. Please check the input data."
)
assert not np.any(np.isinf(library_dict[check_type])), (
f"{check_type} contains infinite values. Please check the input data."
)
assert not np.any(np.isnan(library_dict["parameters"])), (
"Parameters contain NaN values. Please check the input data."
)
assert not np.any(np.isinf(library_dict["parameters"])), (
"Parameters contain infinite values. Please check the input data."
)
[docs]
def save_library(
self,
library_dict: dict, # E.g. output from create_library
overload_out_name: str = "",
overwrite: bool = False,
library_params_to_save=["model_name"],
) -> None:
"""Save the library to a file.
Parameters
----------
library_dict : dict
Dictionary containing the library data.
Expected keys are 'photometry', 'parameters', 'parameter_names',
and 'filter_codes'.
out_name : str, optional
Name of the output file, by default 'grid.hdf5'
"""
check_type = "photometry" if "photometry" in library_dict else "spectra"
self._validate_library(library_dict, check_type=check_type)
# Check if the output directory exists, if not create it
if not os.path.exists(self.out_dir):
os.makedirs(self.out_dir)
if overload_out_name != "":
out_name = overload_out_name
else:
out_name = self.out_name
if not out_name.endswith(".hdf5"):
out_name = f"{out_name}.hdf5"
# Create the full output path
full_out_path = os.path.join(self.out_dir, out_name)
# Check if the file already exists
if os.path.exists(full_out_path) and not overwrite:
logger.warning(f"File {full_out_path} already exists. Skipping.")
return
elif os.path.exists(full_out_path) and overwrite:
logger.warning(f"File {full_out_path} already exists. Overwriting.")
os.remove(full_out_path)
# Create a new HDF5 file
with h5py.File(full_out_path, "w") as f:
# Create a group for the library data
library_group = f.create_group("Grid")
# Create datasets for the photometry and parameters
if "photometry" in library_dict:
library_group.create_dataset(
"Photometry", data=library_dict["photometry"], compression="gzip"
)
if "spectra" in library_dict:
library_group.create_dataset(
"Spectra", data=library_dict["spectra"], compression="gzip"
)
library_group.create_dataset(
"Parameters", data=library_dict["parameters"], compression="gzip"
)
if "supplementary_parameters" in library_dict:
library_group.create_dataset(
"SupplementaryParameters",
data=library_dict["supplementary_parameters"],
compression="gzip",
)
# Create a dataset for the parameter names
f.attrs["ParameterNames"] = library_dict["parameter_names"]
try:
f.attrs["FilterCodes"] = library_dict["filter_codes"]
except (OSError, RuntimeError):
# HDF5 has a limit on string length for attributes
# Save as a dataset instead
library_group.create_dataset(
"FilterCodes",
data=np.array(library_dict["filter_codes"], dtype="S"),
compression="gzip",
)
f.attrs["FilterCodes"] = "/Grid/FilterCodes/"
if "supplementary_parameters" in library_dict:
f.attrs["SupplementaryParameterNames"] = library_dict[
"supplementary_parameter_names"
]
f.attrs["SupplementaryParameterUnits"] = library_dict[
"supplementary_parameter_units"
]
f.attrs["PhotometryUnits"] = "nJy"
if "parameter_units" in library_dict:
f.attrs["ParameterUnits"] = library_dict["parameter_units"]
for param in library_params_to_save:
out = []
for base in self.bases:
out.append(str(getattr(base, param)))
f.attrs[param] = out
# Add a timestamp
f.attrs["Grids"] = [base.grid.grid_name for base in self.bases]
f.attrs["CreationDT"] = datetime.now().strftime("%Y%m%d_%H%M%S")
# Add anything else as a dataset
for key, value in library_dict.items():
if key not in [
"photometry",
"spectra",
"parameters",
"parameter_names",
"filter_codes",
"supplementary_parameters",
"supplementary_parameter_names",
"supplementary_parameter_units",
"parameter_units",
]:
if isinstance(value, (np.ndarray, list)) and isinstance(value[0], str):
f.attrs[key] = value
else:
library_group.create_dataset(key, data=value, compression="gzip")
if isinstance(value, (unyt_array, unyt_quantity)):
library_group[key].attrs["Units"] = value.units
[docs]
def plot_galaxy_from_library(
self,
index: int,
show: bool = True,
save: bool = False,
override_filter_codes=[],
):
"""Plot a galaxy at a grid index.
Parameters
----------
index : int
Index of the galaxy in the library.
show : bool, optional
If True, shows the plot. Defaults to True.
save : bool, optional
If True, saves the plot. Defaults to False.
"""
if self.library_photometry is None:
raise ValueError("Library photometry not created. Run create_library() first.")
if not hasattr(self, "pipeline_outputs"):
self.load_bases()
# Get the parameters for this index
params = self.library_parameters[:, index]
# Get the photometry for this index
photometry = self.library_photometry[:, index]
# Get the filter codes
if len(override_filter_codes) > 0:
filter_codes = override_filter_codes
else:
filter_codes = self.library_filter_codes
filterset = FilterCollection(filter_codes, verbose=False)
# For each basis, look at which parameters for that basis and match to spectra.
combined_spectra = []
total_wavelengths = []
for i, base in enumerate(self.bases):
# Get the varying parameters for this base
base_params = {}
for i, param_name in enumerate(self.library_parameter_names):
# Extract the original parameter name without the base prefix
if "/" in param_name:
basis, orig_param = param_name.split("/")
else:
basis = base.model_name
orig_param = param_name
if basis == base.model_name:
base_params[orig_param] = params[i]
basis_params = list(self.pipeline_outputs[base.model_name]["properties"].keys())
total_mask = np.ones(
len(self.pipeline_outputs[base.model_name]["properties"][basis_params[0]]),
dtype=bool,
)
for key in base_params.keys():
if key not in basis_params:
continue
all_values = self.pipeline_outputs[base.model_name]["properties"][key]
_i = all_values == base_params[key]
total_mask = np.logical_and(total_mask, _i)
assert (
np.sum(total_mask) == 1
), f"""Found {np.sum(total_mask)} matches for {base.model_name}
with parameters {base_params}. Expected 1 match."""
j = np.where(total_mask)[0][0]
# Get the spectra for this index
spectra = self.pipeline_outputs[base.model_name]["observed_spectra"][j]
flux_unit = spectra.units
# get the mass of the spectra and the expected mass to renormalise the spectra
mass = self.pipeline_outputs[base.model_name]["properties"]["mass"][j]
expected_mass = 10 ** params[1] * Msun
scaling_factor = expected_mass / mass
spectra = spectra * scaling_factor
# Append the spectra to the combined spectra
combined_spectra.append(spectra)
wavs = self.pipeline_outputs[base.model_name]["wavelengths"]
total_wavelengths.append(wavs)
for i, wavs in enumerate(total_wavelengths):
if i == 0:
continue
assert np.all(
wavs == total_wavelengths[0]
), f"""Wavelengths for base {i} do not match base 0.
{wavs} != {total_wavelengths[0]}"""
if len(self.bases) == 1:
weights = np.array((1.0,))
elif "weight_fraction" in self.library_parameter_names:
w_idx = self.library_parameter_names.index("weight_fraction")
w = float(params[w_idx])
weights = np.array((w, 1 - w))
else:
# Fallback: equal weights across bases if not encoded
weights = np.ones(len(self.bases)) / len(self.bases)
# Stack the spectra according to the combination weights.
# Spectra has shape (wav, n_bases)
combined_spectra = np.array(combined_spectra)
combined_spectra = combined_spectra * weights[:, np.newaxis]
combined_spectra_summed = np.sum(combined_spectra, axis=0)
# apply redshift
combined_spectra_summed = combined_spectra_summed
fig, ax = plt.subplots(figsize=(10, 6))
photwavs = filterset.pivot_lams
ax.scatter(
photwavs,
photometry,
label="Photometry",
color="red",
s=10,
path_effects=[PathEffects.withStroke(linewidth=4, foreground="white")],
)
ax.plot(
wavs * (1 + params[0]),
combined_spectra_summed,
label="Combined Spectra",
color="blue",
)
for i, base in enumerate(self.bases):
# Get the spectra for this index
spectra = combined_spectra[i]
ax.plot(
wavs * (1 + params[0]),
spectra,
label=f"{base.model_name} Spectra",
alpha=0.5,
linestyle="--",
)
ax.set_xlabel("Wavelength (AA)")
ax.set_xlim(0.8 * np.min(photwavs), 1.2 * np.max(photwavs))
ax.set_yscale("log")
ax.set_ylim(1e-2, None)
ax.legend()
def ab_to_jy(f):
return 1e9 * 10 ** (f / (-2.5) - 8.9)
def jy_to_ab(f):
f = f / 1e9
return -2.5 * np.log10(f) + 8.9
secax = ax.secondary_yaxis("right", functions=(jy_to_ab, ab_to_jy))
secax.yaxis.set_major_formatter(ScalarFormatter())
secax.yaxis.set_minor_formatter(ScalarFormatter())
secax.set_ylabel("Flux Density [AB mag]")
ax.set_xlabel(f"Wavelength ({wavs.units})")
ax.set_ylabel(f"Flux Density ({flux_unit})")
# Text box with parameters and values
textstr = "\n".join(
[f"{key}: {value:.2f}" for key, value in zip(self.library_parameter_names, params)]
)
props = dict(boxstyle="round", facecolor="wheat", alpha=0.5)
ax.text(
0.5,
0.98,
f"index: {index}\n" + textstr,
transform=ax.transAxes,
fontsize=10,
verticalalignment="top",
bbox=props,
horizontalalignment="center",
)
if show:
plt.show()
if save:
fig.savefig(
f"{self.out_dir}/{self.out_name}_{index}.png",
dpi=300,
bbox_inches="tight",
)
plt.close(fig)
return fig
[docs]
def load_library_from_file(
self,
file_path: str,
) -> dict:
"""Load the library from a file.
Parameters
----------
file_path : str
Path to the HDF5 file containing the library data.
Returns:
-------
dict
Dictionary containing the library data.
"""
with h5py.File(file_path, "r") as f:
if isinstance(f.attrs["FilterCodes"], (bytes, str)) and str(
f.attrs["FilterCodes"]
).startswith("/Grid/FilterCodes"):
fcodes = f["Grid/FilterCodes"][()].astype(str)
else:
fcodes = f.attrs["FilterCodes"]
library_data = {
"parameters": f["Grid"]["Parameters"][()],
"parameter_names": f.attrs["ParameterNames"],
"filter_codes": fcodes,
}
if "Photometry" in f["Grid"]:
library_data["photometry"] = f["Grid"]["Photometry"][()]
self.library_photometry = library_data["photometry"]
if "Spectra" in f["Grid"]:
library_data["spectra"] = f["Grid"]["Spectra"][()]
self.library_spectra = library_data["spectra"]
self.library_parameters = library_data["parameters"]
self.library_parameter_names = library_data["parameter_names"]
self.library_filter_codes = library_data["filter_codes"]
return library_data
def _validate_bases(self, pipeline_outputs, skip_inst=False) -> None:
# ===== VALIDATION =====
# Check all bases have the same number of galaxies
ngal = len(pipeline_outputs[self.bases[0].model_name]["properties"]["mass"])
for i, base in enumerate(self.bases):
model_name = base.model_name
if len(pipeline_outputs[model_name]["properties"]["mass"]) != ngal:
raise ValueError(
f"""Base {i} has different number of galaxies to base 0.
Cannot combine bases with different number of galaxies."""
)
# Validate input arrays
for array_name, array in [
("redshifts", self.redshifts),
("log_stellar_masses", self.log_stellar_masses),
("combination_weights", self.combination_weights),
]:
if not isinstance(array, (int, float)) and len(array) != ngal:
raise ValueError(
f"""{array_name} length {len(array)} does not match
number of galaxies {ngal}."""
)
if not skip_inst:
# Validate filters
base_filters = self.bases[0].instrument.filters.filter_codes
for i, base in enumerate(self.bases):
if base.instrument.filters.filter_codes != base_filters:
raise ValueError(
f"""Base {i} has different filters to base 0.
Cannot combine bases with different filters."""
)
[docs]
def create_full_library(
self,
override_instrument: Union[Instrument, None] = None,
save: bool = True,
overload_out_name: str = "",
overwrite: bool = False,
spectral_mode=False,
) -> None:
"""Create a complete grid of SEDs by combining galaxy bases.
This method handles both single-base and multi-base scenarios,
properly scaling according to masses and weights.
It constructs a grid without sampling, generating the full set of
combinations for all specified parameters.
Parameters
----------
override_instrument : Instrument, optional
If provided, use these filters instead of those in the bases
save : bool, default=True
Whether to save the library to disk
overload_out_name : str, default=''
Custom filename for saving the library
overwrite : bool, default=False
Whether to overwrite existing files
"""
# Validate initialization conditions
if self.draw_parameter_combinations:
raise AssertionError(
"""Cannot create full grid with draw_parameter_combinations
set to True. Set to False to create full grid."""
)
# Load base model data
pipeline_outputs = self.load_bases(load_spectra=spectral_mode)
self._validate_bases(pipeline_outputs, skip_inst=spectral_mode)
base_filters = self.bases[0].instrument.filters.filter_codes
ngal = len(pipeline_outputs[self.bases[0].model_name]["properties"]["mass"])
# Determine which filters to use
if override_instrument is not None:
# Ensure all requested filters exist in the base
for filter_code in override_instrument.filters.filter_codes:
if filter_code not in base_filters:
raise ValueError(
f"""Filter {filter_code} not found in base filters.
Cannot override instrument."""
)
filter_codes = override_instrument.filters.filter_codes
else:
filter_codes = base_filters
# Strip path info from filter codes
# filter_codes = [code.split("/")[-1] for code in filter_codes]
# ===== PARAMETER SETUP =====
# Set up parameter names
ignore_keys = ["redshift"]
param_columns = ["redshift", "log_mass"]
# Add weight fraction parameter for multiple bases
is_multi_base = len(self.bases) > 1
if is_multi_base:
param_columns.append("weight_fraction")
# Set up per-base parameter tracking
total_property_names = {}
for base in self.bases:
model_name = base.model_name
# Track parameters unique to this base
if len(self.bases) > 1:
varying_params = [
f"{model_name}/{param}"
for param in base.varying_param_names
if param not in ignore_keys
]
else:
varying_params = [
param for param in base.varying_param_names if param not in ignore_keys
]
total_property_names[model_name] = varying_params
param_columns.extend(varying_params)
# Rename parameter keys in the pipeline outputs
params = pipeline_outputs[model_name]["properties"]
for param in base.varying_param_names:
if param not in ignore_keys and param in params:
params[f"{model_name}/{param}"] = params[param]
# ===== SUPPLEMENTARY PARAMETER SETUP =====
# Get the list of supplementary parameters from the first base
supp_param_keys = list(pipeline_outputs[self.bases[0].model_name]["supp_properties"].keys())
# Validate all bases have the same supplementary parameters
for key in supp_param_keys:
for base in self.bases:
if key not in pipeline_outputs[base.model_name]["supp_properties"]:
raise ValueError(
f"""Supplementary parameter {key}
not found in base {base.model_name}."""
)
# Process supplementary parameters for each base
supp_params = {}
supp_param_units = {}
for i, base in enumerate(self.bases):
model_name = base.model_name
supp_params[model_name] = {}
for key in supp_param_keys:
value = pipeline_outputs[model_name]["supp_properties"][key]
# Handle nested dictionary case (typically for emission models)
if isinstance(value, dict):
emission_key = self.base_emission_model_keys[i]
if emission_key not in value:
raise ValueError(
f"""Emission model key {emission_key} not
found in supplementary parameter {key}."""
)
supp_params[model_name][key] = value[emission_key]
else:
supp_params[model_name][key] = value
# ===== PROCESS EACH galaxy =====
all_outputs = []
all_params = []
all_supp_params = []
for pos in range(ngal):
redshift = (
self.redshifts[pos]
if isinstance(self.redshifts, (np.ndarray, list, tuple))
else self.redshifts
)
log_total_mass = self.log_stellar_masses[pos]
# mass in solar masses
total_mass = 10**log_total_mass
#
weights = self.combination_weights[pos]
# Create per-base data structures to hold galaxy information at this redshift
base_data = []
# For each base, extract galaxies at this redshift
# and calculate scaling factors
for i, base in enumerate(self.bases):
model_name = base.model_name
model_output = pipeline_outputs[model_name]
masses = np.array([model_output["properties"]["mass"][pos]])
if is_multi_base:
# Multiple bases: scale by weight for this base
scaling_factors = weights[i] * total_mass / masses
else:
# Single base: scale by total mass directly
scaling_factors = total_mass / masses
scaling_factors = (
scaling_factors.value
if isinstance(scaling_factors, unyt_array)
else scaling_factors
)
if not spectral_mode:
# Get photometry for all filters and scale it
photometry = np.array(
[model_output["observed_photometry"][code][pos] for code in filter_codes],
dtype=np.float32,
)
scaled_phot = photometry * scaling_factors
else:
# For spectral mode, use the observed spectra
scaled_phot = np.array(
[model_output["observed_spectra"][pos]],
dtype=np.float32,
)
# Scale the spectra by the scaling factor
scaled_phot = scaled_phot * scaling_factors[:, np.newaxis]
# Extract parameters for this base
params_dict = {}
for param_name in total_property_names.get(model_name, []):
original_param = param_name.split("/")[-1]
# Check both possible parameter locations
if f"{model_name}/{original_param}" in model_output["properties"]:
params_dict[param_name] = model_output["properties"][
f"{model_name}/{original_param}"
][pos]
elif original_param in model_output["properties"]:
params_dict[param_name] = model_output["properties"][original_param][pos]
# Process supplementary parameters
supp_dict = {}
for key, value in supp_params[model_name].items():
if isinstance(value, dict):
supp_dict[key] = {}
for subkey, subvalue in value.items():
if check_scaling(subvalue):
supp_dict[key][subkey] = subvalue[pos] * scaling_factors
elif check_log_scaling(subvalue):
supp_dict[key][subkey] = subvalue[pos] + np.log10(scaling_factors)
else:
supp_dict[key][subkey] = subvalue[pos]
supp_param_units[key] = (
str(subvalue.units)
if isinstance(subvalue, unyt_array)
else "dimensionless"
)
else:
if check_scaling(value):
supp_dict[key] = value[pos] * scaling_factors
elif check_log_scaling(value):
supp_dict[key] = value[pos] + np.log10(scaling_factors)
else:
supp_dict[key] = value[pos]
supp_param_units[key] = (
str(value.units) if isinstance(value, unyt_array) else "dimensionless"
)
# Store all relevant data for this base
base_data.append(
{
"photometry": scaled_phot,
"params": params_dict,
"supp_params": supp_dict,
"num_items": len(masses),
}
)
# Decide how to process based on number of bases
if not is_multi_base:
# SINGLE BASE CASE - just use the data directly
base = base_data[0]
num_items = base["num_items"]
# Use photometry directly
output = base["photometry"]
output = np.squeeze(output)
output_array = output[:, np.newaxis] if output.ndim == 1 else output
# Create parameter array
params_array = np.zeros((len(param_columns), num_items))
param_idx = 0
param_units = []
# Fill standard parameters
params_array[param_idx, :] = redshift
param_units.append("dimensionless")
param_idx += 1
params_array[param_idx, :] = log_total_mass
param_units.append("log10_Msun")
param_idx += 1
# Fill varying parameters
for param_name in total_property_names.get(self.bases[0].model_name, []):
if param_name in base["params"]:
params_array[param_idx, :] = base["params"][param_name]
if param_name.split("/")[-1].lower() in UNIT_DICT.keys():
param_units.append(UNIT_DICT[param_name.split("/")[-1].lower()])
else:
param_units.append(
str(base["params"][param_name].units)
if isinstance(base["params"][param_name], unyt_array)
else "dimensionless"
)
param_idx += 1
# Process supplementary parameters
supp_array = np.zeros((len(supp_param_keys), num_items))
for k, param_name in enumerate(supp_param_keys):
data = base["supp_params"][param_name]
if isinstance(data, unyt_array):
data = data.value
supp_array[k, :] = data
else:
# MULTI-BASE CASE - create combinations
# Get the number of items per base
items_per_base = [base["num_items"] for base in base_data]
# Create meshgrid of indices for combination
mesh_indices = np.meshgrid(*[np.arange(n) for n in items_per_base], indexing="ij")
# Reshape to get all combinations:
# array of shape (total_combinations, num_bases)
combinations = np.array([indices.flatten() for indices in mesh_indices]).T
num_combinations = len(combinations)
# Create output arrays
n = len(filter_codes) if not spectral_mode else base_data[0]["photometry"].shape[1]
output_array = np.zeros((n, num_combinations))
params_array = np.zeros((len(param_columns), num_combinations))
supp_array = np.zeros((len(supp_param_keys), num_combinations))
param_units = []
# Fill arrays
for i, combo_indices in enumerate(combinations):
# Fill photometry - combine the weighted contributions
for j, base_idx in enumerate(combo_indices):
output_array[:, i] += base_data[j]["photometry"][base_idx]
# Fill parameters
param_idx = 0
# Standard parameters first
params_array[param_idx, i] = redshift
if i == 0:
param_units.append("dimensionless")
param_idx += 1
params_array[param_idx, i] = log_total_mass
if i == 0:
param_units.append("log10_Msun")
param_idx += 1
params_array[param_idx, i] = weights[0] # weight fraction
param_idx += 1
if i == 0:
param_units.append("dimensionless")
# Add all varying parameters from each base
for j, base in enumerate(self.bases):
model_name = base.model_name
for param_name in total_property_names.get(model_name, []):
if param_name in base_data[j]["params"]:
params_array[param_idx, i] = base_data[j]["params"][param_name]
if param_name.split("/")[-1].lower() in UNIT_DICT.keys():
param_units.append(UNIT_DICT[param_name.split("/")[-1].lower()])
else:
param_units.append(
str(base_data[j]["params"][param_name].units)
if isinstance(
base_data[j]["params"][param_name], unyt_array
)
else "dimensionless"
)
param_idx += 1
# Fill supplementary parameters
for k, param_name in enumerate(supp_param_keys):
total_value = 0
for j, base_idx in enumerate(combo_indices):
data = base_data[j]["supp_params"][param_name]
if isinstance(data, unyt_array):
total_value += data[base_idx].value
elif hasattr(data, "__len__") and len(data) > base_idx:
total_value += data[base_idx]
else:
total_value += data # Scalar case
supp_array[k, i] = total_value
# Append results for this redshift to the global arrays
# Give output_array shape (n, 1) if it is 1D
# if output_array.ndim == 1:
# output_array = output_array[:, np.newaxis]
all_outputs.append(output_array)
all_params.append(params_array)
all_supp_params.append(supp_array)
# ===== COMBINE RESULTS =====
combined_outputs = np.hstack(all_outputs)
combined_params = np.hstack(all_params)
combined_supp_params = np.hstack(all_supp_params)
# Convert units dict to list
supp_param_units_list = [
supp_param_units.get(name, str(dimensionless)) for name in supp_param_keys
]
# Print summary
logger.info(f"Combined outputs shape: {combined_outputs.shape}")
logger.info(f"Combined parameters shape: {combined_params.shape}")
logger.info(f"Combined supplementary parameters shape: {combined_supp_params.shape}")
if not spectral_mode:
logger.info(f"Filter codes: {filter_codes}")
else:
logger.info("Spectral mode enabled, using wavelengths as filter codes.")
logger.info(f"Parameter names: {param_columns}")
logger.info(f"Parameter units: {param_units}")
# Check combined_outputs is 2D
if combined_outputs.ndim == 1:
raise ValueError(
"Combined outputs should be a 2D array with \
shape (n_filters, n_galaxies)."
)
assert len(param_units) == len(param_columns), (
"Parameter units length does not match parameter columns length."
f"Expected {len(param_columns)}, got {len(param_units)}."
)
if not spectral_mode:
assert combined_outputs.shape[0] == len(filter_codes), (
"Output photometry shape does not match number of filters."
f"Expected {len(filter_codes)}, got {combined_outputs.shape[0]}."
)
assert combined_params.shape[0] == len(param_columns), (
"Output parameters shape does not match number of parameter columns."
f"Expected {len(param_columns)}, got {combined_params.shape[0]}."
)
assert combined_supp_params.shape[0] == len(supp_param_keys), (
"Output supplementary parameters shape does not match number of keys."
f"Expected {len(supp_param_keys)}, got {combined_supp_params.shape[0]}."
)
# Create output dictionary
out = {
"parameters": combined_params,
"parameter_names": param_columns,
"supplementary_parameters": combined_supp_params,
"supplementary_parameter_names": supp_param_keys,
"supplementary_parameter_units": supp_param_units_list,
"parameter_units": param_units,
}
# Update object attributes
if spectral_mode:
out["spectra"] = combined_outputs
self.library_spectra = combined_outputs
# 'Grid filter codes' can just be the wavelength array here
self.library_filter_codes = (
pipeline_outputs[self.bases[0].model_name]["wavelengths"].to("um").value
)
out["filter_codes"] = self.library_filter_codes
else:
out["photometry"] = combined_outputs
out["filter_codes"] = filter_codes
self.library_photometry = combined_outputs
self.library_filter_codes = filter_codes
self.library_parameters = combined_params
self.library_parameter_names = param_columns
self.library_supplementary_parameters = combined_supp_params
self.library_supplementary_parameter_names = supp_param_keys
self.library_supplementary_parameter_units = supp_param_units_list
self.library_parameter_units = param_units
# Save results if requested
if save:
self.save_library(out, overload_out_name=overload_out_name, overwrite=overwrite)
[docs]
def create_spectral_grid(
self,
save: bool = True,
overload_out_name: str = "",
overwrite: bool = False,
) -> dict:
"""Creates a parameter grid for spectroscopic observations.
Wrapper for `create_full_library`, but specifically for spectroscopic outputs,
to match e.g. NIRSpec IFU, or DESI etc. The only spectral sampling
currently supported is the one used in the instrument/grid combination.
Parameters
----------
save : bool, optional
If True, saves the library to a file.
overload_out_name : str, optional
If provided, overrides the output name for the library.
overwrite : bool, optional
If True, overwrites the existing grid file if it exists.
Returns:
-------
dict
A dictionary containing the library of SEDs, photometry, and properties.
"""
return self.create_full_library(
override_instrument=None,
save=save,
overload_out_name=overload_out_name,
overwrite=overwrite,
spectral_mode=True,
)
[docs]
class GalaxySimulator(object):
"""Class for simulating photometry/spectra of galaxies.
This class is designed to work with a grid of galaxies, an instrument,
and an emission model. It can simulate photometry and spectra based on
the provided star formation history (SFH) and metallicity distribution
(ZDist) models. The class can also handle uncertainties in the photometry
using an empirical uncertainty model.
It supports various configurations such as normalization methods,
output types, and inclusion of photometric errors.
It can also apply noise models to the simulated photometry.
"""
def __init__(
self,
sfh_model: Type[SFH.Common],
zdist_model: Type[ZDist.Common],
grid: Grid,
instrument: Instrument,
emission_model: EmissionModel,
emission_model_key: str,
emitter_params: dict = None,
cosmo: Cosmology = Planck18,
param_order: Union[None, list] = None,
param_units: dict = None,
param_transforms: dict[callable] = None,
out_flux_unit: str = "nJy",
required_keys=["redshift", "log_mass"],
extra_functions: List[callable] = None,
normalize_method: str = None,
output_type: str = "photo_fnu",
include_phot_errors: bool = False,
depths: Union[np.ndarray, unyt_array] = None,
depth_sigma: int = 5,
noise_models: Union[None, Dict[str, EmpiricalUncertaintyModel]] = None,
fixed_params: dict = None,
photometry_to_remove=None,
ignore_params: list = None,
ignore_scatter: bool = False,
return_type="array",
device: str = "cpu",
) -> None:
"""Parameters
----------
sfh_model : Type[SFH.Common]
SFH model to use. Must be a subclass of SFH.Common.
zdist_model : Type[ZDist.Common]
ZDist model to use. Must be a subclass of ZDist.Common.
grid : Grid
Grid object to use. Must be a subclass of Grid.
instrument : Instrument
Instrument object to use for photometry/spectra.
Must be a subclass of Instrument.
emission_model : EmissionModel
Emission model to use. Must be a subclass of EmissionModel.
emission_model_key : str
Emission model key to use e.g. 'total', 'intrinsic'.
emitter_params : dict
Dictionary of parameters to pass to the emitter.
Keys are 'stellar' and 'galaxy'.
cosmo : Cosmology = Planck18
Cosmology object to use. Must be a subclass of Cosmology.
param_order : Union[None, list] = None
Order of parameters to use.
If None, will use the order of the keys in the params dictionary.
param_units : dict
Dictionary of parameter units to use.
Keys are the parameter names and values are the units.
param_transforms : dict
Dictionary of parameter transforms to use.
Keys are the parameter names and values are the transforms functions.
Can be used to fit say Av when model requires tau_v, or for unit conversion.
out_flux_unit : str
Output flux unit to use. Default is 'nJy'. Can be 'AB', 'Jy', 'nJy', 'uJy',...
required_keys : list
List of required keys to use. Default is ['redshift', 'log_mass'].
Typically this can be ignored.
extra_functions : List[callable]
List of extra functions to call after the simulation.
Inputs to each function are determined by the function signature and can be
any of the following:
galaxy, spec, fluxes, sed, stars, emission_model, cosmo.
The output of each function is appended to an array
and returned in the output tuple.
normalize_method : str
Normalization method to use. If None, no normalization is applied.
Only works on photo_fnu currently.
Currently can be either the name of a filter, or a callable function
with the same rules as extra_functions.
output_type : str
Output type to use. Default is 'photo_fnu'. Can be 'photo_fnu', 'photo_lnu',
'fnu', 'lnu', or a list of any of these.
include_phot_errors : bool
Whether to include photometric errors in the output. Default is False.
depths : Union[np.ndarray, unyt_array] = None
Depths to use for the photometry. If None, no depths are applied.
Default is None.
depth_sigma : int
Sigma to use for the depths. Default is 5.
This is used to calculate the 1 sigma error in nJy.
If depths is None, this is ignored.
noise_models : Union[None, Dict[str, EmpiricalUncertaintyModel]] = None
List of noise models to use for the photometry. If None, no noise is applied.
Default is None.
fixed_params : dict
Dictionary of fixed parameters to use.
Keys are the parameter names and values are the fixed values.
This is used to fix parameters in the simulation.
If None, no parameters are fixed. Default is None.
photometry_to_remove : list
List of photometry to remove from the output.
This is used to remove specific filters from the output.
If None, no photometry is removed. Default is None.
Should match filter codes in the instrument filters.
ignore_params : list
List of parameters which are sampled which won't be checked for use against the model.
ignore_scatter : bool
If True, ignore scatter in the empirical uncertainty model. Default is False.
return_type : str
Return type of the simulate function. Default is 'array'.
Can be 'array' or 'tensor'.
device : str
Device to use for tensor calculations. Default is 'cpu'.
Can be 'cpu' or 'cuda'. Only used if return_type is 'tensor'.
"""
if fixed_params is None:
fixed_params = {}
if param_units is None:
param_units = {}
if param_transforms is None:
param_transforms = {}
if emitter_params is None:
emitter_params = {"stellar": [], "galaxy": []}
if extra_functions is None:
extra_functions = []
if photometry_to_remove is None:
photometry_to_remove = []
if not isinstance(required_keys, list):
raise TypeError(f"required_keys must be a list. Got {type(required_keys)} instead.")
if ignore_params is None:
ignore_params = []
assert isinstance(grid, Grid), f"Grid must be a subclass of Grid. Got {type(grid)} instead."
self.grid = grid
assert isinstance(instrument, Instrument), f"""Instrument must be a subclass of Instrument.
Got {type(instrument)} instead."""
assert isinstance(
emission_model, EmissionModel
), f"""Emission model must be a subclass of EmissionModel.
Got {type(emission_model)} instead."""
self.emission_model = emission_model
self.emission_model_key = emission_model_key
self.instrument = instrument
self.cosmo = cosmo
self.param_order = param_order
self.param_units = param_units
self.param_transforms = param_transforms
self.out_flux_unit = out_flux_unit
self.required_keys = required_keys
self.extra_functions = extra_functions
self.normalize_method = normalize_method
self.fixed_params = fixed_params
self.depths = depths
self.ignore_params = ignore_params
self.ignore_scatter = ignore_scatter
self.return_type = return_type
self.device = device
self.unused_params = []
self.reported_unused = False
if len(photometry_to_remove) > 0:
self.update_photo_filters(
photometry_to_remove=photometry_to_remove, photometry_to_add=None
)
if noise_models is not None:
assert isinstance(noise_models, dict), (
f"Noise models must be a dictionary. Got {type(noise_models)} instead."
)
# Check all filters in noise_models are in the instrument filters
for filter_code in instrument.filters.filter_codes:
if filter_code not in noise_models:
raise ValueError(
f"""Filter {filter_code} not found in noise models.
Cannot apply noise models to photometry."""
)
self.noise_models = noise_models
self.include_phot_errors = include_phot_errors
assert not (
self.depths is not None and self.noise_models is not None
), """Cannot use both depths and noise models at
the same time. Choose one or the other."""
assert not (
self.depths is None and self.noise_models is None and self.include_phot_errors
), """Cannot include photometric errors without depths or noise models.
Set include_phot_errors to False or provide depths or noise models."""
if depths is not None:
assert len(depths) == len(
instrument.filters.filter_codes
), f"""Depths array length {len(depths)} does not match number of filters
{len(instrument.filters.filter_codes)}. Cannot create photometry."""
self.depths = depths
self.depth_sigma = depth_sigma
assert isinstance(output_type, list) or output_type in [
"photo_fnu",
"photo_lnu",
"fnu",
"lnu",
], f"""Output type {output_type} not recognised.
Must be one of ['photo_fnu', 'photo_lnu', 'fnu', 'lnu']"""
if not isinstance(output_type, list):
output_type = [output_type]
self.output_type = output_type
self.sfh_model = sfh_model
sig = inspect.signature(sfh_model).parameters
self.sfh_params = []
self.optional_sfh_params = []
for key in sig.keys():
if sig[key].default != inspect._empty:
self.optional_sfh_params.append(key)
else:
self.sfh_params.append(key)
self.zdist_model = zdist_model
sig = inspect.signature(zdist_model).parameters
self.zdist_params = []
self.optional_zdist_params = []
for key in sig.keys():
if sig[key].default != inspect._empty:
self.optional_zdist_params.append(key)
else:
self.zdist_params.append(key)
self.emitter_params = emitter_params
self.num_emitter_params = np.sum([len(emitter_params[key]) for key in emitter_params])
self.emission_model.save_spectra(emission_model_key)
self.total_possible_keys = (
self.sfh_params
+ self.zdist_params
+ self.optional_sfh_params
+ self.optional_zdist_params
+ required_keys
)
[docs]
def update_photo_filters(self, photometry_to_remove=None, photometry_to_add=None):
"""Update the photometric filters used in the simulation.
This method allows you to modify the set of photometric filters
used in the simulation by removing or adding specific filters.
It updates the instrument's filter collection accordingly.
Parameters
----------
photometry_to_remove : list, optional
List of filter codes to remove from the current set of filters.
If None, no filters will be removed.
photometry_to_add : list, optional
List of filter codes to add to the current set of filters.
If None, no filters will be added.
"""
if photometry_to_remove is None:
photometry_to_remove = []
if photometry_to_add is None:
photometry_to_add = []
filter_codes = self.instrument.filters.filter_codes
new_filters = []
# Remove specified filters
for filter_code in filter_codes:
if filter_code not in photometry_to_remove:
new_filters.append(filter_code)
# Add specified filters if they are not already present
for filter_code in photometry_to_add:
if filter_code not in new_filters:
new_filters.append(filter_code)
self.instrument.filters = FilterCollection(filter_codes=new_filters)
logger.info(f"Updated filters: {self.instrument.filters.filter_codes}")
[docs]
@classmethod
def from_library(
cls,
library_path: str,
override_synthesizer_grid_dir: Union[None, str, bool] = True,
override_emission_model: Union[None, EmissionModel] = None,
**kwargs,
):
"""Create a GalaxySimulator from a library file.
This method reads a grid file in HDF5 format, extracts the necessary
components such as the grid, instrument, cosmology, SFH model,
and emission model, and then instantiates a GalaxySimulator object.
Parameters
----------
library_path : str
Path to the library file in HDF5 format.
override_synthesizer_grid_dir : Union[None, str], optional
If provided, this directory will override the synthesizer grid directory
specified in the library file. This is useful for using a model
on a different computer or environment where the grid directory
is not the same as the one used to create the library file.
If True, and the library_dir saved in the file does not exist,
it will check for a synthesizer_grid_DIR environment variable
and use that as the library directory. If a string is provided,
it will use that as the grid directory.
override_emission_model : Union[None, EmissionModel], optional
If provided, this emission model will override the one in the library file.
**kwargs : dict
Additional keyword arguments to pass to the GalaxySimulator constructor.
Returns:
-------
GalaxySimulator
An instance of GalaxySimulator initialized with the library data.
"""
# Open h5py, look for 'Model' and instatiate by reading the library.
if not os.path.exists(library_path):
raise FileNotFoundError(
f"Library path {library_path} does not exist. Cannot create GalaxySimulator."
)
with h5py.File(library_path, "r") as f:
if "Model" not in f:
raise ValueError(
f"""Library file {library_path} does not contain 'Model' group.
Cannot create GalaxySimulator."""
)
model_group = f["Model"]
# Step 1. Make library
lam = unyt_array(
model_group["Instrument/Filters/Header/Wavelengths"][:], units=Angstrom
)
grid_name = model_group.attrs.get("grid_name", None)
grid_dir = model_group.attrs.get("grid_dir", None)
if override_synthesizer_grid_dir is not None and not os.path.exists(grid_dir):
if isinstance(override_synthesizer_grid_dir, str) and os.path.exists(
override_synthesizer_grid_dir
):
grid_dir = override_synthesizer_grid_dir
else:
# Check for synthesizer_grid_DIR environment variable
grid_dir = os.getenv("SYNTHESIZER_GRID_DIR", None)
if grid_dir is None:
from synthesizer import get_grids_dir
grid_dir = str(get_grids_dir())
if isinstance(override_synthesizer_grid_dir, str) and (
override_synthesizer_grid_dir.endswith(".hdf5")
or override_synthesizer_grid_dir.endswith(".h5")
):
logger.info("Overriding internal library name from provided file path.")
grid_name = (
os.path.basename(override_synthesizer_grid_dir)
.replace(".hdf5", "")
.replace(".h5", "")
)
if isinstance(grid_dir, str) and (
grid_dir.endswith(".hdf5") or grid_dir.endswith(".h5")
):
logger.info("Overriding internal library name from provided file path.")
grid_name = os.path.basename(grid_dir).replace(".hdf5", "").replace(".h5", "")
grid_dir = os.path.dirname(grid_dir)
grid = Grid(grid_name, grid_dir) # new_lam=lam)
# Step 2. Make instrument
instrument = Instrument._from_hdf5(model_group["Instrument"])
# Step 3 - recreate cosmology
cosmo_mapping = model_group.attrs.get("cosmology", None)
try:
cosmo = Cosmology.from_format(cosmo_mapping, format="yaml")
except Exception:
from astropy.cosmology import Planck18 as cosmo
logger.warning("Failed to load cosmology from HDF5. Using Planck18 instead.")
# Step 4 - Collect sfh_model
sfh_model_name = model_group.attrs.get("sfh_class", None)
sfh_model = getattr(SFH, sfh_model_name, None)
if sfh_model is None:
raise ValueError(
f"""SFH model {sfh_model_name} not found in SFH module.
Cannot create GalaxySimulator."""
)
zdist_model_name = model_group.attrs.get("metallicity_distribution_class", None)
zdist_model = getattr(ZDist, zdist_model_name, None)
if zdist_model is None:
raise ValueError(
f"""ZDist model {zdist_model_name} not found in ZDist module.
Cannot create GalaxySimulator."""
)
# recreate emission model
em_group = model_group["EmissionModel"]
emission_model_key = model_group.attrs.get("emission_model_key", "total")
if "emission_model_key" in kwargs:
emission_model_key = kwargs.pop("emission_model_key")
if override_emission_model is not None:
emission_model = override_emission_model
else:
emission_model_name = em_group.attrs["name"]
import synthesizer.emission_models as em
import synthesizer.emission_models.attenuation as dm
import synthesizer.emission_models.generators.dust as dem
from synthesizer.emission_models.stellar.pacman_model import (
PacmanEmissionNoEscapedNoDust,
PacmanEmissionNoEscapedWithDust,
PacmanEmissionWithEscapedNoDust,
PacmanEmissionWithEscapedWithDust,
)
extra = {
"PacmanEmissionNoEscapedWithDust": PacmanEmissionNoEscapedWithDust,
"PacmanEmissionWithEscapedNoDust": PacmanEmissionWithEscapedNoDust,
"PacmanEmissionWithEscapedWithDust": PacmanEmissionWithEscapedWithDust,
"PacmanEmissionNoEscapedNoDust": PacmanEmissionNoEscapedNoDust,
}
emission_model = getattr(em, emission_model_name, None)
if emission_model is None and emission_model_name in extra:
emission_model = extra[emission_model_name]
if emission_model is None:
raise ValueError(
f"Emission model {emission_model_name} not found in synthesizer.emission_models. " # noqa: E501
"Cannot create GalaxySimulator."
)
if "dust_law" in em_group.attrs:
dust_model_name = em_group.attrs["dust_law"]
dust_model = getattr(dm, dust_model_name, None)
if dust_model is None:
raise ValueError(
f"Dust model {dust_model_name} not found in {dm}. Cannot create GalaxySimulator." # noqa: E501
)
dust_model_params = {}
dust_param_keys = em_group.attrs["dust_attenuation_keys"]
dust_param_values = em_group.attrs["dust_attenuation_values"]
dust_param_units = em_group.attrs["dust_attenuation_units"]
for key, value, unit in zip(
dust_param_keys, dust_param_values, dust_param_units
):
if unit != "":
dust_model_params[key] = unyt_array(value, unit)
elif isinstance(value, str) and (
value.isnumeric() or value.replace(".", "", 1).isdigit()
):
dust_model_params[key] = float(value)
else:
dust_model_params[key] = value
dust_model = dust_model(**dust_model_params)
else:
dust_model = None
if "dust_emission" in em_group.attrs:
dust_emission_model_name = em_group.attrs["dust_emission"]
dust_emission_model = getattr(dem, dust_emission_model_name, None)
if dust_emission_model is None:
raise ValueError(
f"Dust emission model {dust_emission_model_name} not found"
" in synthesizer.emission_models. Cannot create from_library."
)
dust_emission_model_params = {}
dust_emission_param_keys = em_group.attrs["dust_emission_keys"]
dust_emission_param_values = em_group.attrs["dust_emission_values"]
dust_emission_param_units = em_group.attrs["dust_emission_units"]
for key, value, unit in zip(
dust_emission_param_keys,
dust_emission_param_values,
dust_emission_param_units,
):
if unit != "":
dust_emission_model_params[key] = unyt_array(value, unit)
else:
dust_emission_model_params[key] = value
cmb = dust_emission_model_params.pop("cmb_factor", None)
dust_emission_model_params.pop("temperature_z", None)
# if cmb is not None:
# dust_emission_model_params["cmb_heating"] = cmb != 1
dust_emission_model = dust_emission_model(**dust_emission_model_params)
else:
dust_emission_model = None
em_keys = em_group.attrs["parameter_keys"]
em_values = em_group.attrs["parameter_values"]
em_units = em_group.attrs["parameter_units"]
emission_model_params = {}
for key, value, unit in zip(em_keys, em_values, em_units):
if unit != "":
emission_model_params[key] = unyt_array(value, unit)
elif isinstance(value, str) and (
value.isnumeric() or value.replace(".", "", 1).isdigit()
):
emission_model_params[key] = float(value)
else:
emission_model_params[key] = value
if dust_model is not None:
emission_model_params["dust_curve"] = dust_model
if dust_emission_model is not None:
emission_model_params["dust_emission_model"] = dust_emission_model
# get arguments from inspect of emission_model
sig = inspect.signature(emission_model).parameters
if "dust_emission" in sig and "dust_emission_model" not in sig:
emission_model_params["dust_emission"] = dust_emission_model
emission_model_params.pop("dust_emission_model", None)
emission_model = emission_model(
grid=grid,
**emission_model_params,
)
# Step 5 - Collect emitter params
stellar_params = model_group.attrs.get("stellar_params", {})
emitter_params = {
"stellar": stellar_params,
"galaxy": {},
}
# Step 7 - work out order and units
param_order = f.attrs["ParameterNames"]
units = f.attrs.get("ParameterUnits")
ignore = ["dimensionless", "log", "mag"]
param_units = {}
for p, u in zip(param_order, units):
skip = False
for i in ignore:
if i in u:
skip = True
break
if not skip:
param_units[p] = Unit(u)
# Step 8 - fixed_params
fixed_param_names = model_group.attrs.get("fixed_param_names", [])
fixed_param_values = model_group.attrs.get("fixed_param_values", [])
fixed_param_units = model_group.attrs.get("fixed_param_units", [])
fixed_params = {}
for name, value, unit in zip(fixed_param_names, fixed_param_values, fixed_param_units):
if unit != "":
fixed_params[name] = unyt_array(value, unit)
else:
fixed_params[name] = value
# Step 9 - Collect param transforms.
# TEMP
param_transforms = {}
if "Transforms" in model_group:
transform_group = model_group["Transforms"]
for key in transform_group.keys():
# need to evaluate the function
code = transform_group[key][()].decode("utf-8")
code = f"\n{code}\n"
# Remove excess indentation
code = inspect.cleandoc(code)
try:
func = exec(code, globals(), locals())
func_name = code.split("def ")[-1].split("(")[0]
func = locals()[func_name]
except Exception as e:
print(e)
logger.error(f"Error evaluating transform function for {key}: {e}")
continue
param_transforms[key] = func
if "new_parameter_name" in transform_group[key].attrs:
new_key = transform_group[key].attrs["new_parameter_name"]
param_transforms[key] = (new_key, func)
dict_create = dict(
sfh_model=sfh_model,
zdist_model=zdist_model,
grid=grid,
emission_model=emission_model,
emission_model_key=emission_model_key,
instrument=instrument,
cosmo=cosmo,
emitter_params=emitter_params,
param_order=param_order,
param_units=param_units,
param_transforms=param_transforms,
fixed_params=fixed_params,
)
dict_create.update(kwargs)
return cls(**dict_create)
[docs]
def simulate(self, params):
"""Simulate photometry from the given parameters.
Parameters
----------
params : dict or array-like
Dictionary of parameters or an array-like object with the parameters
in the order specified by self.param_order.
If self.param_order is None, params must be a dictionary.
Returns:
-------
dict
Dictionary of photometry outputs. Keys are the output types specified
in self.output_type. Values are the corresponding photometry arrays.
"""
params = copy.deepcopy(params)
if isinstance(params, torch.Tensor):
params = params.detach().cpu().numpy()
params = np.squeeze(params)
if not isinstance(params, dict):
if self.param_order is None:
raise ValueError(
"""simulate() input requires a dictionary unless param_order is set.
Cannot create photometry."""
)
assert len(params) == len(
self.param_order
), f"""Parameter array length {len(params)} does not match parameter order
length {len(self.param_order)}. Cannot create photometry."""
params = {i: j for i, j in zip(self.param_order, params)}
params.update(self.fixed_params)
for key in self.required_keys:
if key not in params:
raise ValueError(f"Missing required parameter {key}. Cannot create photometry.")
mass = 10 ** params["log_mass"] * Msun
# Check if we have sfh_params and zdist_params
for key in params:
if key in self.param_units:
params[key] = params[key] * self.param_units[key]
for key in self.param_transforms:
value = self.param_transforms[key]
# check if key is tuple - could be replacing multiple parameters
if isinstance(key, tuple):
name = value[0]
func = value[1]
param_subset = {k: params[k] for k in key if k in params}
try:
params[name] = func(**param_subset)
except Exception as e:
logger.error(f"Error applying transform {value} to {key}: {e}")
elif isinstance(value, tuple):
name = self.param_transforms[key][0]
func = self.param_transforms[key][1]
if key in params:
params[name] = func(params[key])
else:
try:
params[name] = func(params)
except Exception:
continue
# logger.error(f"Error applying transform {value} to {key}: {e}")
elif callable(value):
if key in params:
params[key] = value(params[key])
else:
params[key] = value(params)
# Check if we have all SFH and ZDist parameters
for key in self.sfh_params + self.zdist_params:
if key not in params:
raise ValueError(
f"""Missing required parameter {key} for SFH or ZDist.
Cannot create photometry."""
)
sfh_input = {i: params[i] for i in self.sfh_params}
for i in self.optional_sfh_params:
if i in params:
sfh_input[i] = params[i]
if len(sfh_input) <= 1:
logger.warning(
f"SFH input has only {len(sfh_input)} parameters. "
"Check that all required SFH parameters are provided."
)
sfh = self.sfh_model(
**sfh_input,
)
zdist_input = {i: params[i] for i in self.zdist_params}
for i in self.optional_zdist_params:
if i in params:
zdist_input[i] = params[i]
zdist = self.zdist_model(
**zdist_input,
)
# print({i: params[i] for i in self.sfh_params})
# print({i: params[i] for i in self.zdist_params})
# print({i: params[i] for i in self.optional_zdist_params if i in params})
# print({i: params[i] for i in self.optional_sfh_params if i in params})
# Get param names which aren't in the sfh or zdist models or the required keys
param_names = [i for i in params.keys() if i not in self.total_possible_keys]
# Check if any param names named here are in emitter param dictionry lusts
found_params = []
for key in param_names:
found = False
for param in self.emitter_params:
if key in self.emitter_params[param]:
found = True
break
if key in self.ignore_params:
found = True
break
if not found:
if key not in self.unused_params:
self.unused_params.append(key)
# logger.info(f"Emitter params are {self.emitter_params}")
# raise ValueError(
# f"Parameter {key} not found in emitter params.Cannot create photometry."
# )
else:
found_params.append(key)
if not self.reported_unused and len(self.unused_params) > 0:
logger.warn(
f"The following parameters are not used by the simulator: {self.unused_params}"
)
self.reported_unused = True
# Check we understand all the parameters
# assert len(found_params) - len(self.ignore_params) >= self.num_emitter_params, (
# f"Found {len(found_params)} parameters but expected {self.num_emitter_params}."
# "Cannot create photometry."
# )
stellar_keys = {}
if "stellar" in self.emitter_params:
for key in found_params:
stellar_keys[key] = params[key]
stars = Stars(
log10ages=self.grid.log10ages,
metallicities=self.grid.metallicity,
sf_hist=sfh,
metal_dist=zdist,
initial_mass=mass,
**stellar_keys,
)
galaxy_keys = {}
if "galaxy" in self.emitter_params:
for key in found_params:
galaxy_keys[key] = params[key]
galaxy = Galaxy(
stars=stars,
redshift=params["redshift"],
**galaxy_keys,
)
# Get the spectra for the galaxy
spec = galaxy.stars.get_spectra(self.emission_model)
outputs = {}
if "sfh" in self.output_type:
stars_sfh = stars.get_sfh()
stars_sfh = stars_sfh / np.diff(10 ** (self.grid.log10age), prepend=0) / yr
time = 10 ** (self.grid.log10age) * yr
time = time.to("Myr")
# Check if any NANS in SFH. If there are, print sfh
# if np.isnan(sfr).any():
# print(f"SFH has NANS: {sfh.parameters}")
outputs["sfh"] = stars_sfh
outputs["sfh_time"] = time
outputs["redshift"] = galaxy.redshift
# Put an absolute SFH time here as well given self.cosmo and redshift
outputs["sfh_time_abs"] = self.cosmo.age(galaxy.redshift).to("Myr").value * Myr
outputs["sfh_time_abs"] = outputs["sfh_time_abs"] - time
if "lnu" in self.output_type:
fluxes = spec.lnu
outputs["lnu"] = copy.deepcopy(fluxes)
if "photo_lnu" in self.output_type:
fluxes = galaxy.stars.spectra[self.emission_model_key].get_photo_lnu(
self.instrument.filters
)
fluxes = fluxes.photo_lnu
outputs["photo_lnu"] = copy.deepcopy(fluxes)
if "fnu" in self.output_type or "photo_fnu" in self.output_type:
# Apply IGM and distance
galaxy.get_observed_spectra(self.cosmo)
if "photo_fnu" in self.output_type:
fluxes = galaxy.stars.spectra[self.emission_model_key].get_photo_fnu(
self.instrument.filters
)
outputs["photo_fnu"] = fluxes.photo_fnu
outputs["photo_wav"] = fluxes.filters.pivot_lams
if "fnu" in self.output_type:
fluxes = galaxy.stars.spectra[self.emission_model_key]
outputs["fnu"] = copy.deepcopy(fluxes.fnu)
# print(np.sum(np.isnan(fluxes)), np.sum(fluxes == 0))
outputs["fnu_wav"] = copy.deepcopy(
galaxy.stars.spectra[self.emission_model_key].lam * (1 + galaxy.redshift)
)
if self.out_flux_unit == "AB":
def convert(f):
return -2.5 * np.log10(f.to(Jy).value) + 8.9
if "photo_fnu" in self.output_type:
fluxes = convert(outputs["photo_fnu"])
outputs["photo_fnu"] = copy.deepcopy(fluxes)
if "fnu" in self.output_type:
fluxes = convert(outputs["fnu"])
# turn inf to nan
fluxes[np.isinf(fluxes)] = 99
outputs["fnu"] = fluxes
elif self.out_flux_unit == "asinh":
raise NotImplementedError(
"""asinh fluxes not implemented yet.
Please use AB or Jy units."""
)
else:
if "photo_fnu" in self.output_type:
outputs["photo_fnu"] = outputs["photo_fnu"].to(self.out_flux_unit).value
if "fnu" in self.output_type:
outputs["fnu"] = outputs["fnu"].to(self.out_flux_unit).value
if len(self.output_type) == 1:
fluxes = outputs[self.output_type[0]]
else:
outputs["filters"] = self.instrument.filters
return outputs
def inspect_func(func, locals):
parameters = inspect.signature(func).parameters
inputs = {}
possible_inputs = [
"galaxy",
"spec",
"fluxes",
"sed",
"stars",
"emission_model",
"cosmo",
]
for key in parameters:
if key in possible_inputs:
if hasattr(self, key):
inputs[key] = getattr(self, key)
else:
inputs[key] = locals[key]
else:
inputs[key] = params[key]
return inputs
if self.extra_functions is not None:
output = []
for func in self.extra_functions:
if isinstance(func, tuple):
func, args = func
output.append(func(galaxy, *args))
else:
inputs = inspect_func(func, locals())
output.append(func(**inputs))
fluxes, errors = self._scatter(fluxes, flux_units=self.out_flux_unit)
if self.normalize_method is not None:
if callable(self.normalize_method):
args = inspect_func(self.normalize_method, locals())
norm = self.normalize_method(**args)
if isinstance(norm, dict):
if self.emission_model_key in norm:
norm = norm[self.emission_model_key]
else:
norm = self.normalize_method
fluxes = self._normalize(fluxes, method=norm, norm_unit=self.out_flux_unit)
if self.include_phot_errors:
fluxes = np.concatenate((fluxes, errors))
if self.return_type == "tensor":
fluxes = torch.tensor(np.atleast_2d(fluxes), device=self.device)
return fluxes
def _normalize(self, fluxes, method=None, norm_unit="AB", add_norm_pos=-1):
if method is None:
return fluxes
if norm_unit == "AB":
func = np.subtract
else:
func = np.divide
if isinstance(method, str):
if method in self.instrument.filters.filter_codes:
# Get position of filter in filter codes
filter_pos = self.instrument.filters.filter_codes.index(method)
norm = fluxes[filter_pos]
fluxes = func(fluxes, norm)
else:
raise ValueError(
f"""Filter {method} not found in filter codes.
Cannot normalize photometry."""
)
elif isinstance(method, (unyt_array, unyt_quantity)):
if norm_unit == "AB":
# Convert to AB
method = -2.5 * np.log10(method.to(Jy).value) + 8.9
norm = method
fluxes = func(fluxes, norm)
elif callable(method):
norm = method(fluxes)
fluxes = func(fluxes, norm)
if add_norm_pos is not None:
# Insert the normalization value at this position
if add_norm_pos == -1:
fluxes = np.append(fluxes, norm)
else:
fluxes = np.insert(fluxes, add_norm_pos, norm)
return fluxes
def _scatter(
self,
fluxes: np.ndarray,
flux_units: str = "nJy",
):
"""Scatters the fluxes based on the provided depths or noise models.
Parameters
----------
fluxes : np.ndarray
The fluxes to scatter.
flux_units : str, optional
The units of the fluxes. Default is 'nJy'.
Returns:
-------
np.ndarray, np.ndarray
The scattered fluxes and their corresponding errors.
If depths are not provided, returns the original fluxes and None for errors.
"""
if self.ignore_scatter:
return fluxes, None
if self.depths is not None:
depths = self.depths
if depths is None:
return fluxes
if isinstance(depths, (list, tuple)):
depths = np.array(depths)
m = fluxes.shape
if isinstance(fluxes, unyt_array):
assert (
fluxes.units == flux_units
), f"""Fluxes units {fluxes.units} do not match flux units
{flux_units}. Cannot scatter photometry."""
if flux_units == "AB":
fluxes = (10 ** ((fluxes - 23.9) / -2.5)) * uJy
elif not isinstance(fluxes, (unyt_array)):
fluxes = fluxes * Unit(flux_units)
# Convert depths based on units
if self.out_flux_unit == "AB" and not hasattr(depths, "units"):
# Convert from AB magnitudes to microjanskys
depths_converted = (10 ** ((depths - 23.9) / -2.5)) * uJy
depths_std = depths_converted.to(uJy).value / self.depth_sigma
else:
depths_std = depths.to(uJy).value / self.depth_sigma
# Pre-allocate output array with correct dimensions
output_arr = np.zeros(m)
# Generate all random values at once for better performance
random_values = np.random.normal(loc=0, scale=depths_std, size=(m)) * uJy
# Add the random values to the fluxes
output_arr = fluxes + random_values
errors = depths_std
if flux_units == "AB":
# Convert back to AB magnitudes
output_arr = -2.5 * np.log10(output_arr.to(Jy).value) + 8.9
errors = -2.5 * depths_std / (np.log(10) * fluxes.to(uJy).value)
return output_arr, errors
elif self.noise_models is not None:
# Apply noise models to the fluxes
scattered_fluxes = np.zeros_like(fluxes, dtype=float)
errors = np.zeros_like(fluxes, dtype=float)
for i, filter_code in enumerate(self.instrument.filters.filter_codes):
noise_model = self.noise_models.get(filter_code)
noise_model.return_noise = True
flux = fluxes[i]
scattered_flux, sigma = noise_model.apply_noise(
flux=np.atleast_1d(flux),
true_flux_units=flux_units,
out_units=self.out_flux_unit,
)
scattered_fluxes[i] = scattered_flux
errors[i] = sigma
return scattered_fluxes, errors
else:
# No depths or noise models, return the fluxes as is
return fluxes, None
def __call__(self, params):
"""Call the simulator with parameters to get photometry."""
return self.simulate(params)
def __repr__(self):
"""String representation of the PhotometrySimulator."""
return f"""GalaxySimulator({self.sfh_model},
{self.zdist_model},
{self.grid},
{self.instrument},
{self.emission_model},
{self.emission_model_key})"""
def __str__(self):
"""String representation of the GalaxySimulator."""
return self.__repr__()
[docs]
def test_out_of_distribution(
observed_photometry: np.ndarray,
simulated_photometry: np.ndarray,
sigma_threshold: float = 5.0,
) -> Tuple[np.ndarray, np.ndarray]:
"""Identifies and removes samples from a dataset that are out-of-distribution.
The function calculates the Mahalanobis distance for each simulated sample
to the mean of the observed samples, accounting for the covariance
between filters. Samples with a distance greater than the specified
sigma_threshold are considered outliers and are removed.
Args:
observed_photometry (np.ndarray): The reference dataset, expected to
have a shape of (N_filters, N_obs_samples).
simulated_photometry (np.ndarray): The dataset to be filtered, expected
to have a shape of (N_filters, N_sim_samples).
sigma_threshold (float, optional): The number of standard deviations
(in Mahalanobis distance) to use as the
outlier threshold. Defaults to 5.0.
Returns:
Tuple[np.ndarray, np.ndarray]:
- A NumPy array containing the filtered simulated photometry.
The shape will be (N_filters, N_inliers).
- A 1D NumPy array containing the indices of the rows (samples) that were
identified as outliers and removed from the original simulated_photometry.
"""
if (
observed_photometry.shape[1] == simulated_photometry.shape[1]
and observed_photometry.shape[0] != simulated_photometry.shape[0]
):
observed_photometry = (
observed_photometry.T
) # Transpose if filters are rows and samples are columns
simulated_photometry = simulated_photometry.T # Ensure both are in the same shape
if observed_photometry.shape[0] != simulated_photometry.shape[0]:
raise ValueError(
"""observed_photometry and simulated_photometry must
have the same number of filters (rows)."""
)
# Transpose data so that samples are rows and filters (features) are columns
# Shape becomes (N_samples, N_filters)
obs_data = observed_photometry.T
sim_data = simulated_photometry.T
# Calculate the mean vector and inverse covariance matrix of the observed data
try:
mean_obs = np.mean(obs_data, axis=0)
cov_obs = np.cov(obs_data, rowvar=False)
inv_cov_obs = inv(cov_obs)
except np.linalg.LinAlgError:
raise ValueError(
"""Could not compute the inverse covariance matrix of
the observed_photometry. This can happen if the data
is not full rank (e.g., filters are perfectly correlated)."""
)
# Calculate the Mahalanobis distance for each simulated sample
# from the observed distribution
mahalanobis_distances = np.zeros(sim_data.shape[0])
for i in range(sim_data.shape[0]):
delta = sim_data[i] - mean_obs
# Manual calculation of Mahalanobis distance: sqrt(delta' * inv_cov * delta)
mahalanobis_distances[i] = np.sqrt(delta @ inv_cov_obs @ delta.T)
# Identify the indices of outliers
outlier_indices = np.where(mahalanobis_distances > sigma_threshold)[0]
inlier_indices = np.where(mahalanobis_distances <= sigma_threshold)[0]
# Filter the original simulated_photometry array using the inlier indices
# We filter the original array with shape (N_filters, N_samples)
filtered_sim_photometry = simulated_photometry[:, inlier_indices]
logger.info(f"Original number of samples: {simulated_photometry.shape[1]}")
logger.info(f"Number of outliers removed ({sigma_threshold}-sigma): {len(outlier_indices)}")
logger.info(f"Number of samples remaining: {filtered_sim_photometry.shape[1]}")
return filtered_sim_photometry, outlier_indices
[docs]
class LibraryCreator:
"""Class to create libraries for Synference simulations.
Allows creation of a library from parameter and observation arrays without
using Synthesizer.
Parameters
-----------
model_name : str
Name of the model/library.
parameter_names : List[str]
List of parameter names.
parameter_grid : np.ndarray
2D array of shape (N_parameters, N_samples) containing parameter values.
observation_grid : np.ndarray
2D array of shape (N_observations, N_samples) containing observation values
observation_units : Union[str, unyt_quantity]
Units of the observations (e.g., 'AB', 'Jy', etc.).
parameter_units : List[Union[str, unyt_array]], optional
List of units for each parameter. If None, parameters are assumed to be dimensionless.
supplementary_parameters : np.ndarray, optional
2D array of shape (N_supplementary_parameters, N_samples)
containing supplementary parameter values.
supplementary_parameter_units : List[Union[str, unyt_array]], optional
List of units for each supplementary parameter.
supplementary_parameter_names : List[str], optional
List of names for each supplementary parameter.
compresslevel : int, optional
Compression level for HDF5 datasets (0-9). Default is 4.
"""
def __init__(
self,
model_name: str,
parameter_names: List[str],
parameter_grid: np.ndarray,
observation_grid: np.ndarray,
observation_names: List[str],
observation_units: Union[str, unyt_quantity],
parameter_units: List[Union[str, unyt_array]] = None,
supplementary_parameters: np.ndarray = None,
supplementary_parameter_units: List[Union[str, unyt_array]] = None,
supplementary_parameter_names: List[str] = None,
compresslevel: int = 4,
out_folder: str = ".",
overwrite: bool = False,
):
"""Initialize the LibraryCreator.
Allows creation of a grid from parameter and observation arrays without
using Synthesizer.
Parameters
----------
model_name : str
Name of the model/grid.
parameter_names : List[str]
List of parameter names.
parameter_grid : np.ndarray
2D array of shape (N_parameters, N_samples) containing parameter values.
observation_grid : np.ndarray
2D array of shape (N_observations, N_samples) containing observation values
observation_names : List[str]
List of observation names (e.g. filter names)
observation_units : Union[str, unyt_quantity]
Units of the observations (e.g., 'AB', 'Jy', etc.).
parameter_units : List[Union[str, unyt_array]], optional
List of units for each parameter. If None, parameters are assumed to be dimensionless.
supplementary_parameters : np.ndarray, optional
2D array of shape (N_supplementary_parameters, N_samples)
containing supplementary parameter values.
supplementary_parameter_units : List[Union[str, unyt_array]], optional
List of units for each supplementary parameter.
supplementary_parameter_names : List[str], optional
List of names for each supplementary parameter.
compresslevel : int, optional
Compression level for HDF5 datasets (0-9). Default is 4.
"""
self.model_name = model_name
self.parameter_names = parameter_names
self.parameter_grid = parameter_grid
self.observation_grid = observation_grid
self.observation_names = observation_names
self.parameter_units = parameter_units
self.observation_units = observation_units
self.supplementary_parameters = supplementary_parameters
self.supplementary_parameter_units = supplementary_parameter_units
self.supplementary_parameter_names = supplementary_parameter_names
self.compresslevel = compresslevel
# Run some checks
if self.parameter_units is not None:
assert len(self.parameter_units) == len(self.parameter_names), (
"Length of parameter_units must match length of parameter_names."
)
nparams = self.parameter_grid.shape[0]
nobs = self.observation_grid.shape[0]
logger.info(f"Number of parameters: {nparams}")
logger.info(f"Number of observations: {nobs}")
logger.info(f"Num rows in parameter library: {self.parameter_grid.shape[1]}")
logger.info(f"Num rows in observation library: {self.observation_grid.shape[1]}")
assert self.parameter_grid.shape[0] == len(self.parameter_names), (
"Number of rows in parameter_grid must match length of parameter_names."
f"Got {self.parameter_grid.shape[0]} rows and"
f" {len(self.parameter_names)} parameter names."
)
assert self.observation_grid.shape[0] == len(self.observation_names), (
"Number of rows in observation_grid must match length of observation_names."
)
if self.supplementary_parameters is not None:
assert self.supplementary_parameter_names is not None, (
"supplementary_parameter_names must be provided if"
" supplementary_parameters is provided."
)
assert self.supplementary_parameters.shape[0] == len(
self.supplementary_parameter_names
), (
"Number of rows in supplementary_parameters must match"
"length of supplementary_parameter_names."
)
if self.supplementary_parameter_units is not None:
assert len(self.supplementary_parameter_units) == len(
self.supplementary_parameter_names
), (
"Length of supplementary_parameter_units must"
" match length of supplementary_parameter_names."
)
assert self.supplementary_parameters.shape[1] == self.parameter_grid.shape[1], (
"Number of columns in supplementary_parameters "
"must match number of columns in parameter_grid."
)
self.create_library(out_folder=out_folder, overwrite=overwrite)
[docs]
def create_library(self, out_folder=".", overwrite=False) -> None:
"""Create the libraey and save it to an HDF5 file."""
if not os.path.exists(out_folder):
os.makedirs(out_folder)
if (
os.path.exists(os.path.join(out_folder, f"library_{self.model_name}.h5"))
and not overwrite
):
raise FileExistsError(
f"Library file library_{self.model_name}.h5 already exists "
f"in {out_folder}. Use overwrite=True to overwrite."
)
library_path = os.path.join(out_folder, f"library_{self.model_name}.h5")
with h5py.File(library_path, "w") as f:
grid = f.create_group("Grid")
grid.create_dataset(
"Parameters",
data=self.parameter_grid,
compression="gzip",
compression_opts=self.compresslevel,
)
grid.create_dataset(
"Photometry",
data=self.observation_grid,
compression="gzip",
compression_opts=self.compresslevel,
)
if self.supplementary_parameters is not None:
grid.create_dataset(
"SupplementaryParameters",
data=self.supplementary_parameters,
compression="gzip",
compression_opts=self.compresslevel,
)
f.attrs["SupplementaryParameterNames"] = self.supplementary_parameter_names
if self.supplementary_parameter_units is not None:
f.attrs["SupplementaryParameterUnits"] = self.supplementary_parameter_units
# Add attributes
f.attrs["model_name"] = [self.model_name]
f.attrs["ParameterNames"] = self.parameter_names
f.attrs["FilterCodes"] = self.observation_names
if self.parameter_units is not None:
f.attrs["ParameterUnits"] = self.parameter_units
else:
f.attrs["ParameterUnits"] = ["dimensionless"] * len(self.parameter_names)
f.attrs["PhotometryUnits"] = self.observation_units
f.attrs["CreationDT"] = datetime.now().strftime("%Y%m%d_%H%M%S")
logger.info(f"Library saved to {library_path}")
if __name__ == "__main__":
logger.error("This is a module, not a script.")