Source code for mejiro.synthetic_image

import json
import logging
import os
import numpy as np
import time
import warnings
from lenstronomy.Data.coord_transforms import Coordinates

logger = logging.getLogger(__name__)
from lenstronomy.Data.pixel_grid import PixelGrid
from lenstronomy.Data.psf import PSF
from lenstronomy.ImSim.image_model import ImageModel
from lenstronomy.Util import data_util
from lenstronomy.Util import util as lenstronomy_util

from mejiro.utils import util

LIGHTWEIGHT_SCHEMA_VERSION = 1


[docs] class SyntheticImage: DEFAULT_KWARGS_NUMERICS = { 'supersampling_factor': 5, # super sampling factor of (partial) high resolution ray-tracing 'compute_mode': 'regular', # 'regular' or 'adaptive' 'supersampling_convolution': True, # bool, if True, performs the supersampled convolution (either on regular or adaptive grid) 'supersampling_kernel_size': None, # size of the higher resolution kernel region (can be smaller than the original kernel). None leads to use the full size 'flux_evaluate_indexes': None, # bool mask, if None, it will evaluate all (sub) pixels 'supersampled_indexes': None, # bool mask of pixels to be computed in supersampled grid (only for adaptive mode) 'compute_indexes': None, # bool mask of pixels to be computed the PSF response (flux being added to). Only used for adaptive mode and can be set =likelihood mask. 'point_source_supersampling_factor': 5, } DEFAULT_KWARGS_PSF = { 'psf_type': 'NONE' } def __init__(self, strong_lens, instrument, band, fov_arcsec=5, instrument_params={}, kwargs_numerics={}, kwargs_psf={}, pieces=False ): """ Initialize a SyntheticImage object. Parameters ---------- strong_lens : StrongLens The strong lens system. instrument : Instrument The instrument object, e.g., Roman() band : str The photometric band to simulate. fov_arcsec : float, optional Field of view in arcseconds. Default is 5. instrument_params : dict, optional Instrument-specific parameters. If not provided, defaults are used. kwargs_numerics : dict, optional Numerical settings for image simulation. If not provided, defaults are used. kwargs_psf : dict, optional PSF settings for image simulation. If not provided, defaults are used. pieces : bool, optional If True, computes and stores lens/source surface brightness separately. Raises ------ ValueError If the specified band is not valid for the instrument. Notes ----- - The image is simulated using lenstronomy's ray-shooting and convolution routines. - Magnitudes are converted to lenstronomy amplitudes if not already provided. - The pixel grid is set up to ensure an odd number of pixels per side. - Adaptive supersampling grid is built if requested. - ``self.data`` is in **counts/sec** (surface brightness units from lenstronomy's ``ImageModel.image()``). Multiply by exposure time to get counts before passing to a detector-effects engine. """ start = time.time() # check band is valid for instrument if band not in instrument.bands: raise ValueError(f'Band "{band}" not valid for instrument {instrument.name}') # set up instrument params if not instrument_params: instrument_params = instrument.default_params() else: instrument_params = instrument.validate_instrument_params(instrument_params) # set up attributes self.strong_lens = strong_lens self.instrument = instrument self.instrument_name = instrument.name self.instrument_params = instrument_params self.band = band self.fov_arcsec = fov_arcsec self.pieces = pieces # calculate size of scene self.pixel_scale = instrument.get_pixel_scale(self.band).value # an Astropy Quantity with units arcsec / pix self.num_pix = util.set_odd_num_pix(self.fov_arcsec, self.pixel_scale) # make sure that final image will have odd number of pixels on a side self.fov_arcsec = self.num_pix * self.pixel_scale # adjust fov (may differ from user-provided input) logger.info(f'Scene size: {self.fov_arcsec} arcsec, {self.num_pix} pixels at pixel scale {self.pixel_scale} arcsec/pix') # set up pixel grid and coordinates x, y, self.ra_at_xy_0, self.dec_at_xy_0, x_at_radec_0, y_at_radec_0, self.Mpix2coord, self.Mcoord2pix = ( lenstronomy_util.make_grid_with_coordtransform( numPix=self.num_pix, deltapix=self.pixel_scale, subgrid_res=1, left_lower=False, inverse=False)) kwargs_pixel = { 'nx': self.num_pix, 'ny': self.num_pix, 'ra_at_xy_0': self.ra_at_xy_0, 'dec_at_xy_0': self.dec_at_xy_0, 'transform_pix2angle': self.Mpix2coord } self.pixel_grid = PixelGrid(**kwargs_pixel) self.coords = Coordinates(self.Mpix2coord, self.ra_at_xy_0, self.dec_at_xy_0) # check if lenstronomy amplitudes are provided amps_provided = self.strong_lens.validate_light_models() # update band-specific source image if available (e.g. COSMOS_WEB). source_images = self.strong_lens.kwargs_params.get('source_images') if source_images and self.band in source_images: self.strong_lens.kwargs_source[0]['image'] = source_images[self.band] # if not provided, convert magnitudes to lenstronomy amplitudes if not amps_provided: # retrieve zero-point magnitude if instrument.name == 'Roman': self.magnitude_zeropoint = instrument.get_zeropoint_magnitude(self.band, self.instrument_params['detector']) else: self.magnitude_zeropoint = instrument.get_zeropoint_magnitude(self.band) # retrieve lens and source magnitudes lens_magnitude = self.strong_lens.get_lens_magnitude(self.band) source_magnitude = self.strong_lens.get_source_magnitude(self.band) # overwrite the magnitudes in kwargs_lens_light and kwargs_source self.strong_lens.kwargs_lens_light[0]['magnitude'] = lens_magnitude self.strong_lens.kwargs_source[0]['magnitude'] = source_magnitude # convert magnitudes to lenstronomy amplitudes self.strong_lens.kwargs_lens_light = data_util.magnitude2amplitude(light_model_class=self.strong_lens.lens_light_model, kwargs_light_mag=self.strong_lens.kwargs_lens_light, magnitude_zero_point=self.magnitude_zeropoint) self.strong_lens.kwargs_source = data_util.magnitude2amplitude(light_model_class=self.strong_lens.source_light_model, kwargs_light_mag=self.strong_lens.kwargs_source, magnitude_zero_point=self.magnitude_zeropoint) # convert point source magnitudes to amplitudes if present if self.strong_lens.kwargs_ps: for ps_kwargs in self.strong_lens.kwargs_ps: if 'magnitude' in ps_kwargs: ps_kwargs['point_amp'] = [ data_util.magnitude2cps(mag, self.magnitude_zeropoint) for mag in ps_kwargs['magnitude'] ] del ps_kwargs['magnitude'] else: self.magnitude_zeropoint = None # set kwargs_numerics if not kwargs_numerics: kwargs_numerics = SyntheticImage.DEFAULT_KWARGS_NUMERICS elif 'compute_mode' not in kwargs_numerics: kwargs_numerics['compute_mode'] = 'regular' if kwargs_numerics['compute_mode'] == 'adaptive' and 'supersampled_indexes' not in kwargs_numerics.keys(): logger.info('Building adaptive grid') self.supersampled_indexes = self.build_adaptive_grid(pad=40) kwargs_numerics['supersampled_indexes'] = self.supersampled_indexes if kwargs_numerics['supersampling_factor'] < 5: warnings.warn('Supersampling factor less than 5 may not be sufficient for accurate results, especially when convolving with a non-trivial PSF') self.kwargs_numerics = kwargs_numerics # set kwargs_psf if not kwargs_psf: kwargs_psf = SyntheticImage.DEFAULT_KWARGS_PSF self.psf_class = PSF(**kwargs_psf) # build point source class if needed ps_model_list = self.strong_lens.kwargs_model.get('point_source_model_list', []) if ps_model_list: from lenstronomy.PointSource.point_source import PointSource point_source_class = PointSource( point_source_type_list=ps_model_list, lens_model=self.strong_lens.lens_model, fixed_magnification_list=[False] * len(ps_model_list) ) else: point_source_class = None # ray-shoot image_model = ImageModel(data_class=self.pixel_grid, psf_class=self.psf_class, lens_model_class=self.strong_lens.lens_model, source_model_class=self.strong_lens.source_light_model, lens_light_model_class=self.strong_lens.lens_light_model, point_source_class=point_source_class, kwargs_numerics=kwargs_numerics) self.data = image_model.image(kwargs_lens=self.strong_lens.kwargs_lens, kwargs_source=self.strong_lens.kwargs_source, kwargs_lens_light=self.strong_lens.kwargs_lens_light, kwargs_ps=self.strong_lens.kwargs_ps, kwargs_extinction=self.strong_lens.kwargs_extinction, kwargs_special=self.strong_lens.kwargs_special, unconvolved=False, source_add=True, lens_light_add=True, point_source_add=True) if self.pieces: self.lens_surface_brightness = image_model.lens_surface_brightness(kwargs_lens_light=self.strong_lens.kwargs_lens_light, unconvolved=False) self.source_surface_brightness = image_model.source_surface_brightness(kwargs_source=self.strong_lens.kwargs_source, kwargs_lens=self.strong_lens.kwargs_lens, kwargs_extinction=self.strong_lens.kwargs_extinction, kwargs_special=self.strong_lens.kwargs_special, unconvolved=False) else: self.lens_surface_brightness, self.source_surface_brightness = None, None end = time.time() self.calc_time = end - start logger.info(f'Synthetic image calculation time: {util.calculate_execution_time(start, end, unit="s")}') def __getstate__(self): state = self.__dict__.copy() # drop the PSF kernel on pickle: it's already cached on disk under psf_cache_dir # and is only consumed during __init__, so persisting it just duplicates ~2 MB per pickle state['psf_class'] = None return state
[docs] def save_lightweight(self, path): """Write the lightweight ``.npz`` representation used by the romanisim path. Stores the image as ``float32`` plus a JSON metadata blob carrying only the scalars that downstream consumers (``romanisim_pipeline``, ``_06_h5_export_romanisim``, ``calculate_snrs``, ``projects/.../rung_1.py``) actually read. The full ``StrongLens`` and lenstronomy plumbing are not persisted; loaders should use :func:`mejiro.utils.util.load_synthetic_image`, which returns a :class:`LightweightSyntheticImage` for ``.npz`` paths. Parameters ---------- path : str Destination path. Should end in ``.npz``. Notes ----- Not compatible with the galsim path (``_05_create_exposures.py``); that step requires the full SyntheticImage and will raise if it sees lightweight outputs. ``self.pieces`` is ignored — per-piece arrays are not serialized. """ sl = self.strong_lens band = self.band zp = _to_plain_float(self.magnitude_zeropoint) ip = self.instrument_params or {} det = ip.get('detector') if det is not None: # Roman detectors may flow through as int (pipeline path) or as # an 'SCA01'-style string (some downstream/test paths); normalize. from mejiro.utils import roman_util try: det = roman_util.get_sca_int(det) except (TypeError, ValueError): det = int(det) det_pos = ip.get('detector_position') if det_pos is not None: det_pos = [int(det_pos[0]), int(det_pos[1])] meta = { 'schema_version': LIGHTWEIGHT_SCHEMA_VERSION, 'band': band, 'pixel_scale': float(self.pixel_scale), 'fov_arcsec': float(self.fov_arcsec), 'num_pix': int(self.num_pix), 'instrument_name': str(self.instrument_name), 'instrument_params': {'detector': det, 'detector_position': det_pos}, 'magnitude_zeropoint': zp, 'lens': { 'name': str(sl.name), 'z_lens': float(sl.z_lens), 'z_source': float(sl.z_source), 'has_realization': sl.realization is not None, 'main_halo_mass': float(sl.get_main_halo_mass()), 'einstein_radius': float(sl.get_einstein_radius()), 'velocity_dispersion': float(sl.get_velocity_dispersion()), 'magnification': float(sl.get_magnification()), 'lens_magnitude': float(sl.get_lens_magnitude(band)), 'source_magnitude': float(sl.get_source_magnitude(band)), 'lensed_source_magnitude': float(sl.get_lensed_source_magnitude(band)), }, } meta_bytes = json.dumps(meta).encode('utf-8') # Write atomically: np.savez auto-appends ".npz" when given a path string, # which makes atomic rename awkward, so pass an open file handle instead. tmp_path = path + '.tmp' with open(tmp_path, 'wb') as fh: np.savez( fh, data=np.ascontiguousarray(self.data, dtype=np.float32), meta=np.frombuffer(meta_bytes, dtype=np.uint8), ) os.replace(tmp_path, path)
[docs] def get_flux(self): """ Calculate the total flux in counts/sec of the synthetic image by summing over all pixel values. Returns ------- float The total flux of the synthetic image in counts/sec. """ return np.sum(self.data)
[docs] def get_maggies(self): """ Calculate the total flux of the synthetic image in maggies. This is done by summing over all pixel values to get the total flux in counts/sec, and then converting that flux to maggies using the instrument's zero-point magnitude for the specified band. Returns ------- float The total flux of the synthetic image in maggies. """ total_flux_cps = self.get_flux() magnitude = data_util.cps2magnitude(total_flux_cps, self.magnitude_zeropoint) # using `item()` below because it gives me numpy.ndarray with one element and I want a float return (10 ** (-0.4 * magnitude.value)).item()
[docs] def build_adaptive_grid(self, pad): """ Builds an adaptive grid mask based on the distance of image positions from the center of the scene. To ensure that the mask includes the image positions, the pad value should be at least two pixels but ideally much larger in order to capture the vast majority of the lensed source's flux. Parameters ---------- pad : int Padding value to extend the minimum and maximum radii of the grid mask. Must be non-negative. Returns ------- numpy.ndarray A boolean mask array where `True` indicates grid points within the adaptive grid range and `False` indicates points outside the range. Raises ------ ValueError If the image positions cannot be calculated or are empty. Notes ----- - The grid is centered around the scene, and the distances are calculated relative to the lens center adjusted by the pixel scale. - The adaptive grid range is determined by the minimum and maximum radii of the image positions, adjusted by the padding value. - The range is clamped to ensure it does not exceed the bounds of the scene dimensions. """ if pad < 0 or not isinstance(pad, (int)): raise ValueError(f"Padding value must be a non-negative integer.") image_positions = self.get_image_positions() if len(image_positions) == 0 or len(image_positions[0]) == 0 or len(image_positions[1]) == 0: raise ValueError(f"Failed to calculate image positions: {image_positions}") # calculate how far the images are from the center of the scene image_radii = [] for x, y in zip(image_positions[0], image_positions[1]): image_radii.append(np.sqrt((x - (self.num_pix // 2)) ** 2 + (y - (self.num_pix // 2)) ** 2)) x = np.linspace(-self.num_pix // 2, self.num_pix // 2, self.num_pix) y = np.linspace(-self.num_pix // 2, self.num_pix // 2, self.num_pix) X, Y = np.meshgrid(x, y) distance = np.sqrt((X - (self.strong_lens.kwargs_lens[0]['center_x'] / self.pixel_scale)) ** 2 + ( Y - (self.strong_lens.kwargs_lens[0]['center_y'] / self.pixel_scale)) ** 2) min = np.min(image_radii) - pad if min < 0: min = 0 max = np.max(image_radii) + pad if max > self.num_pix // 2: max = self.num_pix // 2 return (distance >= min) & (distance <= max)
[docs] def get_image_positions(self, ignore_substructure=True, pixel=True): """ Calculate the image positions from the source position and lensing mass model. Wraps ``GalaxyGalaxy.get_image_positions()``, with the added functionality of returning the positions in pixel coordinates based on the pixel grid defined for this synthetic image. Parameters ---------- ignore_substructure : bool, optional If True (default), ignores substructure in the lens model when computing image positions. If False, includes substructure in the calculation. pixel : bool, optional If True, the image positions are returned in pixel coordinates. If False, the image positions are returned in lenstronomy's default angular coordinates. Default is True. Returns ------- Tuple of arrays ([x coordinates], [y coordinates]) of the image positions. When ``pixel=True`` (default), coordinates are in pixels; when ``pixel=False``, coordinates are in lenstronomy's angular units (often arcseconds). """ image_x, image_y = self.strong_lens.get_image_positions(ignore_substructure=ignore_substructure) if pixel: if self.coords is None: self._set_up_pixel_grid() return self.coords.map_coord2pix(ra=image_x, dec=image_y) else: return image_x, image_y
[docs] def plot(self, savepath=None): """ Quickly visualize the synthetic image. Parameters ---------- savepath : str, optional The file path where the plot will be saved. If None, the plot will not be saved. Default is None. Notes ----- The image is displayed using a logarithmic scale (base 10). """ import matplotlib.pyplot as plt plt.imshow(np.log10(self.data), origin='lower') plt.title(f'{self.strong_lens.name} (' + r'$z_{l}=$' + f'{self.strong_lens.z_lens:.2f}, ' + r'$z_{s}=$' + f'{self.strong_lens.z_source:.2f}' + f')\n{self.instrument_name} {self.band}') cbar = plt.colorbar() cbar.set_label(r'log$_{10}$(Counts/sec)') plt.xlabel('x [Pixels]') plt.ylabel('y [Pixels]') if savepath is not None: plt.savefig(savepath) plt.show()
[docs] def overplot_subhalos(self, alpha=0.5,savepath=None): if self.strong_lens.realization is None: raise ValueError('No realization has been added to this StrongLens object.') import matplotlib.pyplot as plt from matplotlib.lines import Line2D for halo in self.strong_lens.realization.halos: if halo.mass < 1e7: plt.scatter(*self.coords.map_coord2pix(halo.x, halo.y), marker='.', color='#0C5DA5', alpha=alpha) elif halo.mass < 1e8: plt.scatter(*self.coords.map_coord2pix(halo.x, halo.y), marker='.', color='#00B945', alpha=alpha) else: plt.scatter(*self.coords.map_coord2pix(halo.x, halo.y), marker='.', color='#FF9500', alpha=alpha) plt.imshow(np.log10(self.data), origin='lower', cmap='binary') plt.title(f'{self.strong_lens.name}: {self.instrument_name} {self.band} band {self.data.shape}') cbar = plt.colorbar() cbar.set_label(r'log$_{10}$(Counts)') plt.xlabel('x [Pixels]') plt.ylabel('y [Pixels]') custom_legend_labels = [r'$> 10^8 \,M_\odot$', r'$10^7 - 10^8 \,M_\odot$', r'$< 10^7 \,M_\odot$'] custom_colors = ['#FF9500', '#00B945', '#0C5DA5'] custom_markers = ['.'] * 3 custom_lines = [Line2D([0], [0], color=custom_colors[i], marker=custom_markers[i], lw=4, linestyle='None') for i in range(len(custom_colors))] plt.legend(custom_lines, custom_legend_labels) if savepath is not None: plt.savefig(savepath) plt.show()
def _to_plain_float(value): """Coerce a magnitude_zeropoint to a JSON-safe float, or None. Roman's ``get_zeropoint_magnitude`` returns a one-row astropy ``Column`` / ``Quantity`` slice; other instruments return plain scalars. Both paths need to round-trip through JSON. """ if value is None: return None if hasattr(value, 'value'): # astropy Quantity value = value.value if hasattr(value, '__len__'): if len(value) == 0: return None value = value[0] return float(value)
[docs] class LightweightStrongLens: """Minimal ``StrongLens`` stand-in loaded from a lightweight ``.npz``. Exposes only the attributes and accessor methods the romanisim-path pipeline reads from ``synthetic_image.strong_lens`` — redshifts, substructure flag, and the scalar ``get_*`` accessors. Each lightweight file is per-band, so the per-band magnitude accessors assert the requested band matches the stored band. """ def __init__(self, meta): self._band = meta['band'] lens_meta = meta['lens'] self.name = lens_meta['name'] self.z_lens = lens_meta['z_lens'] self.z_source = lens_meta['z_source'] # downstream only checks ``lens.realization is None`` to set the # substructure flag; a sentinel string preserves that semantics # without dragging a pyhalo realization onto disk. self.realization = '<lightweight>' if lens_meta['has_realization'] else None self._main_halo_mass = lens_meta['main_halo_mass'] self._einstein_radius = lens_meta['einstein_radius'] self._velocity_dispersion = lens_meta['velocity_dispersion'] self._magnification = lens_meta['magnification'] self._lens_magnitude = lens_meta['lens_magnitude'] self._source_magnitude = lens_meta['source_magnitude'] self._lensed_source_magnitude = lens_meta['lensed_source_magnitude'] def _check_band(self, band): if band != self._band: raise ValueError( f"LightweightStrongLens stores only band {self._band!r}; " f"cannot return magnitude for band {band!r}. Load the " f"per-band lightweight file you actually want." )
[docs] def get_main_halo_mass(self): return self._main_halo_mass
[docs] def get_einstein_radius(self): return self._einstein_radius
[docs] def get_velocity_dispersion(self): return self._velocity_dispersion
[docs] def get_magnification(self): return self._magnification
[docs] def get_lens_magnitude(self, band): self._check_band(band) return self._lens_magnitude
[docs] def get_source_magnitude(self, band): self._check_band(band) return self._source_magnitude
[docs] def get_lensed_source_magnitude(self, band): self._check_band(band) return self._lensed_source_magnitude
[docs] class LightweightSyntheticImage: """In-memory shim returned when loading a lightweight ``.npz``. Quacks like ``SyntheticImage`` for the attributes and methods consumed by the romanisim path. Not suitable for the galsim path (no lens model, PSF, or pixel-grid plumbing). """ def __init__(self, data, meta): self.data = data self.band = meta['band'] self.pixel_scale = meta['pixel_scale'] self.fov_arcsec = meta['fov_arcsec'] self.num_pix = meta['num_pix'] self.instrument_name = meta['instrument_name'] ip = dict(meta.get('instrument_params') or {}) det_pos = ip.get('detector_position') if det_pos is not None: # round-trip as a tuple to match the original SyntheticImage shape ip['detector_position'] = tuple(det_pos) self.instrument_params = ip self.magnitude_zeropoint = meta['magnitude_zeropoint'] self.strong_lens = LightweightStrongLens(meta) self._meta = meta # retained for debugging / introspection
[docs] @classmethod def load(cls, path): with np.load(path) as f: data = np.asarray(f['data']) meta_bytes = f['meta'].tobytes() meta = json.loads(meta_bytes.decode('utf-8')) schema_version = meta.get('schema_version') if schema_version != LIGHTWEIGHT_SCHEMA_VERSION: raise ValueError( f"Unsupported lightweight schema_version={schema_version!r} " f"(expected {LIGHTWEIGHT_SCHEMA_VERSION}) in {path}" ) return cls(data=data, meta=meta)
[docs] def get_flux(self): """Total flux in counts/sec (sum over all pixels).""" return float(np.sum(self.data))
[docs] def get_maggies(self): """Total flux in maggies. Replicates ``SyntheticImage.get_maggies``. Implemented in plain floats so it does not require the stored ``magnitude_zeropoint`` to be an astropy ``Quantity``. """ if self.magnitude_zeropoint is None: raise ValueError( "Cannot compute maggies: this lightweight file has no " "magnitude_zeropoint (lenstronomy amplitudes were provided " "at construction time)." ) magnitude = -2.5 * np.log10(self.get_flux()) + self.magnitude_zeropoint return float(10 ** (-0.4 * magnitude))