import matplotlib.pyplot as plt
from astropy.visualization import ZScaleInterval
import random
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable
from slsim.Microlensing.magmap import MagnificationMap
"""This module contains various plotting definations."""
[docs]
def create_image_montage_from_image_list(
num_rows,
num_cols,
images,
time=None,
band=None,
image_type="other",
image_center=None,
):
"""Creates an image montage from an image list.
:param num_rows: number of images to display horizontally
:param num_cols: number of images to display vertically
:param images: list of images
:param time: array of observation time for point source images. If
None, considers static case.
:param band: array of bands corresponding to the observations. If
None, does not display any information regarding the band.
:param image_type: type of the provided image. It could be 'dp0' or
any other name.
:param image_center: center of the source images.
:type image_center: array. eg: for two image, it should be like
np.array([[13.71649063, 13.09556121], [16.69249276,
17.78106655]])
:return: image montage of given images.
"""
# Collect min and max values from all images
all_min = []
all_max = []
for image in images:
all_min.append(np.min(image))
all_max.append(np.max(image))
global_min = min(all_min)
global_max = max(all_max)
# If band is one string, extend to list
if isinstance(band, str):
band = [band] * len(images)
fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * 3, num_rows * 3))
for i in range(num_rows):
for j in range(num_cols):
if i * num_cols + j < len(images):
image = images[i * num_cols + j]
if image_type == "dp0":
zscale = ZScaleInterval()
vmin, vmax = zscale.get_limits(image)
axes[i, j].imshow(
image, origin="lower", cmap="gray", vmin=vmin, vmax=vmax
)
else:
axes[i, j].imshow(
image, origin="lower", vmin=global_min, vmax=global_max
)
axes[i, j].axis("off") # Turn off axis labels
if time is not None:
axes[i, j].text(
0.05,
0.95,
f"Time: {round(time[i * num_cols + j],2)} days",
fontsize=10,
color="white",
verticalalignment="top",
horizontalalignment="left",
transform=axes[i, j].transAxes,
)
if band is not None:
axes[i, j].text(
0.05,
0.10,
f"Band: {band[i * num_cols + j]}",
fontsize=10,
color="white",
verticalalignment="top",
horizontalalignment="left",
transform=axes[i, j].transAxes,
)
if image_center is not None:
for k in range(len(image_center)):
axes[i, j].scatter(
image_center[k][0],
image_center[k][1],
marker="*",
color="red",
s=30,
)
fig.tight_layout()
fig.subplots_adjust(wspace=0.0, hspace=0.05)
return fig
[docs]
def plot_montage_of_random_injected_lens(image_list, num, n_horizont=1, n_vertical=1):
"""Creates an image montage of random lenses from the catalog of injected
lens.
:param images_list: list of catalog images
:param n_horizont: number of images to display horizontally
:param n_vertical: number of images to display vertically
:param num: length of the injected lens catalog
:return: image montage of random injected lenses.
"""
fig, axes = plt.subplots(
n_vertical, n_horizont, figsize=(n_horizont * 3, n_vertical * 3)
)
for i in range(n_horizont):
for j in range(n_vertical):
ax = axes[j, i]
index = random.randint(0, num)
image = image_list[index]
ax.imshow(image, aspect="equal", origin="lower")
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.autoscale(False)
fig.tight_layout()
fig.subplots_adjust(
left=None, bottom=None, right=None, top=None, wspace=0.0, hspace=0.05
)
return fig
[docs]
def plot_lightcurves(lightcurve_dict):
"""Plots lightcurves dynamically for all available images across different
bands.
:param lightcurve_dict: Dictionary of lightcurves.
The format of this dictionary should be following:
lightcurve_dict = {
"obs_time": {"i": [63105.42, 63107.41], "r": [63107.39, 63118.22]},
"magnitudes": {
"mag_image_1": {"i": [21.21, 20.42], "r": [20.87, 19.31]},
"mag_image_2": {"i": [23.82, 22.87], "r": [23.45, 23.16]},
"mag_image_3": {"i": [], "r": []},
"mag_image_4": {"i": [], "r": []}},
"errors_low": {
"mag_error_image_1_low": {"i": [0.04, 0.03], "r": [0.03, 0.02]},
"mag_error_image_2_low": {"i": [0.06, 0.05], "r": [0.04, 0.03]}},
"errors_high": {
"mag_error_image_1_high": {"i": [0.05, 0.04], "r": [0.03, 0.02]},
"mag_error_image_2_high": {"i": [0.07, 0.06], "r": [0.05, 0.04]}}}
:return: lightcurve plots.
"""
magnitudes = lightcurve_dict["magnitudes"]
errors_low = lightcurve_dict["errors_low"]
errors_high = lightcurve_dict["errors_high"]
obs_time = lightcurve_dict["obs_time"]
# Extract all bands and filter out bands where all magnitudes are not NaN across.
bands = [
band
for band in obs_time.keys()
if any(
not np.all(np.isnan(magnitudes[image_key][band]))
for image_key in magnitudes.keys()
if image_key.startswith("mag_image_")
)
]
# Identify non-empty magnitudes dynamically
image_keys = []
for key in magnitudes.keys():
if key.startswith("mag_image_"):
is_non_empty = any(
not np.all(np.isnan(magnitudes[key][band])) for band in bands
)
if is_non_empty:
image_keys.append(key)
# Prepare the plot grid: rows for bands, columns for images +
# optional images montage
fig, axs = plt.subplots(
nrows=len(bands),
ncols=len(image_keys),
figsize=(12, 6),
gridspec_kw={"hspace": 0.6, "wspace": 0.3},
)
# Adjust axes for single-row scenarios
if len(bands) == 1:
axs = axs[np.newaxis, :] # Ensure axs is 2D
# Add titles for each column
for col_idx, image_key in enumerate(image_keys):
axs[0, col_idx].set_title(
f"Lightcurves of image {col_idx+1}", fontsize=12, loc="center"
)
# Plot data for each band
for row_idx, band in enumerate(bands):
band_time = obs_time[band]
for col_idx, image_key in enumerate(image_keys):
mag_band = magnitudes[image_key][band]
err_low_band = errors_low[
f"{image_key.replace('mag_image', 'mag_error_image')}_low"
][band]
err_high_band = errors_high[
f"{image_key.replace('mag_image', 'mag_error_image')}_high"
][band]
err_band = [err_low_band, err_high_band]
# Plot the lightcurve for the current image
axs[row_idx, col_idx].errorbar(
band_time,
mag_band,
yerr=err_band,
fmt=".",
label=f"{band}-band",
color=f"C{row_idx}",
alpha=0.7,
)
axs[row_idx, col_idx].set_ylim(None, 30)
axs[row_idx, col_idx].set_ylabel(f"Mag_{band}", fontsize=10)
axs[row_idx, col_idx].invert_yaxis()
axs[row_idx, col_idx].tick_params(axis="both", labelsize=8)
# Add x-label only for the bottom row
if row_idx == len(bands) - 1:
axs[row_idx, col_idx].set_xlabel("MJD [Days]", fontsize=10)
# Adjust layout to avoid overlaps
plt.tight_layout()
return fig
[docs]
def create_montage(images_band, grid_size=None):
"""Creates a montage from a list of images, limited to the first 3 images,
with consistent scaling. This function is a helper function for
plot_lightcurves() function.
:param images_band: List of 2D NumPy arrays representing images.
:param grid_size: Tuple specifying the grid dimensions (rows, cols).
If None, calculates the grid size to be approximately square.
:return: 2D NumPy array representing the montage.
"""
# Limit to the first 3 images
images_band = images_band[:3]
# Ensure all elements in images_band are 2D NumPy arrays
images_band = [np.array(img) for img in images_band]
# Determine the global minimum and maximum pixel values across all images
global_min = min(np.min(img) for img in images_band)
global_max = max(np.max(img) for img in images_band)
# Normalize all images to the range [0, 1] based on global min and max
normalized_images = [
(img - global_min) / (global_max - global_min) for img in images_band
]
# Determine grid size if not provided
n_images = len(normalized_images)
if grid_size is None:
grid_cols = n_images
grid_rows = 1
else:
grid_rows, grid_cols = grid_size
# Determine the size of each image
img_h, img_w = normalized_images[0].shape # Assuming all images have the same shape
# Create an empty array for the montage
montage = np.zeros((grid_rows * img_h, grid_cols * img_w))
# Fill the montage with images
for idx, image in enumerate(normalized_images):
row = idx // grid_cols
col = idx % grid_cols
montage[row * img_h : (row + 1) * img_h, col * img_w : (col + 1) * img_w] = (
image
)
return montage
# microlensing lightcurve plot along with the magnification maps
[docs]
def plot_lightcurves_and_magmap(
convolved_map,
lightcurves,
time_duration_observer_frame,
tracks,
magmap_instance: MagnificationMap,
lightcurve_type="magnitude",
):
"""Plot the lightcurves and the magnification map.
:param convolved_map: convolved magnification map 2D numpy array.
This is the map that is used to generate the lightcurves.
:param lightcurves: list of lightcurves to plot
:param time_duration_observer_frame: time duration in observer frame
in days
:param tracks: list of tracks to plot
:param magmap_instance: instance of the MagnificationMap class. Must
be the same as the one used to generate the lightcurves.
:param lightcurve_type: type of lightcurve to plot. Can be
'magnitude' or 'magnification'.
:return: ax: the axis of the plot
"""
fig, ax = plt.subplots(1, 2, figsize=(18, 6), width_ratios=[2, 1])
time_array = np.linspace(
0, time_duration_observer_frame, len(lightcurves[0])
) # in days
# light curves
for i in range(len(lightcurves)):
ax[0].plot(time_array, lightcurves[i], label=f"Lightcurve {i+1}")
ax[0].set_xlabel("Time (days)")
if lightcurve_type == "magnitude":
ax[0].set_ylabel(
"Magnitude $\\Delta m = -2.5 \\log_{10} (\\mu / \\mu_{\\text{av}})$"
)
im_to_show = -2.5 * np.log10(convolved_map / np.abs(magmap_instance.mu_ave))
elif lightcurve_type == "magnification":
ax[0].set_ylabel("Magnification $\\mu$")
im_to_show = convolved_map
ax[0].set_ylim(np.nanmin(im_to_show), np.nanmax(im_to_show))
ax[0].legend()
# magmap
conts = ax[1].imshow(
im_to_show,
cmap="viridis_r",
extent=[
(magmap_instance.center_x - magmap_instance.half_length_x)
/ magmap_instance.theta_star,
(magmap_instance.center_x + magmap_instance.half_length_x)
/ magmap_instance.theta_star,
(magmap_instance.center_y - magmap_instance.half_length_y)
/ magmap_instance.theta_star,
(magmap_instance.center_y + magmap_instance.half_length_y)
/ magmap_instance.theta_star,
],
origin="lower",
)
divider = make_axes_locatable(ax[1])
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = plt.colorbar(conts, cax=cax)
if lightcurve_type == "magnitude":
cbar.set_label(
"Microlensing $\\Delta m = -2.5 \\log_{10} (\\mu / \\mu_{\\text{av}})$ (magnitudes)"
)
elif lightcurve_type == "magnification":
cbar.set_label("Microlensing magnification $\\mu$")
ax[1].set_xlabel("$x / \\theta_★$")
ax[1].set_ylabel("$y / \\theta_★$")
# tracks are in pixel coordinates
# to map them to the magmap coordinates, we need to convert them to the physical coordinates
delta_x = 2 * magmap_instance.half_length_x / magmap_instance.num_pixels_x
delta_y = 2 * magmap_instance.half_length_y / magmap_instance.num_pixels_y
mid_x_pixel = magmap_instance.num_pixels_x // 2
mid_y_pixel = magmap_instance.num_pixels_y // 2
if tracks is not None:
for j in range(len(tracks)):
ax[1].plot(
(tracks[j][1] - mid_x_pixel) * delta_x / magmap_instance.theta_star,
(tracks[j][0] - mid_y_pixel) * delta_y / magmap_instance.theta_star,
"w-",
lw=1,
)
ax[1].text(
(tracks[j][1][0] - mid_x_pixel) * delta_x / magmap_instance.theta_star,
(tracks[j][0][0] - mid_y_pixel) * delta_y / magmap_instance.theta_star,
str(j + 1),
color="white",
fontsize=16,
)
return ax
[docs]
def plot_magnification_map(magmap_instance, ax=None, plot_magnitude=True, **kwargs):
"""Plot the magnification map on the given axis.
:param magmap_instance: instance of the MagnificationMap class.
:param ax: axis to plot on. If None, a new figure and axis will be
created.
:param plot_magnitude: if True, plot the magnitudes. If False, plot
the magnifications.
:param kwargs: additional keyword arguments to pass to the imshow
function.
:return: ax: the axis of the plot
"""
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
if plot_magnitude:
im = ax.imshow(
magmap_instance.magnitudes,
extent=[
(magmap_instance.center_x - magmap_instance.half_length_x)
/ magmap_instance.theta_star,
(magmap_instance.center_x + magmap_instance.half_length_x)
/ magmap_instance.theta_star,
(magmap_instance.center_y - magmap_instance.half_length_y)
/ magmap_instance.theta_star,
(magmap_instance.center_y + magmap_instance.half_length_y)
/ magmap_instance.theta_star,
],
**kwargs,
)
else:
im = ax.imshow(
magmap_instance.magnifications,
extent=[
(magmap_instance.center_x - magmap_instance.half_length_x)
/ magmap_instance.theta_star,
(magmap_instance.center_x + magmap_instance.half_length_x)
/ magmap_instance.theta_star,
(magmap_instance.center_y - magmap_instance.half_length_y)
/ magmap_instance.theta_star,
(magmap_instance.center_y + magmap_instance.half_length_y)
/ magmap_instance.theta_star,
],
**kwargs,
)
ax.set_xlabel("$x / \\theta_★$")
ax.set_ylabel("$y / \\theta_★$")
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = plt.colorbar(im, cax=cax)
if plot_magnitude:
cbar.set_label("Microlensing $\\Delta m$ (magnitudes)")
else:
cbar.set_label("Microlensing magnification")