import numpy as np
from abc import ABC, abstractmethod
from copy import deepcopy
from lenstronomy.Cosmo.lens_cosmo import LensCosmo
from lenstronomy.LensModel.lens_model import LensModel
from lenstronomy.LightModel.light_model import LightModel
from mejiro.analysis import regions
from mejiro.cosmo import cosmo
[docs]
class StrongLens(ABC):
"""
Parent class for strong lenses.
At minimum, a unique name and the parameterization in lenstronomy (`kwargs_model` and `kwargs_params`) must be provided. If the light models have amplitudes (`amp`), they will be used. If not, magnitudes with their corresponding filters must be provided in the 'physical_params' dictionary. If both are provided, the `amp` values will be used.
Parameters
----------
name : str
The name of the galaxy-galaxy strong lens. Should be unique.
coords : astropy.coordinates.SkyCoord or None
The coordinates of the system.
kwargs_model : dict
In Lenstronomy format: see Lenstronomy documentation.
kwargs_params : dict
In Lenstronomy format: see Lenstronomy documentation.
physical_params : dict
A dictionary of physical parameters. The following are populated when importing from SLSim:
- 'lens_stellar_mass': the stellar mass of the lens galaxy in solar masses.
- 'lens_velocity_dispersion': the velocity dispersion of the lens galaxy in km/s.
- 'magnification': the magnification of the source galaxy.
- 'magnitudes': a dictionary of magnitudes for the lens and source galaxies, with keys 'lens' and 'source', respectively. Each value should be a dictionary with keys corresponding to the filter names (e.g., 'F062', 'F087', etc.) and values as the magnitudes in those filters.
use_jax : bool, or list of bool
Whether to use JAXtronomy for calculations. Default is None, then set to False in the constructor for all lens model element(s). See the `lenstronomy documentation <https://lenstronomy.readthedocs.io/en/latest/lenstronomy.LensModel.html#module-lenstronomy.LensModel.lens_model>`__ for details.
Notes
-----
- Note that JAXtronomy is disabled by default. To enable it, either set `use_jax` to True (which will enable it for all lens model elements) or provide a list of booleans specifying which lens model(s) to enable JAXtronomy for.
Examples
--------
Here is a sample ``physical_params`` dictionary:
.. code-block:: python
{
"lens_stellar_mass": 977409654003.7206,
"lens_velocity_dispersion": 318.1035069180407,
"magnification": 2.1572878148080696,
"magnitudes": {
"lens": {
"F062": 25.31638290929369,
"F087": 24.96369787259651,
"F106": 23.9596060901365,
"F129": 22.97223751280789,
"F146": 22.29807037949136,
"F158": 21.906108678706293,
"F184": 21.45793326374795,
"F213": 21.118396579379453
},
"lensed_source": {
"F062": 25.500925653675434,
"F087": 25.177803863752867,
"F106": 25.00159434812006,
"F129": 24.8952958657035,
"F146": 24.766725445488536,
"F158": 24.667238282726974,
"F184": 24.465433844571177,
"F213": 24.029905125404884
},
"source": {
"F062": 26.33569587971921,
"F087": 26.012574089796644,
"F106": 25.836364574163838,
"F129": 25.730066091747275,
"F146": 25.601495671532312,
"F158": 25.50200850877075,
"F184": 25.300204070614953,
"F213": 24.86467535144866
}
}
}
"""
def __init__(
self,
name,
coords,
kwargs_model,
kwargs_params,
physical_params,
use_jax
):
self.name = name
self.coords = coords
self.kwargs_model = kwargs_model
self.kwargs_params = kwargs_params
self.physical_params = physical_params
self.use_jax = use_jax
# set cosmology
if 'cosmo' in kwargs_model:
self.cosmo = kwargs_model['cosmo']
else:
raise ValueError("Set astropy.cosmology instance in kwargs_model['cosmo']")
# set JAXtronomy flag
if self.use_jax is None:
self.use_jax = [False] * len(self.lens_model_list) # default to False for all lens models
elif isinstance(self.use_jax, list):
if len(self.use_jax) != len(self.lens_model_list):
raise ValueError("Length of use_jax list must match the number of lens models.")
elif isinstance(self.use_jax, bool):
self.use_jax = [self.use_jax] * len(self.lens_model_list)
else:
raise ValueError("use_jax must be a boolean or a list of booleans.")
# check that brightnesses are provided somewhere
_ = self.validate_light_models()
# fields to initialize: these can be computed on demand
self.lens_cosmo = None
self.realization = None
self.interpol_deflection_map = None
self.kwargs_lens_macromodel = None
self.lens_redshift_list_macromodel = None
self.lens_model_list_macromodel = None
self.lens_model_macromodel = None
[docs]
def get_f_sub(self, fov_arcsec=5, num_pix=100, plot=False):
if self.realization is None:
raise ValueError("No realization has been added. Use `add_realization()` to add a pyHalo realization.")
einstein_radius = self.get_einstein_radius()
r_in = (einstein_radius - 0.2) / (fov_arcsec / num_pix) # units of pixels
r_out = (einstein_radius + 0.2) / (fov_arcsec / num_pix) # units of pixels
halo_lens_model_list, _, _, _ = self.realization.lensing_quantities(add_mass_sheet_correction=False)
macrolens_lens_model_list = self.lens_model_list[:-len(halo_lens_model_list)]
lens_model_macro = LensModel(lens_model_list=macrolens_lens_model_list,
z_lens=self.z_lens,
z_source=self.z_source,
lens_redshift_list=self.z_lens * len(macrolens_lens_model_list),
cosmo=self.cosmo,
multi_plane=False)
_r = np.linspace(-fov_arcsec / 2, fov_arcsec / 2, num_pix)
xx, yy = np.meshgrid(_r, _r)
macrolens_kappa = lens_model_macro.kappa(xx.ravel(), yy.ravel(), self.kwargs_lens[:-len(halo_lens_model_list)]).reshape(num_pix, num_pix)
subhalo_kappa = self.get_realization_kappa(fov_arcsec=5, num_pix=100, add_mass_sheet_correction=False)
mask = regions.annular_mask(num_pix, num_pix, (num_pix // 2, num_pix // 2), r_in, r_out)
masked_kappa_subhalos = np.ma.masked_array(subhalo_kappa, mask=~mask)
masked_kappa_macro = np.ma.masked_array(macrolens_kappa, mask=~mask)
f_sub = masked_kappa_subhalos.compressed().sum() / masked_kappa_macro.compressed().sum()
if plot:
import matplotlib.pyplot as plt
_, ax = plt.subplots(1, 2, figsize=(6, 3), constrained_layout=True)
ax[0].imshow(masked_kappa_subhalos, cmap='bwr')
ax[1].imshow(masked_kappa_macro, cmap='bwr')
ax[0].set_title(
r'$\sum_n \int_{\mathrm{annulus}}d^2\theta\,\kappa_n=$' + f'{masked_kappa_subhalos.compressed().sum():.6f}')
ax[1].set_title(
r'$\int_{\mathrm{annulus}}d^2\theta\,\kappa_{\mathrm{host}}=$' + f'{masked_kappa_macro.compressed().sum():.6f}')
for a in ax: a.axis('off')
plt.suptitle(r'$f_{\mathrm{sub}}=$' + f'{f_sub:.6f}')
return f_sub, ax
else:
return f_sub, None
[docs]
def get_kappa(self, fov_arcsec=5, num_pix=100):
"""
Computes the convergence (kappa) map of the lens model over a specified field of view.
Parameters
----------
fov_arcsec : float, optional
The field of view in arcseconds. Default is 5.
num_pix : int, optional
The number of pixels along each axis for the output grid. Default is 100.
Returns
-------
kappa_map : ndarray
A 2D array of shape (num_pix, num_pix) representing the convergence (kappa) values
computed over the grid.
Notes
-----
The method creates a square grid centered at (0, 0) in arcseconds, evaluates the lens model's
convergence at each grid point, and returns the resulting map.
"""
_r = np.linspace(-fov_arcsec / 2, fov_arcsec / 2, num_pix)
xx, yy = np.meshgrid(_r, _r)
return self.lens_model.kappa(xx.ravel(), yy.ravel(), self.kwargs_lens).reshape(num_pix, num_pix)
[docs]
def get_realization_kappa(self, fov_arcsec=5, num_pix=100, add_mass_sheet_correction=False):
"""
Computes the convergence (kappa) map for the current realization over a specified field of view.
Parameters
----------
fov_arcsec : float, optional
The field of view in arcseconds for the kappa map. Default is 5.
num_pix : int, optional
The number of pixels per side for the output kappa map. Default is 100.
add_mass_sheet_correction : bool, optional
Whether to include the mass sheet correction in the realization. Default is False.
Returns
-------
kappa_map : numpy.ndarray
A 2D array of shape (num_pix, num_pix) representing the convergence (kappa) map.
Raises
------
ValueError
If no realization has been added prior to calling this method.
Notes
-----
This method requires a pyHalo realization to be added using `add_realization()`.
"""
if self.realization is None:
raise ValueError("No realization has been added. Use `add_realization()` to add a pyHalo realization.")
halo_lens_model_list, halo_redshift_array, kwargs_halos, _ = self.realization.lensing_quantities(add_mass_sheet_correction=add_mass_sheet_correction)
halo_redshift_list = list(halo_redshift_array)
lens_model_realization = LensModel(lens_model_list=halo_lens_model_list,
z_lens=self.z_lens,
z_source=self.z_source,
lens_redshift_list=halo_redshift_list,
cosmo=self.cosmo,
multi_plane=True)
_r = np.linspace(-fov_arcsec / 2, fov_arcsec / 2, num_pix)
xx, yy = np.meshgrid(_r, _r)
return lens_model_realization.kappa(xx.ravel(), yy.ravel(), kwargs_halos).reshape(num_pix, num_pix)
[docs]
def add_realization(self, realization, add_mass_sheet_correction=True, use_jax=False):
"""
Add a pyHalo dark matter subhalo realization to the mass model of the system. See the `pyHalo documentation <https://github.com/dangilman/pyHalo>`__ for details.
Parameters
----------
realization : pyHalo realization object
See the pyHalo documentation for details.
add_mass_sheet_correction : bool, optional
See the pyHalo documentation for details. Default is True.
use_jax : bool, optional
Whether to use JAXtronomy for calculations. Default is False.
"""
self.realization = realization
# before adding the realization, save the macromodel parameters
self.kwargs_lens_macromodel = deepcopy(self.kwargs_lens)
self.lens_redshift_list_macromodel = deepcopy(self.lens_redshift_list)
self.lens_model_list_macromodel = deepcopy(self.lens_model_list)
self.lens_model_macromodel = deepcopy(self.lens_model)
# get lenstronomy lensing quantities
halo_lens_model_list, halo_redshift_array, kwargs_halos, _ = realization.lensing_quantities(
add_mass_sheet_correction=add_mass_sheet_correction)
# halo_lens_model_list and kwargs_halos are lists, but halo_redshift_array is ndarray
halo_redshift_list = list(halo_redshift_array)
# add subhalos to lenstronomy objects that model the strong lens
self.kwargs_lens += kwargs_halos
self.lens_redshift_list += halo_redshift_list
self.lens_model_list += halo_lens_model_list
# use JAXtronomy for supported lens models
if isinstance(use_jax, bool):
if use_jax:
from jaxtronomy.LensModel.profile_list_base import _JAXXED_MODELS
for halo_lens_model in halo_lens_model_list:
if halo_lens_model in _JAXXED_MODELS:
self.use_jax.append(True)
else:
self.use_jax.append(False)
else:
self.use_jax += [False] * len(halo_lens_model_list)
else:
raise ValueError("use_jax must be a boolean.")
[docs]
def quick_add(self, model='CDM', model_kwargs=None, add_mass_sheet_correction=True, use_jax=False):
"""
Add a pyHalo dark matter subhalo realization to the mass model of the system. See the `pyHalo documentation <https://github.com/dangilman/pyHalo>`__ for details.
Parameters
----------
model : str, optional
The name of the pyHalo preset model to use. Default is 'CDM'.
model_kwargs : dict, optional
Additional keyword arguments to pass to the model class.
add_mass_sheet_correction : bool, optional
Whether to add a mass sheet correction. Default is True.
use_jax : bool, optional
Whether to use JAXtronomy for calculations. Default is False.
"""
from pyHalo.preset_models import preset_model_from_name
ModelClass = preset_model_from_name(model)
if model_kwargs is None:
model_kwargs = {}
realization = ModelClass(round(self.z_lens, 2), round(self.z_source, 2), log_m_host=np.log10(self.get_main_halo_mass()), **model_kwargs)
self.add_realization(realization, add_mass_sheet_correction=add_mass_sheet_correction, use_jax=use_jax)
[docs]
def get_lens_magnitude(self, band):
"""
Parameters
----------
band : str
The name of the photometric band for which to retrieve the lens magnitude.
Returns
-------
float
Magnitude of the lens in the specified photometric band.
"""
return self.get_magnitude('lens', band)
[docs]
def get_source_magnitude(self, band):
"""
Returns the magnitude of the source in the specified photometric band.
Parameters
----------
band : str
The name of the photometric band for which to retrieve the source magnitude.
Returns
-------
float
The magnitude of the source in the specified band.
"""
return self.get_magnitude('source', band)
[docs]
def get_lensed_source_magnitude(self, band):
"""
Returns the magnitude of the lensed source in the specified photometric band.
Parameters
----------
band : str
The name of the photometric band for which to retrieve the lensed source magnitude.
Returns
-------
float
The magnitude of the lensed source in the specified band.
"""
return self.get_magnitude('lensed_source', band)
[docs]
def get_magnitude(self, kind, band):
"""
Retrieve the magnitude value for a specified kind and photometric band from the physical parameters dictionary.
Parameters
----------
kind : str
Options are 'lens', 'source', 'lensed_source'.
band : str
e.g., 'F129' (Roman), 'J' (HWO), etc.
Returns
-------
float
Magnitude
Raises
------
ValueError
If magnitudes are not provided in the `physical_params` dictionary.
If the specified kind is not present in the magnitudes.
If the specified band is not present for the given kind.
"""
if self.physical_params.get('magnitudes') is None:
raise ValueError("Magnitudes are not provided in the `physical_params` dictionary.")
if self.physical_params['magnitudes'].get(kind) is None:
raise ValueError(f"{kind} magnitudes are not provided in the `physical_params` dictionary.")
if self.physical_params['magnitudes'][kind].get(band) is None:
raise ValueError(f"{kind} magnitudes for band {band} are not provided in the `physical_params` dictionary.")
return self.physical_params['magnitudes'][kind][band]
[docs]
def get_maggies(self, kind, band):
"""
Retrieve the maggies value for a specified kind and photometric band from the physical parameters dictionary.
Parameters
----------
kind : str
Options are 'lens', 'source', 'lensed_source'.
band : str
e.g., 'F129' (Roman), 'J' (HWO), etc.
Returns
-------
float
Maggies
Raises
------
ValueError
If magnitudes are not provided in the `physical_params` dictionary.
If the specified kind is not present in the magnitudes.
If the specified band is not present for the given kind.
"""
magnitude = self.get_magnitude(kind, band)
return 10 ** (-0.4 * magnitude)
[docs]
def get_einstein_radius(self):
"""
Returns the Einstein radius.
Returns
-------
float
The Einstein radius (``theta_E``) in angular units (often, arcseconds).
Raises
------
ValueError
This method currently does not calculate the Einstein radius. Rather, it retrieves it from the attributes. If it has not been stored in these attributes, a ValueError will be raised.
"""
einstein_radius = self.physical_params.get('einstein_radius', None)
if einstein_radius is None:
einstein_radius = self.kwargs_lens[0].get('theta_E', None)
if einstein_radius is None:
raise ValueError("Could not find `einstein_radius` in `physical_params` or `theta_E` in `kwargs_lens`")
return einstein_radius
[docs]
def get_velocity_dispersion(self):
"""
Get the velocity dispersion of the lensing galaxy in km/s.
Returns
-------
float
The velocity dispersion of the lensing galaxy in km/s.
Raises
------
ValueError
If 'lens_velocity_dispersion' is not present in `self.physical_params`.
"""
if 'lens_velocity_dispersion' not in self.physical_params:
raise ValueError("Velocity dispersion not found in physical_params. Please provide 'lens_velocity_dispersion' in physical_params.")
return self.physical_params['lens_velocity_dispersion']
[docs]
def get_stellar_mass(self):
"""
Get the stellar mass of the lensing galaxy in solar masses.
Returns
-------
float
The stellar mass of the lensing galaxy in solar masses.
Raises
------
ValueError
If 'lens_stellar_mass' is not present in `self.physical_params`.
"""
if 'lens_stellar_mass' not in self.physical_params:
raise ValueError("Stellar mass not found in physical_params. Please provide 'lens_stellar_mass' in physical_params.")
return self.physical_params['lens_stellar_mass']
[docs]
def get_main_halo_mass(self):
"""
Returns the main halo mass of the lensing galaxy in solar masses.
This method first attempts to retrieve the main halo mass from the ``physical_params`` dictionary
using the ``main_halo_mass`` key. If this value is not present, it will attempt to estimate the
main halo mass using the stellar mass (``lens_stellar_mass``) and the lens redshift (``z_lens``)
via the ``cosmo.stellar_to_main_halo_mass`` method, if available.
Returns
-------
float
The mass of the main halo in solar masses.
Raises
------
ValueError
If neither ``main_halo_mass`` nor ``lens_stellar_mass`` are present in ``physical_params``.
"""
main_halo_mass = self.physical_params.get('main_halo_mass', None)
if main_halo_mass is None:
lens_stellar_mass = self.physical_params.get('lens_stellar_mass', None)
if lens_stellar_mass is not None:
main_halo_mass = cosmo.stellar_to_main_halo_mass(stellar_mass=lens_stellar_mass, z=self.z_lens, sample=True)
else:
raise ValueError("Could not find `main_halo_mass` or `lens_stellar_mass` in physical_params")
return main_halo_mass
[docs]
def get_lens_cosmo(self):
"""
Get or create the LensCosmo instance, a lenstronomy class that supports physical unit calculations.
Returns
-------
LensCosmo
The lenstronomy.Cosmo.lens_cosmo.LensCosmo instance for the system.
"""
if self.lens_cosmo is None:
self.lens_cosmo = LensCosmo(self.z_lens, self.z_source, cosmo=self.cosmo)
return self.lens_cosmo
[docs]
def validate_light_models(self):
"""
Validates the presence of the lenstronomy amplitude ('amp') parameter in all lens and source light model keyword arguments. If magnitude information is not present in ``self.physical_params``, or if it is incomplete (missing 'lens' or 'source' magnitudes), the method ensures that each light model dictionary in ``self.kwargs_lens_light`` and ``self.kwargs_source`` contains an 'amp' key. If any light model is missing the 'amp' parameter, a ValueError is raised.
Returns
-------
amps_provided : bool
True if all required amplitude parameters are provided. False if magnitudes are provided for both the lens and source.
Raises
------
ValueError
If any light model dictionary is missing the 'amp' parameter.
"""
amps_provided = False
if ('magnitudes' not in self.physical_params) or ('magnitudes' in self.physical_params and ('lens' not in self.physical_params['magnitudes'] or 'source' not in self.physical_params['magnitudes'])):
for light_kwargs in self.kwargs_lens_light + self.kwargs_source:
if 'amp' not in light_kwargs:
raise ValueError(f"Missing 'amp' in {light_kwargs}")
# if loop completes without raising an error, then amps are provided
amps_provided = True
return amps_provided
@property
def kwargs_lens(self):
return self.kwargs_params.get('kwargs_lens', None)
@kwargs_lens.setter
def kwargs_lens(self, value):
self.kwargs_params['kwargs_lens'] = value
@property
def kwargs_lens_light(self):
return self.kwargs_params.get('kwargs_lens_light', None)
@kwargs_lens_light.setter
def kwargs_lens_light(self, value):
self.kwargs_params['kwargs_lens_light'] = value
@property
def kwargs_source(self):
return self.kwargs_params.get('kwargs_source', None)
@kwargs_source.setter
def kwargs_source(self, value):
self.kwargs_params['kwargs_source'] = value
@property
def kwargs_ps(self):
return self.kwargs_params.get('kwargs_ps', None)
@kwargs_ps.setter
def kwargs_ps(self, value):
self.kwargs_params['kwargs_ps'] = value
@property
def kwargs_extinction(self):
return self.kwargs_params.get('kwargs_extinction', None)
@kwargs_extinction.setter
def kwargs_extinction(self, value):
self.kwargs_params['kwargs_extinction'] = value
@property
def kwargs_special(self):
return self.kwargs_params.get('kwargs_special', None)
@kwargs_special.setter
def kwargs_special(self, value):
self.kwargs_params['kwargs_special'] = value
@property
def lens_model_list(self):
return self.kwargs_model.get('lens_model_list', None)
@lens_model_list.setter
def lens_model_list(self, value):
self.kwargs_model['lens_model_list'] = value
@property
def lens_light_model_list(self):
return self.kwargs_model.get('lens_light_model_list', None)
@lens_light_model_list.setter
def lens_light_model_list(self, value):
self.kwargs_model['lens_light_model_list'] = value
@property
def source_light_model_list(self):
return self.kwargs_model.get('source_light_model_list', None)
@source_light_model_list.setter
def source_light_model_list(self, value):
self.kwargs_model['source_light_model_list'] = value
@property
def lens_redshift_list(self):
return self.kwargs_model.get('lens_redshift_list', None)
@lens_redshift_list.setter
def lens_redshift_list(self, value):
self.kwargs_model['lens_redshift_list'] = value
@property
def source_redshift_list(self):
return self.kwargs_model.get('source_redshift_list', None)
@source_redshift_list.setter
def source_redshift_list(self, value):
self.kwargs_model['source_redshift_list'] = value
@property
def lens_model(self):
return LensModel(self.lens_model_list, use_jax=self.use_jax)
@property
def lens_light_model(self):
return LightModel(self.lens_light_model_list)
@property
def source_light_model(self):
return LightModel(self.source_light_model_list)
# @property
# def lens_model_macromodel(self):
# if self.lens_model_macromodel is None:
# raise ValueError("This value is only populated when a substructure realization is added.")
# return LensModel(self.lens_model_list_macromodel, use_jax=self.use_jax)
# @property
# def lens_model_list_macromodel(self):
# if self.lens_model_list_macromodel is None:
# raise ValueError("This value is only populated when a substructure realization is added.")
# return self.lens_model_list_macromodel
# @property
# def kwargs_lens_macromodel(self):
# if self.kwargs_lens_macromodel is None:
# raise ValueError("This value is only populated when a substructure realization is added.")
# return self.kwargs_lens_macromodel
# @property
# def lens_redshift_list_macromodel(self):
# if self.lens_redshift_list_macromodel is None:
# raise ValueError("This value is only populated when a substructure realization is added.")
# return self.lens_redshift_list_macromodel
def __str__(self):
return f"StrongLens(name={self.name}, coords={self.coords}, z_lens={getattr(self, 'z_lens', None)}, z_source={getattr(self, 'z_source', None)})"