import json
import logging
import os
import numpy as np
import time
import warnings
from lenstronomy.Data.coord_transforms import Coordinates
logger = logging.getLogger(__name__)
from lenstronomy.Data.pixel_grid import PixelGrid
from lenstronomy.Data.psf import PSF
from lenstronomy.ImSim.image_model import ImageModel
from lenstronomy.Util import data_util
from lenstronomy.Util import util as lenstronomy_util
from mejiro.utils import util
LIGHTWEIGHT_SCHEMA_VERSION = 1
[docs]
class SyntheticImage:
DEFAULT_KWARGS_NUMERICS = {
'supersampling_factor': 5, # super sampling factor of (partial) high resolution ray-tracing
'compute_mode': 'regular', # 'regular' or 'adaptive'
'supersampling_convolution': True, # bool, if True, performs the supersampled convolution (either on regular or adaptive grid)
'supersampling_kernel_size': None, # size of the higher resolution kernel region (can be smaller than the original kernel). None leads to use the full size
'flux_evaluate_indexes': None, # bool mask, if None, it will evaluate all (sub) pixels
'supersampled_indexes': None, # bool mask of pixels to be computed in supersampled grid (only for adaptive mode)
'compute_indexes': None, # bool mask of pixels to be computed the PSF response (flux being added to). Only used for adaptive mode and can be set =likelihood mask.
'point_source_supersampling_factor': 5,
}
DEFAULT_KWARGS_PSF = {
'psf_type': 'NONE'
}
def __init__(self,
strong_lens,
instrument,
band,
fov_arcsec=5,
instrument_params={},
kwargs_numerics={},
kwargs_psf={},
pieces=False
):
"""
Initialize a SyntheticImage object.
Parameters
----------
strong_lens : StrongLens
The strong lens system.
instrument : Instrument
The instrument object, e.g., Roman()
band : str
The photometric band to simulate.
fov_arcsec : float, optional
Field of view in arcseconds. Default is 5.
instrument_params : dict, optional
Instrument-specific parameters. If not provided, defaults are used.
kwargs_numerics : dict, optional
Numerical settings for image simulation. If not provided, defaults are used.
kwargs_psf : dict, optional
PSF settings for image simulation. If not provided, defaults are used.
pieces : bool, optional
If True, computes and stores lens/source surface brightness separately.
Raises
------
ValueError
If the specified band is not valid for the instrument.
Notes
-----
- The image is simulated using lenstronomy's ray-shooting and convolution routines.
- Magnitudes are converted to lenstronomy amplitudes if not already provided.
- The pixel grid is set up to ensure an odd number of pixels per side.
- Adaptive supersampling grid is built if requested.
- ``self.data`` is in **counts/sec** (surface brightness units from lenstronomy's
``ImageModel.image()``). Multiply by exposure time to get counts before passing
to a detector-effects engine.
"""
start = time.time()
# check band is valid for instrument
if band not in instrument.bands:
raise ValueError(f'Band "{band}" not valid for instrument {instrument.name}')
# set up instrument params
if not instrument_params:
instrument_params = instrument.default_params()
else:
instrument_params = instrument.validate_instrument_params(instrument_params)
# set up attributes
self.strong_lens = strong_lens
self.instrument = instrument
self.instrument_name = instrument.name
self.instrument_params = instrument_params
self.band = band
self.fov_arcsec = fov_arcsec
self.pieces = pieces
# calculate size of scene
self.pixel_scale = instrument.get_pixel_scale(self.band).value # an Astropy Quantity with units arcsec / pix
self.num_pix = util.set_odd_num_pix(self.fov_arcsec, self.pixel_scale) # make sure that final image will have odd number of pixels on a side
self.fov_arcsec = self.num_pix * self.pixel_scale # adjust fov (may differ from user-provided input)
logger.info(f'Scene size: {self.fov_arcsec} arcsec, {self.num_pix} pixels at pixel scale {self.pixel_scale} arcsec/pix')
# set up pixel grid and coordinates
x, y, self.ra_at_xy_0, self.dec_at_xy_0, x_at_radec_0, y_at_radec_0, self.Mpix2coord, self.Mcoord2pix = (
lenstronomy_util.make_grid_with_coordtransform(
numPix=self.num_pix,
deltapix=self.pixel_scale,
subgrid_res=1,
left_lower=False,
inverse=False))
kwargs_pixel = {
'nx': self.num_pix,
'ny': self.num_pix,
'ra_at_xy_0': self.ra_at_xy_0,
'dec_at_xy_0': self.dec_at_xy_0,
'transform_pix2angle': self.Mpix2coord
}
self.pixel_grid = PixelGrid(**kwargs_pixel)
self.coords = Coordinates(self.Mpix2coord, self.ra_at_xy_0, self.dec_at_xy_0)
# check if lenstronomy amplitudes are provided
amps_provided = self.strong_lens.validate_light_models()
# update band-specific source image if available (e.g. COSMOS_WEB).
source_images = self.strong_lens.kwargs_params.get('source_images')
if source_images and self.band in source_images:
self.strong_lens.kwargs_source[0]['image'] = source_images[self.band]
# if not provided, convert magnitudes to lenstronomy amplitudes
if not amps_provided:
# retrieve zero-point magnitude
if instrument.name == 'Roman':
self.magnitude_zeropoint = instrument.get_zeropoint_magnitude(self.band,
self.instrument_params['detector'])
else:
self.magnitude_zeropoint = instrument.get_zeropoint_magnitude(self.band)
# retrieve lens and source magnitudes
lens_magnitude = self.strong_lens.get_lens_magnitude(self.band)
source_magnitude = self.strong_lens.get_source_magnitude(self.band)
# overwrite the magnitudes in kwargs_lens_light and kwargs_source
self.strong_lens.kwargs_lens_light[0]['magnitude'] = lens_magnitude
self.strong_lens.kwargs_source[0]['magnitude'] = source_magnitude
# convert magnitudes to lenstronomy amplitudes
self.strong_lens.kwargs_lens_light = data_util.magnitude2amplitude(light_model_class=self.strong_lens.lens_light_model,
kwargs_light_mag=self.strong_lens.kwargs_lens_light,
magnitude_zero_point=self.magnitude_zeropoint)
self.strong_lens.kwargs_source = data_util.magnitude2amplitude(light_model_class=self.strong_lens.source_light_model,
kwargs_light_mag=self.strong_lens.kwargs_source,
magnitude_zero_point=self.magnitude_zeropoint)
# convert point source magnitudes to amplitudes if present
if self.strong_lens.kwargs_ps:
for ps_kwargs in self.strong_lens.kwargs_ps:
if 'magnitude' in ps_kwargs:
ps_kwargs['point_amp'] = [
data_util.magnitude2cps(mag, self.magnitude_zeropoint)
for mag in ps_kwargs['magnitude']
]
del ps_kwargs['magnitude']
else:
self.magnitude_zeropoint = None
# set kwargs_numerics
if not kwargs_numerics:
kwargs_numerics = SyntheticImage.DEFAULT_KWARGS_NUMERICS
elif 'compute_mode' not in kwargs_numerics:
kwargs_numerics['compute_mode'] = 'regular'
if kwargs_numerics['compute_mode'] == 'adaptive' and 'supersampled_indexes' not in kwargs_numerics.keys():
logger.info('Building adaptive grid')
self.supersampled_indexes = self.build_adaptive_grid(pad=40)
kwargs_numerics['supersampled_indexes'] = self.supersampled_indexes
if kwargs_numerics['supersampling_factor'] < 5:
warnings.warn('Supersampling factor less than 5 may not be sufficient for accurate results, especially when convolving with a non-trivial PSF')
self.kwargs_numerics = kwargs_numerics
# set kwargs_psf
if not kwargs_psf:
kwargs_psf = SyntheticImage.DEFAULT_KWARGS_PSF
self.psf_class = PSF(**kwargs_psf)
# build point source class if needed
ps_model_list = self.strong_lens.kwargs_model.get('point_source_model_list', [])
if ps_model_list:
from lenstronomy.PointSource.point_source import PointSource
point_source_class = PointSource(
point_source_type_list=ps_model_list,
lens_model=self.strong_lens.lens_model,
fixed_magnification_list=[False] * len(ps_model_list)
)
else:
point_source_class = None
# ray-shoot
image_model = ImageModel(data_class=self.pixel_grid,
psf_class=self.psf_class,
lens_model_class=self.strong_lens.lens_model,
source_model_class=self.strong_lens.source_light_model,
lens_light_model_class=self.strong_lens.lens_light_model,
point_source_class=point_source_class,
kwargs_numerics=kwargs_numerics)
self.data = image_model.image(kwargs_lens=self.strong_lens.kwargs_lens,
kwargs_source=self.strong_lens.kwargs_source,
kwargs_lens_light=self.strong_lens.kwargs_lens_light,
kwargs_ps=self.strong_lens.kwargs_ps,
kwargs_extinction=self.strong_lens.kwargs_extinction,
kwargs_special=self.strong_lens.kwargs_special,
unconvolved=False,
source_add=True,
lens_light_add=True,
point_source_add=True)
if self.pieces:
self.lens_surface_brightness = image_model.lens_surface_brightness(kwargs_lens_light=self.strong_lens.kwargs_lens_light, unconvolved=False)
self.source_surface_brightness = image_model.source_surface_brightness(kwargs_source=self.strong_lens.kwargs_source, kwargs_lens=self.strong_lens.kwargs_lens, kwargs_extinction=self.strong_lens.kwargs_extinction, kwargs_special=self.strong_lens.kwargs_special, unconvolved=False)
else:
self.lens_surface_brightness, self.source_surface_brightness = None, None
end = time.time()
self.calc_time = end - start
logger.info(f'Synthetic image calculation time: {util.calculate_execution_time(start, end, unit="s")}')
def __getstate__(self):
state = self.__dict__.copy()
# drop the PSF kernel on pickle: it's already cached on disk under psf_cache_dir
# and is only consumed during __init__, so persisting it just duplicates ~2 MB per pickle
state['psf_class'] = None
return state
[docs]
def save_lightweight(self, path):
"""Write the lightweight ``.npz`` representation used by the romanisim path.
Stores the image as ``float32`` plus a JSON metadata blob carrying only
the scalars that downstream consumers (``romanisim_pipeline``,
``_06_h5_export_romanisim``, ``calculate_snrs``, ``projects/.../rung_1.py``)
actually read. The full ``StrongLens`` and lenstronomy plumbing are not
persisted; loaders should use :func:`mejiro.utils.util.load_synthetic_image`,
which returns a :class:`LightweightSyntheticImage` for ``.npz`` paths.
Parameters
----------
path : str
Destination path. Should end in ``.npz``.
Notes
-----
Not compatible with the galsim path (``_05_create_exposures.py``); that
step requires the full SyntheticImage and will raise if it sees
lightweight outputs. ``self.pieces`` is ignored — per-piece arrays are
not serialized.
"""
sl = self.strong_lens
band = self.band
zp = _to_plain_float(self.magnitude_zeropoint)
ip = self.instrument_params or {}
det = ip.get('detector')
if det is not None:
# Roman detectors may flow through as int (pipeline path) or as
# an 'SCA01'-style string (some downstream/test paths); normalize.
from mejiro.utils import roman_util
try:
det = roman_util.get_sca_int(det)
except (TypeError, ValueError):
det = int(det)
det_pos = ip.get('detector_position')
if det_pos is not None:
det_pos = [int(det_pos[0]), int(det_pos[1])]
meta = {
'schema_version': LIGHTWEIGHT_SCHEMA_VERSION,
'band': band,
'pixel_scale': float(self.pixel_scale),
'fov_arcsec': float(self.fov_arcsec),
'num_pix': int(self.num_pix),
'instrument_name': str(self.instrument_name),
'instrument_params': {'detector': det, 'detector_position': det_pos},
'magnitude_zeropoint': zp,
'lens': {
'name': str(sl.name),
'z_lens': float(sl.z_lens),
'z_source': float(sl.z_source),
'has_realization': sl.realization is not None,
'main_halo_mass': float(sl.get_main_halo_mass()),
'einstein_radius': float(sl.get_einstein_radius()),
'velocity_dispersion': float(sl.get_velocity_dispersion()),
'magnification': float(sl.get_magnification()),
'lens_magnitude': float(sl.get_lens_magnitude(band)),
'source_magnitude': float(sl.get_source_magnitude(band)),
'lensed_source_magnitude': float(sl.get_lensed_source_magnitude(band)),
},
}
meta_bytes = json.dumps(meta).encode('utf-8')
# Write atomically: np.savez auto-appends ".npz" when given a path string,
# which makes atomic rename awkward, so pass an open file handle instead.
tmp_path = path + '.tmp'
with open(tmp_path, 'wb') as fh:
np.savez(
fh,
data=np.ascontiguousarray(self.data, dtype=np.float32),
meta=np.frombuffer(meta_bytes, dtype=np.uint8),
)
os.replace(tmp_path, path)
[docs]
def get_flux(self):
"""
Calculate the total flux in counts/sec of the synthetic image by summing over all pixel values.
Returns
-------
float
The total flux of the synthetic image in counts/sec.
"""
return np.sum(self.data)
[docs]
def get_maggies(self):
"""
Calculate the total flux of the synthetic image in maggies. This is done by summing over all pixel values to get the total flux in counts/sec, and then converting that flux to maggies using the instrument's zero-point magnitude for the specified band.
Returns
-------
float
The total flux of the synthetic image in maggies.
"""
total_flux_cps = self.get_flux()
magnitude = data_util.cps2magnitude(total_flux_cps, self.magnitude_zeropoint)
# using `item()` below because it gives me numpy.ndarray with one element and I want a float
return (10 ** (-0.4 * magnitude.value)).item()
[docs]
def build_adaptive_grid(self, pad):
"""
Builds an adaptive grid mask based on the distance of image positions from the center of the scene. To ensure that the mask includes the image positions, the pad value should be at least two pixels but ideally much larger in order to capture the vast majority of the lensed source's flux.
Parameters
----------
pad : int
Padding value to extend the minimum and maximum radii of the grid mask. Must be non-negative.
Returns
-------
numpy.ndarray
A boolean mask array where `True` indicates grid points within the adaptive grid range
and `False` indicates points outside the range.
Raises
------
ValueError
If the image positions cannot be calculated or are empty.
Notes
-----
- The grid is centered around the scene, and the distances are calculated relative to the
lens center adjusted by the pixel scale.
- The adaptive grid range is determined by the minimum and maximum radii of the image positions,
adjusted by the padding value.
- The range is clamped to ensure it does not exceed the bounds of the scene dimensions.
"""
if pad < 0 or not isinstance(pad, (int)):
raise ValueError(f"Padding value must be a non-negative integer.")
image_positions = self.get_image_positions()
if len(image_positions) == 0 or len(image_positions[0]) == 0 or len(image_positions[1]) == 0:
raise ValueError(f"Failed to calculate image positions: {image_positions}")
# calculate how far the images are from the center of the scene
image_radii = []
for x, y in zip(image_positions[0], image_positions[1]):
image_radii.append(np.sqrt((x - (self.num_pix // 2)) ** 2 + (y - (self.num_pix // 2)) ** 2))
x = np.linspace(-self.num_pix // 2, self.num_pix // 2, self.num_pix)
y = np.linspace(-self.num_pix // 2, self.num_pix // 2, self.num_pix)
X, Y = np.meshgrid(x, y)
distance = np.sqrt((X - (self.strong_lens.kwargs_lens[0]['center_x'] / self.pixel_scale)) ** 2 + (
Y - (self.strong_lens.kwargs_lens[0]['center_y'] / self.pixel_scale)) ** 2)
min = np.min(image_radii) - pad
if min < 0:
min = 0
max = np.max(image_radii) + pad
if max > self.num_pix // 2:
max = self.num_pix // 2
return (distance >= min) & (distance <= max)
[docs]
def get_image_positions(self, ignore_substructure=True, pixel=True):
"""
Calculate the image positions from the source position and lensing mass model. Wraps ``GalaxyGalaxy.get_image_positions()``, with the added functionality of returning the positions in pixel coordinates based on the pixel grid defined for this synthetic image.
Parameters
----------
ignore_substructure : bool, optional
If True (default), ignores substructure in the lens model when computing image positions.
If False, includes substructure in the calculation.
pixel : bool, optional
If True, the image positions are returned in pixel coordinates.
If False, the image positions are returned in lenstronomy's default angular coordinates. Default is True.
Returns
-------
Tuple of arrays
([x coordinates], [y coordinates]) of the image positions. When ``pixel=True`` (default), coordinates are in pixels; when ``pixel=False``, coordinates are in lenstronomy's angular units (often arcseconds).
"""
image_x, image_y = self.strong_lens.get_image_positions(ignore_substructure=ignore_substructure)
if pixel:
if self.coords is None:
self._set_up_pixel_grid()
return self.coords.map_coord2pix(ra=image_x, dec=image_y)
else:
return image_x, image_y
[docs]
def plot(self, savepath=None):
"""
Quickly visualize the synthetic image.
Parameters
----------
savepath : str, optional
The file path where the plot will be saved. If None, the plot
will not be saved. Default is None.
Notes
-----
The image is displayed using a logarithmic scale (base 10).
"""
import matplotlib.pyplot as plt
plt.imshow(np.log10(self.data), origin='lower')
plt.title(f'{self.strong_lens.name} (' + r'$z_{l}=$' + f'{self.strong_lens.z_lens:.2f}, ' + r'$z_{s}=$' + f'{self.strong_lens.z_source:.2f}' + f')\n{self.instrument_name} {self.band}')
cbar = plt.colorbar()
cbar.set_label(r'log$_{10}$(Counts/sec)')
plt.xlabel('x [Pixels]')
plt.ylabel('y [Pixels]')
if savepath is not None:
plt.savefig(savepath)
plt.show()
[docs]
def overplot_subhalos(self, alpha=0.5,savepath=None):
if self.strong_lens.realization is None:
raise ValueError('No realization has been added to this StrongLens object.')
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
for halo in self.strong_lens.realization.halos:
if halo.mass < 1e7:
plt.scatter(*self.coords.map_coord2pix(halo.x, halo.y), marker='.', color='#0C5DA5', alpha=alpha)
elif halo.mass < 1e8:
plt.scatter(*self.coords.map_coord2pix(halo.x, halo.y), marker='.', color='#00B945', alpha=alpha)
else:
plt.scatter(*self.coords.map_coord2pix(halo.x, halo.y), marker='.', color='#FF9500', alpha=alpha)
plt.imshow(np.log10(self.data), origin='lower', cmap='binary')
plt.title(f'{self.strong_lens.name}: {self.instrument_name} {self.band} band {self.data.shape}')
cbar = plt.colorbar()
cbar.set_label(r'log$_{10}$(Counts)')
plt.xlabel('x [Pixels]')
plt.ylabel('y [Pixels]')
custom_legend_labels = [r'$> 10^8 \,M_\odot$', r'$10^7 - 10^8 \,M_\odot$', r'$< 10^7 \,M_\odot$']
custom_colors = ['#FF9500', '#00B945', '#0C5DA5']
custom_markers = ['.'] * 3
custom_lines = [Line2D([0], [0], color=custom_colors[i], marker=custom_markers[i], lw=4, linestyle='None') for i in range(len(custom_colors))]
plt.legend(custom_lines, custom_legend_labels)
if savepath is not None:
plt.savefig(savepath)
plt.show()
def _to_plain_float(value):
"""Coerce a magnitude_zeropoint to a JSON-safe float, or None.
Roman's ``get_zeropoint_magnitude`` returns a one-row astropy ``Column`` /
``Quantity`` slice; other instruments return plain scalars. Both paths need
to round-trip through JSON.
"""
if value is None:
return None
if hasattr(value, 'value'): # astropy Quantity
value = value.value
if hasattr(value, '__len__'):
if len(value) == 0:
return None
value = value[0]
return float(value)
[docs]
class LightweightStrongLens:
"""Minimal ``StrongLens`` stand-in loaded from a lightweight ``.npz``.
Exposes only the attributes and accessor methods the romanisim-path
pipeline reads from ``synthetic_image.strong_lens`` — redshifts,
substructure flag, and the scalar ``get_*`` accessors. Each lightweight
file is per-band, so the per-band magnitude accessors assert the
requested band matches the stored band.
"""
def __init__(self, meta):
self._band = meta['band']
lens_meta = meta['lens']
self.name = lens_meta['name']
self.z_lens = lens_meta['z_lens']
self.z_source = lens_meta['z_source']
# downstream only checks ``lens.realization is None`` to set the
# substructure flag; a sentinel string preserves that semantics
# without dragging a pyhalo realization onto disk.
self.realization = '<lightweight>' if lens_meta['has_realization'] else None
self._main_halo_mass = lens_meta['main_halo_mass']
self._einstein_radius = lens_meta['einstein_radius']
self._velocity_dispersion = lens_meta['velocity_dispersion']
self._magnification = lens_meta['magnification']
self._lens_magnitude = lens_meta['lens_magnitude']
self._source_magnitude = lens_meta['source_magnitude']
self._lensed_source_magnitude = lens_meta['lensed_source_magnitude']
def _check_band(self, band):
if band != self._band:
raise ValueError(
f"LightweightStrongLens stores only band {self._band!r}; "
f"cannot return magnitude for band {band!r}. Load the "
f"per-band lightweight file you actually want."
)
[docs]
def get_main_halo_mass(self):
return self._main_halo_mass
[docs]
def get_einstein_radius(self):
return self._einstein_radius
[docs]
def get_velocity_dispersion(self):
return self._velocity_dispersion
[docs]
def get_magnification(self):
return self._magnification
[docs]
def get_lens_magnitude(self, band):
self._check_band(band)
return self._lens_magnitude
[docs]
def get_source_magnitude(self, band):
self._check_band(band)
return self._source_magnitude
[docs]
def get_lensed_source_magnitude(self, band):
self._check_band(band)
return self._lensed_source_magnitude
[docs]
class LightweightSyntheticImage:
"""In-memory shim returned when loading a lightweight ``.npz``.
Quacks like ``SyntheticImage`` for the attributes and methods consumed by
the romanisim path. Not suitable for the galsim path (no lens model, PSF,
or pixel-grid plumbing).
"""
def __init__(self, data, meta):
self.data = data
self.band = meta['band']
self.pixel_scale = meta['pixel_scale']
self.fov_arcsec = meta['fov_arcsec']
self.num_pix = meta['num_pix']
self.instrument_name = meta['instrument_name']
ip = dict(meta.get('instrument_params') or {})
det_pos = ip.get('detector_position')
if det_pos is not None:
# round-trip as a tuple to match the original SyntheticImage shape
ip['detector_position'] = tuple(det_pos)
self.instrument_params = ip
self.magnitude_zeropoint = meta['magnitude_zeropoint']
self.strong_lens = LightweightStrongLens(meta)
self._meta = meta # retained for debugging / introspection
[docs]
@classmethod
def load(cls, path):
with np.load(path) as f:
data = np.asarray(f['data'])
meta_bytes = f['meta'].tobytes()
meta = json.loads(meta_bytes.decode('utf-8'))
schema_version = meta.get('schema_version')
if schema_version != LIGHTWEIGHT_SCHEMA_VERSION:
raise ValueError(
f"Unsupported lightweight schema_version={schema_version!r} "
f"(expected {LIGHTWEIGHT_SCHEMA_VERSION}) in {path}"
)
return cls(data=data, meta=meta)
[docs]
def get_flux(self):
"""Total flux in counts/sec (sum over all pixels)."""
return float(np.sum(self.data))
[docs]
def get_maggies(self):
"""Total flux in maggies. Replicates ``SyntheticImage.get_maggies``.
Implemented in plain floats so it does not require the stored
``magnitude_zeropoint`` to be an astropy ``Quantity``.
"""
if self.magnitude_zeropoint is None:
raise ValueError(
"Cannot compute maggies: this lightweight file has no "
"magnitude_zeropoint (lenstronomy amplitudes were provided "
"at construction time)."
)
magnitude = -2.5 * np.log10(self.get_flux()) + self.magnitude_zeropoint
return float(10 ** (-0.4 * magnitude))