import logging
import numpy as np
import os
import warnings
from glob import glob
from stpsf.roman import WFI
from stpsf import NIRCam
from mejiro.engines.engine import Engine
from mejiro.utils import roman_util, lenstronomy_util
logger = logging.getLogger(__name__)
[docs]
class STPSFEngine(Engine):
[docs]
@staticmethod
def defaults(instrument_name):
if instrument_name.lower() == 'roman':
return {} # TODO implement
else:
Engine.instrument_not_supported(instrument_name)
[docs]
@staticmethod
def validate_engine_params(engine_params):
# TODO implement
pass
[docs]
@staticmethod
def get_roman_psf_kwargs(band, detector, detector_position, oversample, num_pix, check_cache=False, psf_cache_dir=None, require_cached=False):
kernel = STPSFEngine.get_roman_psf(band, detector, detector_position, oversample, num_pix,
check_cache=check_cache, psf_cache_dir=psf_cache_dir,
require_cached=require_cached)
return lenstronomy_util.get_pixel_psf_kwargs(kernel, oversample)
[docs]
@staticmethod
def get_roman_psf(band, detector, detector_position, oversample, num_pix, check_cache=False, psf_cache_dir=None,
require_cached=False, **calc_psf_kwargs):
"""
Generate a Roman WFI PSF using STPSF.
Parameters
----------
band : str
The band.
detector : int
The detector number.
detector_position : tuple of int
The (x, y) position on the detector.
oversample : int
The oversampling factor.
num_pix : int
The number of pixels on a side. This parameter is passed to STPSF's `fov_pixels` parameter.
check_cache : bool, optional
If True, check the cache for an existing PSF before generating a new one. Default is True.
psf_cache_dir : str, optional
The directory where cached PSFs are stored. If None, defaults to the directory installed with mejiro. Default is None.
**calc_psf_kwargs : dict
Additional keyword arguments to pass to STPSF's `calc_psf` method.
Returns
-------
np.ndarray
The PSF kernel.
"""
# first, check if it exists in the cache
if check_cache:
assert psf_cache_dir is not None, 'Must provide a PSF cache directory if checking the cache'
psf_id = STPSFEngine.get_psf_id(band, detector, detector_position, oversample, num_pix)
cached_psf = STPSFEngine.get_cached_psf(psf_id, psf_cache_dir)
if cached_psf is not None:
return cached_psf
if require_cached:
raise RuntimeError(f'PSF {psf_id} not found in cache {psf_cache_dir}')
logger.warning('Generating PSF with STPSF, which may be slow. Consider caching frequently-used PSFs.')
# set PSF parameters
wfi = WFI()
wfi.filter = band.upper()
wfi.detector = roman_util.get_sca_string(detector)
wfi.detector_position = detector_position
wfi.options['output_mode'] = 'oversampled'
# generate PSF in STPSF
psf = wfi.calc_psf(fov_pixels=num_pix, oversample=oversample, **calc_psf_kwargs)
return psf['OVERSAMP'].data
[docs]
@staticmethod
def get_jwst_psf_kwargs(band, oversample, num_pix, check_cache=False, psf_cache_dir=None):
kernel = STPSFEngine.get_jwst_psf(band, oversample, num_pix,
check_cache=check_cache, psf_cache_dir=psf_cache_dir)
return lenstronomy_util.get_pixel_psf_kwargs(kernel, oversample)
[docs]
@staticmethod
def get_jwst_psf(band, oversample, num_pix, check_cache=False, psf_cache_dir=None, **calc_psf_kwargs):
# first, check if it exists in the cache
if check_cache:
assert psf_cache_dir is not None, 'Must provide a PSF cache directory if checking the cache'
psf_id = STPSFEngine.get_jwst_psf_id(band, oversample, num_pix)
cached_psf = STPSFEngine.get_cached_psf(psf_id, psf_cache_dir)
if cached_psf is not None:
return cached_psf
logger.warning('Generating PSF with STPSF, which may be slow. Consider caching frequently-used PSFs.')
# set PSF parameters
nircam = NIRCam()
nircam.filter = band.upper()
nircam.options['output_mode'] = 'oversampled'
# generate PSF in STPSF
psf = nircam.calc_psf(fov_pixels=num_pix, oversample=oversample, **calc_psf_kwargs)
return psf['OVERSAMP'].data
[docs]
@staticmethod
def get_psf_id(band, detector, detector_position, oversample, num_pix):
"""
Generate a PSF identifier string. mejiro's Roman simulation uses this under-the-hood to cache and retrieve Roman PSFs.
Parameters
----------
band : str
The band.
detector : str
The detector number.
detector_position : tuple of int
The (x, y) position on the detector.
oversample : int
The oversampling factor.
num_pix : int
The number of pixels on a side.
Returns
-------
str
A unique identifier string for the PSF.
"""
detector = roman_util.get_sca_int(detector)
return f'{band}_{detector}_{detector_position[0]}_{detector_position[1]}_{oversample}_{num_pix}'
[docs]
@staticmethod
def get_jwst_psf_id(band, oversample, num_pix):
"""
Generate a PSF identifier string. mejiro's JWST simulation uses this under-the-hood to cache and retrieve JWST PSFs.
Parameters
----------
band : str
The band.
oversample : int
The oversampling factor.
num_pix : int
The number of pixels on a side.
Returns
-------
str
A unique identifier string for the PSF.
"""
return f'{band}_{oversample}_{num_pix}'
[docs]
@staticmethod
def get_params_from_psf_id(psf_id):
"""
Converts mejiro's Roman PSF identifier string format back to a list of PSF parameters.
Parameters
----------
psf_id : str
mejiro's Roman PSF identifier string.
Returns
-------
tuple
A tuple containing the following elements:
- band (str): The band.
- detector (int): The detector number.
- detector_position (tuple of int): The (x, y) position on the detector.
- oversample (int): The oversampling factor.
- num_pix (int): The number of pixels on a side.
"""
band, detector, detector_position_0, detector_position_1, oversample, num_pix = psf_id.split('_')
return band, int(detector), (int(detector_position_0), int(detector_position_1)), int(oversample), int(num_pix)
[docs]
@staticmethod
def get_roman_psf_from_id(psf_id, check_cache=True, psf_cache_dir=None, **calc_psf_kwargs):
"""
Wrapper method for `get_roman_psf` that accepts the PSF's identifier string.
Parameters
----------
psf_id : str
The identifier for the PSF, which encodes various parameters.
check_cache : bool, optional
If True, check the cache for an existing PSF before generating a new one. Default is True.
psf_cache_dir : str, optional
The directory where cached PSFs are stored. If None, defaults to the directory installed with mejiro. Default is None.
**calc_psf_kwargs : dict
Additional keyword arguments to pass to STPSF's `calc_psf` method.
Returns
-------
np.ndarray
The PSF kernel.
"""
band, detector, detector_position, oversample, num_pix = STPSFEngine.get_params_from_psf_id(psf_id)
return STPSFEngine.get_roman_psf(band, detector, detector_position, oversample, num_pix, check_cache, psf_cache_dir,
**calc_psf_kwargs)
[docs]
@staticmethod
def cache_psf(id_string, psf_cache_dir):
"""
Save a PSF to the provided directory.
Parameters
----------
id_string : str
The PSF identifier string.
psf_cache_dir : str
The directory where cached PSFs are stored.
Returns
-------
None
"""
psf_path = os.path.join(psf_cache_dir, f'{id_string}.npy')
if os.path.exists(psf_path):
logger.info(f'PSF {id_string} already cached to {psf_path}')
else:
psf = STPSFEngine.get_roman_psf_from_id(id_string, check_cache=False)
np.save(psf_path, psf)
logger.info(f'Cached PSF to {psf_path}')
[docs]
@staticmethod
def get_cached_psf(id_string, psf_cache_dir):
"""
Check if a PSF exists in the provided cache directory. If found, load and return it. Otherwise, return None.
Parameters
----------
id_string : str
The PSF identifier string.
psf_cache_dir : str or None
The directory where cached PSFs are stored. If None, defaults to the directory installed with mejiro.
Returns
-------
numpy.ndarray or None
The cached PSF if found, otherwise None.
"""
# if no psf cache directory provided, default to those installed with mejiro
if psf_cache_dir is None:
import mejiro
module_path = os.path.dirname(mejiro.__file__)
psf_cache_dir = os.path.join(module_path, 'data', 'cached_psfs')
psf_path = os.path.join(psf_cache_dir, f'{id_string}.npy')
if os.path.isfile(psf_path):
logger.debug(f'Loading cached PSF: {psf_path}')
return np.load(psf_path)
else:
logger.warning(f'PSF {id_string} not found in cache {psf_cache_dir}')
return None