import importlib
import os
import yaml
from glob import glob
import logging
import mejiro
from mejiro.utils import util
logger = logging.getLogger(__name__)
# Pipeline steps that have a JAX variant whose output lands in '<step>_jax/'
# (e.g. '04' -> '04_jax'). Consumers should resolve the directory via
# PipelineHelper.step_dir() so the variant follows jaxtronomy.use_jax.
JAX_VARIANT_STEPS = {'01b', '04'}
[docs]
class PipelineHelper:
def __init__(self, args, prev_script_name, script_name, supported_instruments, delete_existing_output=False):
self.prev_script_name = prev_script_name
self.script_name = script_name
self.supported_instruments = supported_instruments
# ensure the configuration file has a .yaml or .yml extension
if not args.config.endswith(('.yaml', '.yml')):
if os.path.exists(args.config + '.yaml'):
args.config += '.yaml'
elif os.path.exists(args.config + '.yml'):
args.config += '.yml'
else:
raise ValueError("The configuration file must be a YAML file with extension '.yaml' or '.yml'.")
# read configuration file
with open(args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.SafeLoader)
self.config = config
# configure logging level
logging_level = config.get('logging_level', 'INFO')
logging.basicConfig(level=getattr(logging, logging_level.upper(), logging.INFO))
# set data directory
self.data_dir = config['data_dir']
if hasattr(args, 'data_dir') and args.data_dir is not None:
logger.warning(f'Overriding data_dir in config file ({self.data_dir}) with provided data_dir ({args.data_dir})')
self.data_dir = args.data_dir
elif self.data_dir is None:
raise ValueError("data_dir must be specified either in the config file or via the --data_dir argument.")
# get attributes from config
self.dev = config['dev']
self.show_progress_bar = config['show_progress_bar']
self.limit = config['limit']
self.runs = config['survey']['runs']
self.detectors = config['survey']['detectors']
# set pipeline name
self.name = config['pipeline_label']
# set nice level
os.nice(config.get('nice', 0))
# suppress warnings
if config['suppress_warnings']:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
# load instrument
self.instrument_name = config['instrument'].lower()
if self.instrument_name not in self.supported_instruments:
raise ValueError(f"Unsupported instrument: {self.instrument_name}. Supported instruments are {self.supported_instruments}.")
self.instrument = self.initialize_instrument_class()
# set psf cache directory
self.psf_cache_dir = config['psf_cache_dir']
if self.psf_cache_dir is None:
self.psf_cache_dir = os.path.join(os.path.dirname(mejiro.__file__), 'data', 'psfs', self.instrument_name.lower())
elif os.path.isabs(self.psf_cache_dir):
pass
else:
candidate = os.path.join(self.data_dir, self.psf_cache_dir)
if not os.path.isdir(candidate):
pkg_candidate = os.path.abspath(os.path.join(os.path.dirname(mejiro.__file__), '..', self.psf_cache_dir))
if os.path.isdir(pkg_candidate):
candidate = pkg_candidate
self.psf_cache_dir = candidate
# set up top directory for all pipeline output
self.pipeline_dir = os.path.join(self.data_dir, self.config['pipeline_label'])
if self.dev:
self.pipeline_dir += '_dev'
# set up input directory for current script (resolving the JAX variant of the
# previous step, e.g. '04' -> '04_jax', when jaxtronomy.use_jax is True)
if self.prev_script_name is not None:
self.input_dir = self.step_dir(self.prev_script_name)
# set up output directory for current script
if self.script_name is None:
raise ValueError("script_name must be specified.")
self.output_dir = os.path.join(self.pipeline_dir, self.script_name)
util.create_directory_if_not_exists(self.output_dir)
if delete_existing_output:
util.clear_directory(self.output_dir)
[docs]
def step_dir(self, step):
"""Absolute path to a pipeline step's directory, selecting the JAX variant
(e.g. '04' -> '04_jax') when jaxtronomy.use_jax is True and one exists."""
if step in JAX_VARIANT_STEPS and self.config['jaxtronomy']['use_jax']:
step = f'{step}_jax'
return os.path.join(self.pipeline_dir, step)
[docs]
def calculate_process_count(self, count):
import multiprocessing
cpu_count = multiprocessing.cpu_count()
process_count = self.config['cores'][f'script_{self.script_name}']
if count < process_count:
process_count = count
logger.info(f'Spinning up {process_count} process(es) on {cpu_count} core(s)')
return process_count
[docs]
def parse_sca_from_filename(self, filename):
self.validate_instrument('roman')
# extract SCA from filename
dirname = os.path.dirname(filename)
sca = dirname.split('/')[-1]
if sca.startswith('sca'):
return int(sca[3:])
else:
raise ValueError(f'Invalid SCA format in filename: {filename}')
[docs]
def create_roman_sca_output_directories(self):
self.validate_instrument('roman')
# for a case where e.g. only 2 runs but 18 detectors, only create 2 folders
detectors_to_use = self.detectors
if self.runs < len(detectors_to_use):
detectors_to_use = detectors_to_use[:self.runs]
output_sca_dirs = []
for sca in detectors_to_use:
sca_dir = os.path.join(self.output_dir, f'sca{str(sca).zfill(2)}')
os.makedirs(sca_dir, exist_ok=True)
output_sca_dirs.append(sca_dir)
logger.info(f'Set up output directories {output_sca_dirs}')
return output_sca_dirs
[docs]
def parse_roman_uids(self, prefix, suffix, extension):
uids = set()
roman_pickles = self.retrieve_roman_pickles(prefix=prefix, suffix=suffix, extension=extension)
for f in roman_pickles:
basename = os.path.basename(f)
uid = basename.split("_")[-2]
uids.add(uid)
return sorted(uids)
[docs]
def retrieve_roman_pickles(self, prefix, suffix, extension):
self.validate_instrument('roman')
filename_pattern = f'{prefix}_{self.name}_*'
if suffix:
filename_pattern += f'_{suffix}'
filename_pattern += f'{extension}'
return sorted(glob(os.path.join(self.input_dir, 'sca*', filename_pattern)))
[docs]
def retrieve_pickles(self, prefix='', suffix='', extension='.pkl'):
filename_pattern = f'{prefix}_{self.name}_*'
if suffix:
filename_pattern += f'_{suffix}'
filename_pattern += f'{extension}'
return sorted(glob(os.path.join(self.input_dir, filename_pattern)))
[docs]
def initialize_instrument_class(self):
base_module_path = "mejiro.instruments"
class_map = {
"hwo": "HWO",
"jwst": "JWST",
"roman": "Roman"
}
if self.instrument_name.lower() not in class_map:
raise ValueError(f"Unknown instrument: {self.instrument_name}")
module_path = f"{base_module_path}.{self.instrument_name.lower()}"
module = importlib.import_module(module_path)
class_name = class_map[self.instrument_name.lower()]
cls = getattr(module, class_name)
return cls()
[docs]
def validate_instrument(self, instrument_name):
assert self.instrument_name == instrument_name, f"This method is only for the {instrument_name} instrument."
[docs]
@staticmethod
def patch_astropy_for_mejiro_v2_pickles():
"""
Install ``sys.modules`` aliases so that pickles produced under the
``mejiro-v2`` conda environment can be loaded under ``mejiro-v3``.
When this is needed
-------------------
Pickles created under ``mejiro-v2`` (which used astropy where
``astropy.cosmology.flrw`` was a package containing the submodules
``base``, ``lambdacdm``, ``w0cdm``, ``w0wacdm``, ``w0wzcdm``, ``wcdm``,
``wpwazpcdm``, and the compiled ``scalar_inv_efuncs`` Cython extension)
embed those fully-qualified module paths in the pickle stream. In
``mejiro-v3`` (astropy 7.x), the public ``astropy.cosmology.flrw`` is a
single flat module and the real submodules live under
``astropy.cosmology._src.flrw``. Unpickling therefore fails with::
ModuleNotFoundError: No module named
'astropy.cosmology.flrw.lambdacdm';
'astropy.cosmology.flrw' is not a package
Call this function once, at the top of a pipeline script's ``main`` (or
before any unpickling), whenever that script consumes pickles produced
by ``mejiro-v2``. After re-pickling those artifacts under ``mejiro-v3``,
the call can be removed without further changes — it is a pure
``sys.modules`` shim with no other side effects.
"""
import sys
import importlib
submodules = (
'base', 'lambdacdm', 'w0cdm', 'w0wacdm', 'w0wzcdm',
'wcdm', 'wpwazpcdm', 'scalar_inv_efuncs',
)
for name in submodules:
try:
sys.modules.setdefault(
f'astropy.cosmology.flrw.{name}',
importlib.import_module(f'astropy.cosmology._src.flrw.{name}'),
)
except ModuleNotFoundError:
pass