import copy
from astropy.coordinates import SkyCoord
import datetime
import numpy as np
from lenstronomy.SimulationAPI.sim_api import SimAPI
from slsim.ImageSimulation import image_quality_lenstronomy
import os.path
import pickle
from stpsf.roman import WFI
import random
import warnings
try:
import galsim
from galsim import Image, InterpolatedImage, roman
except ModuleNotFoundError:
warning_msg = (
"If you want to simulate images with Roman filters, please install the galsim module.\n"
"Note that this module is not supported on Windows"
)
warnings.warn(warning_msg, category=UserWarning, stacklevel=2)
# NOTE: The galsim module is required, which is not supported on Windows.
# Additionally, PSF convolution is very slow since the psf is being generated
# by stpsf. Alternatively, the user can download psfs from cached_webb_psf
# (https://github.com/LSST-strong-lensing/data_public/stpsf), where the
# psfs have been generated ahead of time so that they can be loaded from
# a file. The directory containing these psfs should be passed into the
# "psf_directory" parameter below.
[docs]
def simulate_roman_image(
lens_class,
band,
num_pix,
oversample=3,
add_noise=True,
subtract_mean_background=True,
with_source=True,
with_deflector=True,
exposure_time=None,
num_exposures=None,
t_obs=None,
survey_mode="wide_area",
detector=None,
detector_pos=None,
seed=None,
ra=None,
dec=None,
date=datetime.datetime(year=2027, month=7, day=7, hour=0, minute=0, second=0),
psf_directory=None,
):
"""Creates a Roman-simulated image of a selected lens with noise, using
galsim's noise settings and PSFs from STPSF.
:param lens_class: class object containing all information of the lensing system
(e.g., Lens())
:param band: imaging band.
:type band: str
:param num_pix: number of pixels per axis.
:type num_pix: int
:param add_noise: determines whether sky background and detector effects are added or not.
See https://galsim-developers.github.io/GalSim/_build/html/roman.html#galsim.roman.allDetectorEffects
for specific details about the detector effects.
:type add_noise: bool
:param subtract_mean_background: whether to subtract the mean count of photons on the background (not the noise).
:type subtract_mean_background: bool
:param with_source: determines whether source is included in image.
:type with_source: bool
:param with_deflector: determines whether deflector is included in image.
:type with_deflector: bool
:param exposure_time: exposure time of image. If None, a default exposure time will be retrieved from
lenstronomy's SimulationAPI.ObservationConfig based on the Roman survey mode.
:type exposure_time: int, optional
:param num_exposures: number of exposures. If None, a default number will be retrieved from
lenstronomy's SimulationAPI.ObservationConfig based on the Roman survey mode.
:type num_exposures: int, optional
:param t_obs: an observation time in units of days. This is applicable only for
variable source. In case of point source, if we do not provide
t_obs, considers no variability in the lens.
:param survey_mode: survey mode of the Roman detector. Can be "wide_area" or "time_domain".
:type survey_mode: str
:param detector: The specific WFI detector being used to generate the psf (from 1 to 18).
If None, one will be selected at random.
:type detector: int, optional
:param detector_pos: The pixel on the detector being used to generate the psf.
Must be a 2-tuple of integers between 4 + num_pix * oversample and 4092 - num_pix * oversample.
:type detector_pos: tuple, optional
:param seed: An rng seed used for generating detector effects in galsim.
:type seed: int, optional
:param ra: Coordinate in space used to generate sky background. For possible coordinates, see
https://roman-docs.stsci.edu/files/215024143/215024145/2/1768495040130/outlines.png
:type ra: float, optional
:param dec: Coordinate in space used to generate sky background. For possible coordinates, see
https://roman-docs.stsci.edu/files/215024143/215024145/2/1768495040130/outlines.png
:type dec: float, optional
:param date: Date used to generate sky background. The date must be consistent with the ra and dec coordinates.
:type date: datetime.datetime
:param psf_directory: Path to directory containing psf file(s) where the psf can be loaded.
Otherwise, the psf will be generated by stpsf on the fly, which is very slow.
See the note in the ``get_psf`` method's docstring for details on the PSF file naming convention.
:type psf_directory: str
:return: simulated image in units of flux per second.
:rtype: numpy.ndarray
"""
if detector is None:
detector = random.randint(1, 18)
if detector_pos is None:
x_pos = random.randint(4 + num_pix * oversample, 4092 - num_pix * oversample)
y_pos = random.randint(4 + num_pix * oversample, 4092 - num_pix * oversample)
detector_pos = (x_pos, y_pos)
if ra is None:
ra = random.uniform(5, 60)
if dec is None:
dec = random.uniform(-40, -20)
# Perform all operations with an additional 3 pixel buffer on each side
# to avoid edge effects, cropped out at the end
num_pix += 6
kwargs_model, kwargs_params = lens_class.lenstronomy_kwargs(band=band, time=t_obs)
kwargs_single_band = image_quality_lenstronomy.kwargs_single_band(
observatory="Roman", band=band, survey_mode=survey_mode
)
galsim_psf = get_psf(band, detector, detector_pos, oversample, psf_directory)
# Will use Galsim to handle individual exposures then add them up
_exposure_time = (
kwargs_single_band["exposure_time"] if exposure_time is None else exposure_time
)
_num_exposures = (
kwargs_single_band["num_exposures"] if num_exposures is None else num_exposures
)
# Unconvolved image will be drawn at oversampled pixel scale
kwargs_single_band["pixel_scale"] /= oversample
sim_api = SimAPI(
numpix=num_pix * oversample,
kwargs_single_band=kwargs_single_band,
kwargs_model=kwargs_model,
)
kwargs_lens_light, kwargs_source, kwargs_ps = sim_api.magnitude2amplitude(
kwargs_lens_light_mag=kwargs_params.get("kwargs_lens_light", None),
kwargs_source_mag=kwargs_params.get("kwargs_source", None),
kwargs_ps_mag=kwargs_params.get("kwargs_ps", None),
)
kwargs_lens = kwargs_params.get("kwargs_lens", None)
kwargs_numerics = {
"point_source_supersampling_factor": 1,
"supersampling_factor": 1,
}
image_model = sim_api.image_model_class(kwargs_numerics)
# Draws the unconvolved image with point source painted on single pixel
array = _exposure_time * image_model.image(
kwargs_lens=kwargs_lens,
kwargs_source=kwargs_source,
kwargs_lens_light=kwargs_lens_light,
kwargs_ps=kwargs_ps,
unconvolved=True,
source_add=with_source,
lens_light_add=with_deflector,
point_source_add=True,
)
# Converts image to the galsim InterpolatedImage class
interp = InterpolatedImage(
Image(array, xmin=0, ymin=0),
scale=0.11 / oversample,
flux=np.sum(array),
)
# Gets psf and convolve
convolved = galsim.Convolve(interp, galsim_psf)
# Draw interpolated image at the original (not oversampled) pixel scale
im = galsim.ImageF(num_pix, num_pix, scale=0.11)
im.setOrigin(0, 0)
image_no_noise = convolved.drawImage(im)
if add_noise:
# Obtain sky and thermal background corresponding to certain band and add it to the image
# Poisson noise realization is not handled until later
image_with_background = add_sky_plus_thermal_background(
image_no_noise,
band,
detector,
num_pix,
_exposure_time,
ra,
dec,
date,
)
# Add noise realizations and detector effects
# Need to handle each exposure separately to properly take into account read noise and persistence
rng = galsim.UniformDeviate(seed)
# includes all noise
final_image_list = []
# does not include readout noise; necessary to include the effects of persistence
prev_exposures = []
for i in range(_num_exposures):
# Create new realizations of image + noise
final_image_list.append(copy.deepcopy(image_with_background))
prev_exposures = roman.allDetectorEffects(
final_image_list[i], # this gets modified in-place
prev_exposures=prev_exposures, # this gets updated with each call
rng=rng, # rng updates are automatically done
exptime=_exposure_time,
)
if subtract_mean_background:
mean_noise = np.mean(final_image_list[i].array - image_no_noise.array)
final_image_list[i].array -= mean_noise
# Combine exposures and compute flux per second
array_list = np.array([image_i.array for image_i in final_image_list])
array = np.sum(array_list, axis=0) / (_exposure_time * _num_exposures)
else:
array = image_no_noise.array / _exposure_time
final_image = array[3:-3, 3:-3]
return final_image
# The following functions have been copy-pasted from the mejiro repo
# Credit to Bryce Wedig
[docs]
def get_psf(band, detector, detector_pos, oversample, psf_directory):
"""Obtain galsim psf corresponding to specific WFI detector, position,
band, and supersampling factor, using stpsf.
:param band: The specific band corresponding to the psf.
:type band: str
:param detector: The specific Roman detector being used to generate the psf (from 1 to 18).
:type detector: int
:param detector_pos: The position of the detector being used to generate the psf.
Must be between 4 + num_pix * oversample and 4092 - num_pix * oversample.
:type detector_pos: int
:param oversample: Number of times that each pixel's side is subdivided for higher
accuracy psf convolution.
:type oversample: int
:param psf_directory: Path to directory containing psf file(s) where the psf can be loaded.
Otherwise, the psf will be generated by stpsf on the fly, which is very slow.
:type psf_directory: str
:return: An image of the psf generated by stpsf.
:rtype: galsim.InterpolatedImage
**Notes on psf naming convention:**
The name of the psf file inside the directory follows this convention::
psf_file_name = f"{band}_{detector}_{detector_pos[0]}_{detector_pos[1]}_{oversample}.pkl"
For example::
psf_file_name = "F106_SCA03_1934_1293_5.pkl"
"""
detector = f"SCA{str(detector).zfill(2)}"
# Since generating the stpsf is very slow, it can alternatively be loaded from a pickle file
# where the psf has been generated ahead of time
psf_file_name = (
f"{band}_{detector}_{detector_pos[0]}_{detector_pos[1]}_{oversample}.pkl"
)
if psf_directory is not None:
psf_file_path = os.path.join(psf_directory, psf_file_name)
else:
psf_file_path = os.path.join(
os.path.dirname(__file__), "../..", "data", "stpsf", psf_file_name
)
if os.path.exists(psf_file_path):
with open(psf_file_path, "rb") as psf_file:
psf = pickle.load(psf_file)
else:
wfi = WFI()
wfi.filter = band.upper()
wfi.detector = detector
wfi.detector_position = detector_pos
# Request a 45x45 PSF kernel which then gets supersampled
# If fov_pixels is not provided, the default is 46x46
psf = wfi.calc_psf(fov_pixels=45, oversample=oversample)
# import PSF to GalSim
oversampled_pixel_scale = 0.11 / oversample
psf_image = galsim.Image(psf[0].data, scale=oversampled_pixel_scale)
return galsim.InterpolatedImage(psf_image)
[docs]
def add_sky_plus_thermal_background(
image,
band,
detector,
num_pix,
exposure_time,
ra,
dec,
date,
):
"""Adds a sky and thermal background to image, corresponding to a specific
band, detector, date, and coordinate in the sky.
:param image: image to add the background to
:type image: galsim Image class
:param band: imaging band
:type band: string
:param detector: The specific Roman detector being used to generate
the psf
:type detector: integer from 1 to 18
:param num_pix: number of pixels per axis
:type num_pix: integer
:param ra: Coordinate in space used to generate sky background
:type ra: float or None
:param dec: Coordinate in space used to generate sky background
:type dec: float or None
:param date: Date used to generate sky background. The date must be
consistent with the ra and dec coordinates.
:type date: datetime.datetime class
:return: image with added background
:rtype: galsim Image class
"""
# Get bandpass object, instance of galsim.roman.Bandpass class corresponding to specific band
bandpass = get_bandpass(band)
# Get dictionary of Roman WCS objects for each band
wcs_dict = _get_wcs_dict(ra, dec, date)
# instance of galsim.GSFitsWCS class corresponding to a Roman detector pointing to a specific location in the sky
wcs = wcs_dict[detector]
# Initialize sky image from wcs (does not include the sky background yet)
sky_image = galsim.ImageF(num_pix, num_pix, wcs=wcs)
# convert center of image in units of pixels to world coordinates in radians
sca_cent_pos = wcs.toWorld(sky_image.true_center)
# sky level in units of electrons per arcsec^2 corresponding to specific band, detector pointing, exposure time, and date
sky_level = roman.getSkyLevel(
bandpass, world_pos=sca_cent_pos, exptime=exposure_time, date=date
)
# include stray light as part of total background
sky_level *= 1.0 + roman.stray_light_fraction
# update sky_image class with the actual sky image corresponding to the sky level
# the conversion from electrons/arcsec^2 to electrons/pixel happens here
wcs.makeSkyImage(sky_image, sky_level)
# Add thermal background
thermal_bkg = roman.thermal_backgrounds[get_bandpass_key(band)] * exposure_time
image = image + sky_image + thermal_bkg
return image
[docs]
def get_bandpass(band):
"""
:param band: imaging band
:type band: string
:return: galsim bandpass object corresponding to specific band
:rtype: galsim Bandpass class
"""
bandpass_key = get_bandpass_key(band)
return roman.getBandpasses()[bandpass_key]
[docs]
def get_bandpass_key(band):
"""Translates the Roman bands to keys used in galsim.
:param band: The Roman band to be translated
:type band: string
:return: Translated band
:rtype: string
"""
band = band.upper()
translate = {
"F062": "R062",
"F087": "Z087",
"F106": "Y106",
"F129": "J129",
"F158": "H158",
"F184": "F184",
"F146": "W149",
"F213": "K213",
}
return translate[band]
def _get_wcs_dict(ra, dec, date):
"""
:param ra: Coordinate in space used to generate sky background. For possible coordinates, see
https://roman-docs.stsci.edu/files/215024143/215024145/2/1768495040130/outlines.png
:type ra: float or None
:param dec: Coordinate in space used to generate sky background. For possible coordinates, see
https://roman-docs.stsci.edu/files/215024143/215024145/2/1768495040130/outlines.png
:type dec: float or None
:param date: Date used to generate sky background. The date must be consistent with the ra and dec coordinates.
:type date: datetime.datetime class
:return: WCS corresponding to date and coordinate in space
:rtype: dictionary, where the keys are the detectors and the
values are the WCS corresponding to each detector
"""
skycoord = SkyCoord(ra, dec, frame="icrs", unit="deg")
ra_hms, dec_dms = skycoord.to_string("hmsdms").split(" ")
ra_targ = galsim.Angle.from_hms(ra_hms)
dec_targ = galsim.Angle.from_dms(dec_dms)
targ_pos = galsim.CelestialCoord(ra=ra_targ, dec=dec_targ)
# NB targ_pos indicates the position to observe at the center of the focal plane array
return roman.getWCS(world_pos=targ_pos, date=date)