import h5py
import warnings
import numpy as np
import astropy.units as u
from typing import Callable
from scipy.integrate import quad
from slsim.Util import param_util
from dataclasses import dataclass
from slsim.Sources.source import Source
from astropy.cosmology import Cosmology
from slsim.Sources.SourcePopulation.source_pop_base import SourcePopBase
BANDS = ("u", "g", "r", "i", "z", "Y")
SCOTCH_MAPPINGS = {
"n0": "n_sersic_0",
"n1": "n_sersic_1",
"e0": "ellipticity0",
"e1": "ellipticity1",
}
SKY_AREA = (4 * np.pi * u.rad**2).to(u.deg**2).value
[docs]
def d08(z: float | np.ndarray) -> float | np.ndarray:
"""Redshift Evolution of SNIa Rates from Dilday et al.
2008 Sec. 6.4.1
https://arxiv.org/abs/0801.3297
"""
return (1 + z) ** 1.5
[docs]
def md14(z: float | np.ndarray) -> float | np.ndarray:
"""Redshift Evolution of Cosmic Star Formation Rate from Madau & Dickinson
2014 Eq.
15.
https://arxiv.org/abs/1403.0007
"""
return (1 + z) ** 2.7 / (1 + ((1 + z) / 2.9) ** 5.6)
[docs]
def s15(z: float | np.ndarray) -> float | np.ndarray:
"""Redshift Evolution of CCSNe Rates from Strolger et al.
2015 Eq. 9.
https://arxiv.org/abs/1509.06574
"""
return (1 + z) ** 5.0 / (1 + ((1 + z) / 1.5) ** 6.1)
[docs]
def snia_rate(z: float | np.ndarray) -> float | np.ndarray:
r0 = 25 # in units of 10^-6 Mpc^-3 yr^-1
z = np.asarray(z)
rate = np.where(z < 1, r0 * d08(z), r0 * (1 + z) ** -0.5)
return rate
[docs]
def snia_91bg_rate(z: float | np.ndarray) -> float | np.ndarray:
r0 = 3 # in units of 10^-6 Mpc^-3 yr^-1
z = np.asarray(z)
rate = r0 * d08(z)
return rate
[docs]
def sniax_rate(z: float | np.ndarray) -> float | np.ndarray:
r0 = 6
z = np.asarray(z)
rate = r0 * md14(z)
return rate
[docs]
def snii_rate(z: float | np.ndarray) -> float | np.ndarray:
r0 = 45 # in units of 10^-6 Mpc^-3 yr^-1
z = np.asarray(z)
rate = r0 * s15(z)
return rate
[docs]
def snibc_rate(z: float | np.ndarray) -> float | np.ndarray:
r0 = 19 # in units of 10^-6 Mpc^-3 yr^-1
z = np.asarray(z)
rate = r0 * s15(z)
return rate
[docs]
def slsn_rate(z: float | np.ndarray) -> float | np.ndarray:
r0 = 0.02 # in units of 10^-6 Mpc^-3 yr^-1
z = np.asarray(z)
rate = r0 * md14(z)
return rate
[docs]
def tde_rate(z: float | np.ndarray) -> float | np.ndarray:
r0 = 1 # in units of 10^-6 Mpc^-3 yr^-1
z = np.asarray(z)
rate = r0 * 10 ** (-5 * z / 6)
return rate
[docs]
def kn_rate(z: float | np.ndarray) -> float | np.ndarray:
r0 = 6
z = np.asarray(z)
rate = r0 * np.ones_like(z)
return rate
# Subclass rates are calculated using Lokken et al. 2022
# Table B1 if not given in Kessler et al. 2019
# https://arxiv.org/abs/1903.11756 Table 2
RATE_FUNCS = {
"SNIa-SALT2": snia_rate,
"SNIax": sniax_rate,
"SNIa-91bg": snia_91bg_rate,
"SNII-Templates": lambda z: 0.19448 * snii_rate(z),
"SNII-NMF": lambda z: 0.19948 * snii_rate(z),
"SNII+HostXT_V19": lambda z: 0.39016 * snii_rate(z),
"SNIIn+HostXT_V19": lambda z: 0.04502 * snii_rate(z),
"SNIIn-MOSFIT": lambda z: 0.04502 * snii_rate(z),
"SNIb-Templates": lambda z: 0.27835 * snibc_rate(z),
"SNIb+HostXT_V19": lambda z: 0.27835 * snibc_rate(z),
"SNIc-Templates": lambda z: 0.19330 * snibc_rate(z),
"SNIc+HostXT_V19": lambda z: 0.19330 * snibc_rate(z),
"SNIcBL+HostXT_V19": lambda z: 0.05670 * snibc_rate(z),
"SNIIb+HostXT_V19": lambda z: 0.13085 * snii_rate(z),
"SLSN-I": slsn_rate,
"KN_K17": lambda z: 0.5 * kn_rate(z),
"KN_B19": lambda z: 0.5 * kn_rate(z),
"TDE": tde_rate,
}
[docs]
def expected_number(
rate_fn: Callable,
cosmo: Cosmology,
z_min: float = 0.0,
z_max: float = 3.0,
) -> float:
def integrand(z):
dv = 4 * np.pi * cosmo.differential_comoving_volume(z).value
volumetric_rate = 1e-6 * rate_fn(z)
return volumetric_rate * dv
n = quad(integrand, z_min, z_max)[0]
return n
def _norm_band_names(bands: list[str]) -> list[str]:
"""Normalize band names to lowercase, except for 'Y' which is uppercase.
Parameters
----------
bands : list of str
List of band names to normalize.
Returns
-------
list of str
Normalized band names.
"""
out = []
for b in bands:
b = b.strip()
if b.lower() == "y":
out.append("Y")
else:
out.append(b.lower())
return out
[docs]
def galaxy_projected_eccentricity(
ellipticity: float, rotation_angle=float | None
) -> tuple[float, float]:
"""Compute the projected eccentricity components (e1, e2) of an elliptical
galaxy given its ellipticity and rotation angle. If the rotation angle is
not provided, it is drawn randomly from a uniform distribution between 0
and π.
Parameters
----------
ellipticity : float
Eccentricity amplitude, must be in the range [0, 1).
rotation_angle : float or None, optional
Rotation angle of the major axis in radians. The reference is the +RA axis
(towards the East direction) and it increases from East to North. If None, a random angle is drawn.
Returns
-------
e1 : float
First component of the projected eccentricity.
e2 : float
Second component of the projected eccentricity.
"""
if rotation_angle is None:
phi = np.random.uniform(0, np.pi)
else:
phi = rotation_angle
e = param_util.epsilon2e(ellipticity)
e1 = e * np.cos(2 * phi)
e2 = e * np.sin(2 * phi)
return e1, e2
@dataclass
class _SubclassShard:
file_index: int
grp: h5py.Group
N: int
n_ok: int
eligible: np.ndarray | int # If int then eligible = n_ok = N, so all rows valid
weight_sum: float # S_{f,rl} = sum d_{rl}(z_i) over eligible rows
weights: np.ndarray | None # Normalized weights over eligible rows
@dataclass
class _SubclassIndex:
name: str
shards: list[_SubclassShard]
n_expected: int # from RATE_FUNCS integral over z
@dataclass
class _ClassIndex:
# One host table per input file for this class
host_grp: list[h5py.Group]
host_gid_sorted: list[np.ndarray]
host_gid_sort_idx: list[np.ndarray]
host_mask_sorted: list[np.ndarray]
# Per-subclass info (merged across files)
subclasses: list[_SubclassIndex]
subclass_total: np.ndarray # total rows per subclass (sum across files)
subclass_selected: np.ndarray # eligible rows per subclass (sum across files)
subclass_expected: np.ndarray # expected counts per subclass (RATE_FUNCS)
subclass_weights: np.ndarray # sampling weights p(s | class)
total: int
total_expected: int
total_selected: int = 0
[docs]
class ScotchSources(SourcePopBase):
def __init__(
self,
cosmo: Cosmology,
scotch_path: list[str] | str,
sky_area=None,
transient_types: list[str] | str | None = None,
transient_subtypes: dict[list[str]] | None = None,
kwargs_cut: dict | None = None,
rng: np.random.Generator | int | None = None,
sample_uniformly: bool = False,
exclude_agn: bool = False,
):
"""Class for SCOTCH transient source population. Allows for sampling of
transients and their hosts from the SCOTCH HDF5 catalogs.
Parameters
----------
cosmo : astropy.cosmology instance
An instance of an astropy cosmology model (e.g., FlatLambdaCDM(H0=70, Om0=0.3)).
scotch_path : str
Path to the SCOTCH HDF5 file.
sky_area : astropy.units.Quantity, optional
Sky area over which galaxies are sampled. Must be in units of solid angle.
Default is None.
transient_types : list of str, optional
List of transient types to include. If None, all available types are used.
Default is None.
transient_subtypes: dict of list of str, optional
Dict with transient types as keys and lists of transient subtypes to include.
If None, all available subtypes for chosen transient types are used. Default
is None.
kwargs_cut : dict, optional
Dictionary of selection criteria to filter the sources. Supported keys:
- 'z_min': Minimum redshift (float).
- 'z_max': Maximum redshift (float).
- 'band': List of band names (str) for magnitude cuts.
- 'band_max': List of maximum magnitudes (float) corresponding to 'band'.
The lengths of 'band' and 'band_max' must be equal. Default is None
rng : np.random.Generator, int, or None, optional
Random number generator or seed for reproducibility. If None, a new
generator is created. Default is None.
sample_uniformly: bool, optional
If False, sampling is done according to the expected rates of transient
subclasses within the given redshift range. If True, sampling is done
uniformly over all transient subclasses, while the redshift of the
transient is still sampled according to the volumetric rate. Default is False.
exclude_agn: bool, optional
If True, AGN are excluded from the source population. Defualt is False.
Raises
------
ValueError
If transient_types contains unknown types, or if kwargs_cut is invalid,
or if no sources pass the selection criteria.
Warnings
If any transient class has no objects passing the provided kwargs_cut filters.
Notes
-----
The SCOTCH catalogs contain multiple transient classes, each with its own
host galaxy table. Transients are sampled uniformly among the selected classes
and subclasses, with selection cuts applied as specified in kwargs_cut.
Hosts are included if their redshift is not 999.0; otherwise, the transient
is considered hostless.
The transient lightcurves are provided as "general_lightcurve" point sources,
and hosts (if any) as "double_sersic" extended sources. If the
transient is hostless, the Source is an instance of PointSource; otherwise,
it is an instance of PointPlusExtendedSource.
The SCOTCH HDF5 file is expected to have the following structure:
- /TransientTable/{transient_class}/{subclass}/
- Datasets: "z", "GID", "ra_off", "dec_off", "MJD", "mag_{band}" for each band
- /HostTable/{transient_class}/
- Datasets: "GID", "z", "mag_{band}" for each band, "a_rot", "a0", "b0", "n
"ellipticity0", "a1", "b1", "n1", etc.
The "GID" fields are used to link transients to their hosts.
The "mag_{band}" datasets contain magnitudes, with 99.0 indicating missing data.
"""
super().__init__(cosmo=cosmo, sky_area=sky_area)
self.files = (
[h5py.File(p, "r") for p in scotch_path]
if isinstance(scotch_path, (list, tuple))
else [h5py.File(scotch_path, "r")]
)
self.sample_uniformly = sample_uniformly
transient_types = self._parse_transient_types(transient_types)
if "AGN" in transient_types and exclude_agn:
transient_types = [
transient_type
for transient_type in transient_types
if transient_type != "AGN"
]
self.transient_types = transient_types
self.transient_subtypes = self._parse_transient_subtypes(transient_subtypes)
zmin, zmax, bands_to_filter, band_maxes = self._parse_kwargs_cut(kwargs_cut)
self.zmin, self.zmax = zmin, zmax
self.bands_to_filter = bands_to_filter
self.band_maxes = band_maxes
self.rng = (
rng if isinstance(rng, np.random.Generator) else np.random.default_rng(rng)
)
# Build indices per class
self._index: dict[str, _ClassIndex] = {}
for transient_type in self.transient_types:
self._index[transient_type] = self._build_transient_index(
transient_type=transient_type
)
# keep only classes with survivors
active_types = []
total = 0
total_selected = 0
total_expected = 0
for c in self.transient_types:
if self._index[c].total_selected > 0:
active_types.append(c)
total += self._index[c].total
total_selected += self._index[c].total_selected
total_expected += self._index[c].total_expected
else:
warnings.warn(
f"Transient class '{c}' has no objects passing "
+ "the provided kwargs_cut filters and will be ignored.",
)
self.n_source = total
self.n_source_selected = total_selected
self.total_expected = total_expected
self.active_transient_types = active_types
if self.n_source_selected == 0:
raise ValueError("No objects satisfy the provided kwargs_cut filters.")
# Setup weights for sampling
n_active_transient_types = len(self.active_transient_types)
class_weights = np.zeros(n_active_transient_types)
for i, c in enumerate(self.active_transient_types):
cls = self._index[c]
subclass_expected = cls.subclass_expected
if sample_uniformly:
n_subclasses = len(subclass_expected)
class_weight = n_subclasses
subclass_weights = np.ones(n_subclasses) / n_subclasses
else:
# The probability of sampling a transient class c and a subclass
# s are given as
# p(c, s) = n^{expected}_{c,s} / n^{expected}_total.
# We factorize this such that we first sample the class c with
# probabilities
# p(c) = \sum_{s} p(c,s).
# Given a transient class, we then sample the subclass as
# p(s | c) = p(c, s) / p(c).
# Thus p(c, s) = p(s | c) * p(c)
global_subclass_weights = subclass_expected / self.total_expected
class_weight = np.sum(global_subclass_weights)
subclass_weights = global_subclass_weights / class_weight
cls.subclass_weights = subclass_weights
class_weights[i] = class_weight
if sample_uniformly:
class_weights = class_weights / np.sum(class_weights)
self.class_weights = class_weights
self._effective_sky_area = (
SKY_AREA * self.n_source_selected / self.total_expected
)
if self.sky_area is None:
self.sky_area = self._effective_sky_area * u.deg**2
else:
scaling_factor = (self.sky_area / self._effective_sky_area).value
new_number_selected = int(scaling_factor * self.source_number_selected)
self.n_source_selected = new_number_selected
@property
def source_number(self) -> int:
"""Number of sources in the population before any selection cuts.
Returns
-------
int
Number of sources.
"""
return self.n_source
@property
def source_number_selected(self) -> int:
"""Number of sources in the population after applying selection cuts.
Returns
-------
int
Number of sources passing the selection criteria.
"""
# Why not just rename self.n_source_selected as
# self.source_number_selected. Would mean having to refactor
# SourcePopBase and any children, but would reduce boat.
# Grumble grumble grumble
return self.n_source_selected
# -------------------- init helpers --------------------
def _parse_transient_types(self, transient_types: list[str] | str | None) -> list:
if isinstance(transient_types, str):
transient_types = [transient_types]
avail = set()
for f in self.files:
avail |= set(f["TransientTable"].keys())
if transient_types is None:
transient_types = avail
else:
missing = [t for t in transient_types if t not in avail]
if missing:
raise ValueError(
f"Unknown transient_types {missing}. Available: {sorted(avail)}"
)
transient_types = sorted(list(transient_types))
return transient_types
def _parse_transient_subtypes(
self, transient_subtypes: dict | dict[list[str]] | None
) -> dict[list[str]] | None:
if transient_subtypes is None:
transient_subtypes = {}
for transient_type in self.transient_types:
sub_union = set()
for f in self.files:
if transient_type in f["TransientTable"]:
sub_union |= set(f["TransientTable"][transient_type].keys())
provided = transient_subtypes.get(transient_type, None)
if provided is None:
transient_subtypes[transient_type] = sorted(list(sub_union))
continue
missing = [t for t in provided if t not in sub_union]
if missing:
raise ValueError(
f"Unknown transient_subtypes {missing} for transient_type {transient_type}. "
f"Available: {sorted(sub_union)}"
)
return transient_subtypes
def _parse_kwargs_cut(
self, kwargs_cut: dict | None
) -> tuple[float, float, list, list]:
if kwargs_cut is None:
kwargs_cut = {}
z_min = float(kwargs_cut.get("z_min", 0.0))
z_max = float(kwargs_cut.get("z_max", 3.0))
bands = []
band_maxes = []
has_bands = "band" in kwargs_cut
has_band_max = "band_max" in kwargs_cut
if (has_bands and not has_band_max) or (has_band_max and not has_bands):
raise ValueError(
'If "band" is provided in kwargs_cut then "band_max" must also be '
+ "provided, and vice versa. Currently provided keys in kwargs_cut"
+ f" are {list(kwargs_cut.keys())}."
)
if has_bands and has_band_max:
band = kwargs_cut.get("band")
band_max = kwargs_cut.get("band_max")
band_is_str = isinstance(band, str)
bandmax_is_num = isinstance(band_max, (int, float))
if band_is_str:
kwargs_cut["band"] = [band]
if bandmax_is_num:
kwargs_cut["band_max"] = [band_max]
band = kwargs_cut.get("band")
band_max = kwargs_cut.get("band_max")
band_is_list = isinstance(band, (list, tuple))
bandmax_is_list = isinstance(band_max, (list, tuple))
band_and_bandmax_equal_len = len(band) == len(band_max)
is_valid = band_is_list and bandmax_is_list and band_and_bandmax_equal_len
if not is_valid:
raise ValueError(
"kwargs_cut['band'] and ['band_max'] must be lists of equal length."
)
bands = _norm_band_names(list(band))
band_maxes = list(map(float, band_max))
for b in bands:
if b not in BANDS:
raise ValueError(f"Unsupported band '{b}'. Allowed: {BANDS}")
return z_min, z_max, bands, band_maxes
def _host_pass_mask(self, host_grp: h5py.Group) -> np.ndarray:
"""Create a boolean mask for hosts passing cuts on redshift and
magnitude.
Parameters
----------
host_grp : h5py.Group
HDF5 group for the host table of a given transient class. Must contain
datasets "z" and "mag_{band}" for each band in self.bands, all of shape (Nh,).
Returns
-------
mask : np.ndarray
Boolean array with shape (Nh,) where True indicates the host passes all cuts.
"""
Nh = host_grp["z"].shape[0]
mask = np.ones(Nh, dtype=bool)
z = host_grp["z"][...]
is_hostless = z == 999.0
passes_redshift_cut = (z >= self.zmin) & (z <= self.zmax)
mask &= np.isfinite(z) & (is_hostless | passes_redshift_cut)
for b, mmax in zip(self.bands_to_filter, self.band_maxes):
arr = host_grp[f"mag_{b}"][...]
mask &= np.isfinite(arr) & (arr <= mmax)
return mask
def _build_index_host_info(
self, transient_type: str
) -> tuple[list, list, list, list]:
# Per-file host info
host_grps = []
gids_sorted_list = []
sort_idx_list = []
host_mask_sorted_list = []
for f in self.files:
if transient_type not in f["HostTable"]:
# If a file lacks this class, create empty stubs to keep indexing aligned
host_grps.append(None)
gids_sorted_list.append(np.array([], dtype="|S8"))
sort_idx_list.append(np.array([], dtype=int))
host_mask_sorted_list.append(np.array([], dtype=bool))
continue
host_grp = f["HostTable"][transient_type]
host_grps.append(host_grp)
host_gids = host_grp["GID"][...]
sort_idx = np.argsort(host_gids)
gids_sorted = host_gids[sort_idx]
host_mask = self._host_pass_mask(host_grp)
host_mask_sorted = host_mask[sort_idx]
gids_sorted_list.append(gids_sorted)
sort_idx_list.append(sort_idx)
host_mask_sorted_list.append(host_mask_sorted)
return host_grps, gids_sorted_list, sort_idx_list, host_mask_sorted_list
def _transient_pass_mask(
self,
subgrp: h5py.Group,
host_gid_sorted: np.ndarray,
host_mask_sorted: np.ndarray,
batch: int = 100_000,
) -> np.ndarray:
"""Create a boolean mask for transients passing cuts on redshift,
magnitude, and host validity. Lightcurve magnitude cuts are applied as
nanmin over time <= threshold.
Parameters
----------
subgrp : h5py.Group
HDF5 group for a transient subclass. Must contain datasets "z", "GID", and
"mag_{band}" for each band in self.bands, where "z" and "GID" have shape (N,)
and "mag_{band}" has shape (N, T).
host_gid_sorted : np.ndarray
Sorted array of host GIDs (|S8) for the corresponding transient class.
host_mask_sorted : np.ndarray
Boolean array aligned with host_gid_sorted indicating valid hosts.
batch : int, optional
Number of transient rows to process in each chunk, by default 100_000.
Returns
-------
mask : np.ndarray
Boolean array with shape (N,) where True indicates the transient passes all cuts.
"""
N = subgrp["z"].shape[0]
mask = np.ones(N, dtype=bool)
# transient redshift
z = subgrp["z"][...]
mask &= np.isfinite(z) & (z >= self.zmin) & (z <= self.zmax)
# transient bands: require nanmin over time <= threshold for each requested band
for b, mmax in zip(self.bands_to_filter, self.band_maxes):
ds = subgrp[f"mag_{b}"] # shape (N, T)
# chunk along rows
for i in range(0, N, batch):
sl = slice(i, min(i + batch, N))
arr = ds[sl] # (B,T)
# nanmin across time; if all NaN, result is NaN (treated as fail)
new_arr = np.where(np.isnan(arr), np.inf, arr)
ok = np.any(new_arr <= mmax, axis=1)
mask[sl] &= ok
# host pass via GID membership (vectorized searchsorted per chunk)
for i in range(0, N, batch):
sl = slice(i, min(i + batch, N))
gids = subgrp["GID"][sl]
pos = np.searchsorted(host_gid_sorted, gids)
in_range = pos < host_gid_sorted.size
match = in_range & (host_gid_sorted[pos] == gids)
host_ok = np.zeros(sl.stop - sl.start, dtype=bool)
host_ok[match] = host_mask_sorted[pos[match]]
mask[sl] &= host_ok
return mask
def _build_subtype_shards(
self,
transient_type: str,
subname: str,
host_grps: list,
gids_sorted_list: list,
host_mask_sorted_list: list,
) -> tuple[list[_SubclassShard], int, int]:
shards: list[_SubclassShard] = []
total_rows = 0
total_ok = 0
# collect shards from each file
for f_idx, f in enumerate(self.files):
# skip if class/subclass missing in this file
has_transient_type = transient_type in f["TransientTable"]
if not has_transient_type:
continue
grp = f["TransientTable"][transient_type]
has_subname = subname in grp
if not has_subname:
continue
subgrp = grp[subname]
no_host_table = host_grps[f_idx] is None
if no_host_table:
continue
eligible_mask = self._transient_pass_mask(
subgrp,
gids_sorted_list[f_idx],
host_mask_sorted_list[f_idx],
)
n_ok = int(eligible_mask.sum())
if n_ok == 0:
continue
N = eligible_mask.size
redshifts = subgrp["z"][:]
if "AGN" in subname:
weights = np.ones_like(redshifts) / len(redshifts)
else:
try:
rate_func = RATE_FUNCS[subname]
except KeyError:
raise KeyError(
f"Transient subclass {subname} not found in rate functions."
)
weights = rate_func(redshifts).astype(np.float64)
weights[weights < 0] = 0.0
if n_ok == N:
eligible_idx = N
weights_ok = weights
else:
eligible_idx = np.flatnonzero(eligible_mask).astype(np.int64)
weights_ok = weights[eligible_idx]
weight_sum = np.sum(weights_ok)
normed_weights = weights_ok / weight_sum
shards.append(
_SubclassShard(
file_index=f_idx,
grp=subgrp,
N=N,
n_ok=n_ok,
eligible=eligible_idx,
weight_sum=weight_sum,
weights=normed_weights,
)
)
total_rows += N
total_ok += n_ok
return shards, total_rows, total_ok
def _get_expected_number(self, subname: str, total_ok: int) -> int:
if subname in RATE_FUNCS:
rate_fn = RATE_FUNCS[subname]
n_expected = int(
expected_number(
rate_fn=rate_fn,
cosmo=self._cosmo,
z_min=self.zmin,
z_max=self.zmax,
)
)
elif "AGN" in subname:
n_expected = total_ok
else:
raise KeyError(
f"Transient Subclass {subname} not found in rate functions. "
+ f"Rate functions are available for {list(RATE_FUNCS.keys())}."
)
return n_expected
def _build_subtype_indeces(
self,
transient_type: str,
host_grps: list,
gids_sorted_list: list,
host_mask_sorted_list: list,
) -> tuple[
list[_SubclassIndex],
np.ndarray,
np.ndarray,
np.ndarray,
]:
sub_list: list[_SubclassIndex] = []
subclass_total = []
subclass_selected = []
subclass_expected = []
for subname in self.transient_subtypes[transient_type]:
shards, total_rows, total_ok = self._build_subtype_shards(
transient_type=transient_type,
subname=subname,
host_grps=host_grps,
gids_sorted_list=gids_sorted_list,
host_mask_sorted_list=host_mask_sorted_list,
)
# keep only if any shard has survivors
if not shards:
continue
# expected number for this subclass (same across files)
n_expected = self._get_expected_number(subname=subname, total_ok=total_ok)
sub_list.append(
_SubclassIndex(name=subname, shards=shards, n_expected=n_expected)
)
subclass_total.append(total_rows)
subclass_selected.append(total_ok)
subclass_expected.append(n_expected)
# sort subclasses by name for determinism
sub_names = [s.name for s in sub_list]
idx_ordered = np.argsort(sub_names)
ordered_sub_list = [sub_list[i] for i in idx_ordered]
subclass_total = np.asarray(subclass_total)[idx_ordered]
subclass_selected = np.asarray(subclass_selected)[idx_ordered]
subclass_expected = np.asarray(subclass_expected)[idx_ordered]
return (ordered_sub_list, subclass_total, subclass_selected, subclass_expected)
def _build_transient_index(self, transient_type: str) -> _ClassIndex:
host_grps, gids_sorted_list, sort_idx_list, host_mask_sorted_list = (
self._build_index_host_info(transient_type)
)
# Subclasses across files (as shards)
ordered_sub_list, subclass_total, subclass_selected, subclass_expected = (
self._build_subtype_indeces(
transient_type=transient_type,
host_grps=host_grps,
gids_sorted_list=gids_sorted_list,
host_mask_sorted_list=host_mask_sorted_list,
)
)
total = int(np.sum(subclass_total))
total_selected = int(np.sum(subclass_selected))
total_expected = int(np.sum(subclass_expected))
class_index = _ClassIndex(
host_grp=host_grps,
host_gid_sorted=gids_sorted_list,
host_gid_sort_idx=sort_idx_list,
host_mask_sorted=host_mask_sorted_list,
subclasses=ordered_sub_list,
subclass_total=subclass_total,
subclass_selected=subclass_selected,
subclass_expected=subclass_expected,
subclass_weights=subclass_expected, # placeholder; overwritten below
total=total,
total_expected=total_expected,
total_selected=total_selected,
)
return class_index
# -------------------- sampling --------------------
def _sample_from_class(
self, cls: str
) -> tuple[_SubclassIndex, _SubclassShard, int]:
"""Sample a transient subclass, a subclass shard and an index within
that subclass shard over all surviving subclasses within the provided
class.
Parameters
----------
cls : str
Transient class name.
Returns
-------
s : _SubclassIndex
The sampled transient subclass.
sh: _SubclassShard:
The sampled subclass shard
i : int
Index within the subclass in the file
belonging to the sampled shard.
"""
ci = self._index[cls]
# Ensure subclass weights are normalized (E_{rl} / sum E)
p_sub = ci.subclass_weights
p_sub = p_sub / p_sub.sum()
s = ci.subclasses[self.rng.choice(len(ci.subclasses), p=p_sub)]
# P(file | leaf) ∝ S_{f,rl} = shard.w_sum
shard_weights = np.array([sh.weight_sum for sh in s.shards], dtype=float)
shard_weights /= shard_weights.sum()
sh = s.shards[self.rng.choice(len(s.shards), p=shard_weights)]
# P(row | file, leaf) ∝ d_{rl}(z_i):
i = int(self.rng.choice(sh.eligible, p=sh.weights))
return s, sh, i
def _host_lookup(self, cls: str, file_index: int, gid_bytes: bytes) -> int:
"""Given a transient class, a file index and a GID (as bytes), return
the index of the corresponding host in the HostTable for that class.
Parameters
----------
cls: str
Transient class name.
file_index: int
Sampled file index
gid_bytes: bytes
GID of the host as bytes (|S8).
Returns
-------
int
Index of the host in the HostTable for the given class.
"""
ci = self._index[cls]
gids_sorted = ci.host_gid_sorted[file_index]
sort_idx = ci.host_gid_sort_idx[file_index]
pos = int(np.searchsorted(gids_sorted, gid_bytes))
if pos >= len(gids_sorted) or gids_sorted[pos] != gid_bytes:
raise KeyError(f"GID {gid_bytes!r} not found in HostTable/{cls}")
return int(sort_idx[pos])
def _scotch_to_slsim_host(self, host: dict) -> dict:
"""Convert a host dictionary from SCOTCH naming and conventions to
slsim naming and conventions. Adds projected eccentricity components
and average angular size for each component.
Parameters
----------
host : dict
Dictionary with host parameters using SCOTCH naming. Must include
keys "ellipticity0", "ellipticity1", "a_rot", "a0", "b0", "a1", "b1".
Returns
-------
dict
Dictionary with host parameters using slsim naming, including
"e0_1", "e0_2", "angular_size_0", "e1_1", "e1_2", "angular_size_1".
"""
_host = host.copy()
host = _host
for comp in [0, 1]:
ellip = host[f"ellipticity{comp}"]
a_rot = host["a_rot"]
a = host[f"a{comp}"]
b = host[f"b{comp}"]
e1, e2 = galaxy_projected_eccentricity(
ellipticity=ellip, rotation_angle=a_rot
)
angular_size = param_util.average_angular_size(a=a, b=b)
host[f"e{comp}_1"] = e1
host[f"e{comp}_2"] = e2
host[f"angular_size_{comp}"] = angular_size
return host
def _build_host_dict(self, host_grp: h5py.Group, host_idx: int) -> dict:
"""Build an SLSlim-compatible host dictionary from the host group and
index. If the host redshift is 999.0 (corresponding to a hostless
transient), return an empty dictionary.
Parameters
----------
host_grp : h5py.Group
HDF5 group for the host table of a given transient class.
host_idx : int
Index of the host within the host group.
Returns
-------
dict
Dictionary with host parameters using slsim naming. Empty if
the transient is hostless (host redshift = 999.0).
"""
host = {}
if host_grp["z"][host_idx] == 999.0:
return host
for name, ds in host_grp.items():
if not isinstance(ds, h5py.Dataset):
continue
val = ds[host_idx]
if ds.dtype.kind == "S":
val = val.decode("utf-8")
if name == "a_rot":
val = np.deg2rad(val)
if name in SCOTCH_MAPPINGS:
name = SCOTCH_MAPPINGS[name]
host[name] = val
host = self._scotch_to_slsim_host(host)
return host
def _draw_source_dict(self, *args, **kwargs) -> dict:
"""Draw a transient and its host (if any), returning a combined
dictionary of parameters. Transient class is chosen uniformly among
those with surviving objects, then a transient is chosen uniformly
among all surviving subclasses in that class.
Returns
-------
dict
Dictionary with transient and host parameters using slsim naming.
bool
True if the transient has a host, False if hostless.
"""
cls = self.rng.choice(self.active_transient_types, p=self.class_weights)
s, sh, i = self._sample_from_class(cls)
file_index = sh.file_index
g = sh.grp
transient_metadata = {
"name": f"{s.name}",
"z": float(g["z"][i]),
"ra_off": float(g["ra_off"][i]),
"dec_off": float(g["dec_off"][i]),
}
mjd = g["MJD"][i]
transient_lightcurve = {}
min_mag = np.inf
for band in BANDS:
mags = g[f"mag_{band}"][i]
mags = np.where(mags == 99.0, np.inf, mags)
idx_min_i = np.nanargmin(mags)
min_mag_i = mags[idx_min_i]
if min_mag_i < min_mag:
min_mag = min_mag_i
idx_min = idx_min_i
transient_lightcurve[f"ps_mag_{band}"] = mags
mjd = mjd - mjd[idx_min]
transient_lightcurve["MJD"] = mjd
transient_dict = transient_metadata | transient_lightcurve
gid_b = g["GID"][i]
host_idx = self._host_lookup(cls=cls, file_index=file_index, gid_bytes=gid_b)
host_grp = self._index[cls].host_grp[file_index]
host_dict = self._build_host_dict(host_grp, host_idx)
has_host = bool(host_dict)
source_dict = transient_dict | host_dict
return source_dict, has_host
[docs]
def draw_source(self, *args, **kwargs) -> Source:
"""Draw a source from the population, returning a Source object.
Transients are instantiated as "general_lightcurve" point sources, and
hosts (if any) as "double_sersic" extended sources. If the transient is
hostless, the Source is an instance of PointSource, otherwise it is an
instance of PointPlusExtendedSource.
Returns
-------
Source
The drawn source object. If hostless, an instance of PointSource;
otherwise, an instance of PointPlusExtendedSource.
"""
source_dict, has_host = self._draw_source_dict()
point_source_type = "general_lightcurve"
extended_source_type = "double_sersic"
if not has_host:
extended_source_type = None
source = Source(
cosmo=self._cosmo,
extended_source_type=extended_source_type,
point_source_type=point_source_type,
**source_dict,
)
return source
[docs]
def close(self):
for f in getattr(self, "files", []):
try:
f.close()
except Exception:
pass