Source code for mejiro.pipeline.romanisim_pipeline

"""
Runs romanisim detector simulation on tiled synthetic images.

This script tiles SyntheticImage pickles into 4088x4088 detector arrays, runs romanisim
to apply detector effects, and extracts individual cutouts. Systems are processed in
batches of 3136 (56x56 grid of 73x73 tiles) until all systems for each SCA/band are
complete. Multiprocessing is used to parallelize batch processing.

Usage:
    python3 romanisim_pipeline.py --config <config.yaml> [--resume]

Arguments:
    --config: Path to the YAML configuration file.
    --resume: Preserve existing output and skip already-completed batches (those with a
        batch_complete_*.txt sentinel). Default is to delete existing output and rebuild
        from scratch.
"""
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

import argparse
import json
import logging
import math
import time
from collections import defaultdict, deque
from concurrent.futures import ThreadPoolExecutor
from multiprocessing import Pool
from copy import deepcopy
from glob import glob

import galsim
import matplotlib.pyplot as plt
import numpy as np
from astropy import table, time as astro_time
from astropy.coordinates import SkyCoord
from astropy import units as u
from matplotlib.colors import LogNorm
from tqdm import tqdm

import romanisim.bandpass
from romanisim import image, parameters, wcs

import mejiro
from mejiro.utils import util as mejiro_util
from mejiro.utils.pipeline_helper import PipelineHelper

logger = logging.getLogger(__name__)

TILE_SIZE = 73
GRID_SIDE = 56
N_TILES = GRID_SIDE * GRID_SIDE  # 3136


def _load_pickle(pickle_path):
    # Despite the name, this also handles lightweight .npz files via the
    # unified loader — kept as a single entry point so the ThreadPool below
    # doesn't need to branch on extension.
    try:
        return mejiro_util.load_synthetic_image(pickle_path)
    except EOFError:
        raise EOFError(f'Corrupted SyntheticImage file: {pickle_path}')


def _lens_id_from_pickle(pickle_path, band):
    bn = os.path.basename(pickle_path)
    for ext in ('.pkl', '.npz'):
        suffix = f'_{band}{ext}'
        if bn.startswith('SyntheticImage_') and bn.endswith(suffix):
            return bn[len('SyntheticImage_'):-len(suffix)]
    raise AssertionError(bn)


[docs] def exposure_cutout_name(input_path): """Derive the Exposure ``.npy`` cutout filename from a SyntheticImage input path. Handles both full ``.pkl`` pickles and lightweight ``.npz`` inputs by stripping whatever extension is present before appending ``.npy`` (avoids ``.npz.npy``). """ stem = os.path.splitext(os.path.basename(input_path))[0] return stem.replace('SyntheticImage_', 'Exposure_') + '.npy'
def _glob_synthetic_images(directory): """Return SyntheticImage files in ``directory`` regardless of extension.""" from glob import glob as _glob return sorted( _glob(os.path.join(directory, 'SyntheticImage_*.pkl')) + _glob(os.path.join(directory, 'SyntheticImage_*.npz')) ) def _read_detector_position(pickle_path): sidecar = pickle_path + '.psfpos.json' if os.path.exists(sidecar): with open(sidecar) as f: x, y = json.load(f)['detector_position'] return int(x), int(y) obj = _load_pickle(pickle_path) x, y = obj.instrument_params['detector_position'] return int(x), int(y) def main(args): start = time.time() PipelineHelper.patch_astropy_for_mejiro_v2_pickles() # remove after re-pickling inputs under mejiro-v3 logging.basicConfig( level=logging.INFO, format='%(asctime)s %(levelname)s %(name)s: %(message)s' ) # read config config_file = os.path.join(os.path.dirname(mejiro.__file__), 'data', 'mejiro_config', args.config) with open(config_file, 'r') as f: import yaml config = yaml.load(f, Loader=yaml.SafeLoader) if config['dev']: config['pipeline_label'] += '_dev' limit = config.get('limit') num_workers = config['cores']['script_05_romanisim'] threads_per_worker = max(2, 64 // num_workers) bands = config['synthetic_image']['bands'] seed = config['seed'] divide_up_detector = int(config['psf']['divide_up_detector']) assert GRID_SIDE % divide_up_detector == 0, ( f'GRID_SIDE ({GRID_SIDE}) must be divisible by divide_up_detector ' f'({divide_up_detector}); valid values: 1, 2, 4, 7, 8, 14, 28, 56' ) assert 4088 % divide_up_detector == 0, ( f'4088 must be divisible by divide_up_detector ({divide_up_detector})' ) tps = GRID_SIDE // divide_up_detector # tiles per bucket along one axis tiles_per_bucket = tps * tps # e.g. 196 for N=4 sub_px = 4088 // divide_up_detector # detector pixels per bucket along one axis ma_table_number = config['exposure']['ma_table_number'] date = config['exposure']['date'] if not isinstance(date, str): date = date.isoformat() coord = SkyCoord(ra=config['exposure']['coordinates']['ra'] * u.deg, dec=config['exposure']['coordinates']['dec'] * u.deg) read_pattern = parameters.read_pattern[ma_table_number] exptime = parameters.read_time * read_pattern[-1][-1] logger.info(f'Total exposure time (MA table {ma_table_number}): {exptime:.1f} s') # discover SCA directories and group SyntheticImage pickles by SCA and band # (read from the JAX step-04 variant when jaxtronomy.use_jax is set) synth_step = '04_jax' if config['jaxtronomy']['use_jax'] else '04' data_dir = os.path.join(config['data_dir'], config['pipeline_label'], synth_step) sca_dirs = sorted(glob(os.path.join(data_dir, 'sca*'))) logger.info(f'Found {len(sca_dirs)} SCA directories in {data_dir}') pickles_by_sca_band = {} for sca_dir in sca_dirs: sca_name = os.path.basename(sca_dir) sca_num = int(sca_name[3:]) sca_pickles = _glob_synthetic_images(sca_dir) by_band = defaultdict(list) for p in sca_pickles: bn = os.path.basename(p) # strip either extension before splitting; band is the final token stem = bn.replace('.pkl', '').replace('.npz', '') band = stem.split('_')[-1] by_band[band].append(p) pickles_by_sca_band[sca_num] = dict(by_band) for band, ps in sorted(by_band.items()): logger.debug(f' SCA {sca_num:02d}, {band}: {len(ps)} pickles') # output directory output_dir = os.path.join(config['data_dir'], config['pipeline_label'], '05_romanisim') os.makedirs(output_dir, exist_ok=True) if not args.resume: existing = [p for p in glob(os.path.join(output_dir, '**', '*'), recursive=True) if os.path.isfile(p)] if existing: logger.warning( f'Deleting {len(existing)} existing output file(s) in ' f'{output_dir} and rebuilding from scratch. Pass --resume to keep them.' ) mejiro_util.clear_directory(output_dir) logger.info(f'Tile size: {TILE_SIZE}x{TILE_SIZE}') logger.info(f'Grid: {GRID_SIDE}x{GRID_SIDE} = {N_TILES} tiles per batch') logger.info(f'PSF buckets: {divide_up_detector}x{divide_up_detector} = {divide_up_detector * divide_up_detector} ' f'({tps}x{tps} = {tiles_per_bucket} tiles per bucket)') logger.info(f'Bands: {bands}') # build task list and prepare output directories tasks = [] for sca_num, bands_dict in sorted(pickles_by_sca_band.items()): sca_output_dir = os.path.join(output_dir, f'sca{str(sca_num).zfill(2)}') os.makedirs(sca_output_dir, exist_ok=True) # pick the lens-ID subsample once per SCA so every band processes the same # systems in the same order — otherwise a per-band np.random.choice puts the # same lens at different tile positions in each band's tiled PNG all_lens_ids_per_band = { band: sorted({_lens_id_from_pickle(p, band) for p in ps}) for band, ps in bands_dict.items() } union_lens_ids = sorted(set().union(*all_lens_ids_per_band.values())) if all_lens_ids_per_band else [] if limit is not None and limit < len(union_lens_ids): if args.sequential: selected_lens_ids = union_lens_ids[:limit] else: rng = np.random.default_rng(seed + sca_num) selected_lens_ids = sorted(rng.choice(union_lens_ids, limit, replace=False).tolist()) logger.info(f'SCA {sca_num:02d}: limiting to {limit} lens system(s)') else: selected_lens_ids = union_lens_ids selected_set = set(selected_lens_ids) for band_idx, band in enumerate(bands): if band not in bands_dict: logger.info(f'Skipping SCA {sca_num:02d}, {band}: no pickles found') continue all_pickles = sorted(p for p in bands_dict[band] if _lens_id_from_pickle(p, band) in selected_set) count = len(all_pickles) logger.info(f'SCA {sca_num:02d}, {band}: processing {count} image(s)') # resolve detector_position for each pickle (sidecar JSON fast-path, pickle-load fallback) with ThreadPoolExecutor(max_workers=threads_per_worker) as exe: positions = list(exe.map(_read_detector_position, all_pickles)) # group by PSF bucket; preserve input (sorted) order within each bucket by_bucket = defaultdict(deque) for pickle_path, (x, y) in zip(all_pickles, positions): bi = min(divide_up_detector - 1, x // sub_px) bj = min(divide_up_detector - 1, y // sub_px) by_bucket[(bi, bj)].append(pickle_path) n_batches = max((math.ceil(len(q) / tiles_per_bucket) for q in by_bucket.values()), default=0) logger.info(f'SCA {sca_num:02d}, {band}: {len(all_pickles)} images in {n_batches} batch(es); ' f'bucket sizes: {sorted(len(q) for q in by_bucket.values())}') # round-robin pop up to tiles_per_bucket per bucket per batch; each batch is a # list of (pickle_path, tile_r, tile_c) triples placed inside the bucket's sub-grid. for batch_idx in range(n_batches): batch_items = [] for (bi, bj), q in by_bucket.items(): for k in range(min(tiles_per_bucket, len(q))): pickle_path = q.popleft() tile_r = bj * tps + (k // tps) tile_c = bi * tps + (k % tps) batch_items.append((pickle_path, tile_r, tile_c)) tasks.append(( batch_items, sca_num, band, batch_idx, n_batches, sca_output_dir, exptime, seed, ma_table_number, date, coord, band_idx, threads_per_worker, )) # resume: skip batches whose completion sentinel already exists def _batch_sentinel(sca_output_dir, sca_num, band, batch_idx): return os.path.join(sca_output_dir, f'batch_complete_sca{sca_num:02d}_{band}_batch{batch_idx}.txt') if args.resume: total_batches = len(tasks) tasks = [t for t in tasks if not os.path.exists(_batch_sentinel(t[5], t[1], t[2], t[3]))] skipped = total_batches - len(tasks) logger.info( f'Resuming: {skipped} of {total_batches} batch(es) already complete, ' f'{len(tasks)} remaining.' ) if not tasks: logger.info('All batches already complete. Nothing to do.') stop = time.time() execution_time = mejiro_util.print_execution_time(start, stop, return_string=True) logger.info(f'Total execution time: {execution_time}') return logger.info(f'Submitting {len(tasks)} batch(es) with {num_workers} workers') # process tasks in parallel; maxtasksperchild=1 recycles each worker after one batch, # releasing all C-extension caches (romanisim, galsim) that gc.collect() cannot touch with Pool(processes=num_workers, maxtasksperchild=1) as pool: for _ in tqdm(pool.imap_unordered(process_batch, tasks), total=len(tasks)): pass stop = time.time() execution_time = mejiro_util.print_execution_time(start, stop, return_string=True) logger.info(f'Total execution time: {execution_time}') def process_batch(task): (batch_items, sca_num, band, batch_idx, n_batches, sca_output_dir, exptime, seed, ma_table_number, date, coord, band_idx, threads_per_worker) = task try: n_images = len(batch_items) logger.info(f'SCA {sca_num:02d}, {band}, batch {batch_idx + 1}/{n_batches}: {n_images} images') # per-batch deterministic RNGs batch_seed = seed + sca_num * 10000 + band_idx * 1000 + batch_idx rng = galsim.UniformDeviate(batch_seed) rng_np = np.random.default_rng(batch_seed) # get AB flux abflux = romanisim.bandpass.get_abflux(band, sca_num) logger.info(f' AB flux: {abflux:.6e} e-/s per maggy, exptime: {exptime:.6f} s') # 1. tile synthetic images into a 4088x4088 extra_counts array counts = np.zeros((4088, 4088), dtype=np.float64) batch_pickles = [item[0] for item in batch_items] logger.info(f' Loading {len(batch_pickles)} pickles with {threads_per_worker} threads...') load_start = time.time() with ThreadPoolExecutor(max_workers=threads_per_worker) as pool: synth_images = list(pool.map(_load_pickle, batch_pickles)) load_stop = time.time() logger.info(f' Loaded {len(batch_pickles)} pickles in {mejiro_util.print_execution_time(load_start, load_stop, return_string=True)}') for (_, tile_r, tile_c), synth in zip(batch_items, synth_images): # deal with negative and nan pixels smooth_data = np.asarray(mejiro_util.smooth_pixels(synth.data), dtype=np.float64) # convert the units synth_sum = np.sum(smooth_data, dtype=np.float64) maggies = synth.get_maggies() total_electrons = maggies * abflux * exptime lens_electrons = (smooth_data / synth_sum) * total_electrons # place inside the PSF-bucket sub-grid assigned to this image r0 = tile_r * TILE_SIZE c0 = tile_c * TILE_SIZE counts[r0:r0 + TILE_SIZE, c0:c0 + TILE_SIZE] = lens_electrons plt.imshow(counts, norm=LogNorm(), origin='lower') plt.colorbar(label='Electrons') plt.title(f'SCA {sca_num:02d}, {band} batch {batch_idx} - Tiled Synthetic Images') plt.savefig(os.path.join(sca_output_dir, f'sca{sca_num:02d}_{band}_batch{batch_idx}_tiled.png')) plt.close() # 2. add Poisson noise realized = rng_np.poisson(counts).astype(np.int32) extra_counts = galsim.ImageI(realized) # 3. set up romanisim metadata with the correct SCA meta = deepcopy(parameters.default_parameters_dictionary) meta['instrument']['detector'] = f'WFI{sca_num:02d}' meta['instrument']['optical_element'] = band meta['exposure']['ma_table_number'] = ma_table_number meta['exposure']['read_pattern'] = parameters.read_pattern[ma_table_number] meta['exposure']['start_time'] = astro_time.Time(date) wcs.fill_in_parameters(meta, coord, boresight=True) # 4. create empty source table source_catalog = table.Table({ 'ra': np.array([], dtype='f8'), 'dec': np.array([], dtype='f8'), 'type': np.array([], dtype='U3'), 'n': np.array([], dtype='f4'), 'half_light_radius': np.array([], dtype='f4'), 'pa': np.array([], dtype='f4'), 'ba': np.array([], dtype='f4'), band: np.array([], dtype='f4'), }) # 5. simulate logger.info(f' Running romanisim for SCA {sca_num:02d}, {band} batch {batch_idx}...') im, extras = image.simulate( meta, source_catalog, usecrds=False, psftype='galsim', level=2, rng=rng, crparam=dict(), extra_counts=extra_counts, ) # save the romanisim output mejiro_util.pickle(os.path.join(sca_output_dir, f'im_sca{sca_num:02d}_{band}_batch{batch_idx}.pkl'), im) mejiro_util.pickle(os.path.join(sca_output_dir, f'extras_sca{sca_num:02d}_{band}_batch{batch_idx}.pkl'), extras) # 6. extract cutouts and save as .npy result_data = im.data np.save(os.path.join(sca_output_dir, f'full_array_sca{sca_num:02d}_{band}_batch{batch_idx}.npy'), result_data) for pickle_path, tile_r, tile_c in batch_items: r0 = tile_r * TILE_SIZE c0 = tile_c * TILE_SIZE cutout = result_data[r0:r0 + TILE_SIZE, c0:c0 + TILE_SIZE] output_name = exposure_cutout_name(pickle_path) np.save(os.path.join(sca_output_dir, output_name), cutout) logger.info(f' Saved {n_images} cutouts to {sca_output_dir}') # write the completion sentinel only after all artifacts for this batch are on disk sentinel_path = os.path.join(sca_output_dir, f'batch_complete_sca{sca_num:02d}_{band}_batch{batch_idx}.txt') with open(sentinel_path, 'w') as f: f.write(str(n_images)) except Exception: logger.exception(f'Batch failed for SCA {sca_num:02d}, {band}, batch {batch_idx + 1}/{n_batches}') if __name__ == '__main__': parser = argparse.ArgumentParser(description="Run romanisim detector simulation on tiled synthetic images.") parser.add_argument('--config', type=str, required=True, help='Name of the yaml configuration file.') parser.add_argument('--sequential', action='store_true', default=False, help='Process systems sequentially from the start instead of randomly when limit is imposed.') parser.add_argument('--resume', action='store_true', default=False, help='Preserve existing output and skip already-completed items. Default is to delete and rebuild from scratch.') args = parser.parse_args() main(args)