diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py b/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py new file mode 100644 index 000000000..fe5088369 --- /dev/null +++ b/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py @@ -0,0 +1,2 @@ +from .brightStarCutout import * +from .brightStarStack import * diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py new file mode 100644 index 000000000..3be878eb0 --- /dev/null +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py @@ -0,0 +1,691 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Extract bright star cutouts; normalize and warp, optionally fit the PSF.""" + +__all__ = ["BrightStarCutoutConnections", "BrightStarCutoutConfig", "BrightStarCutoutTask"] + +from typing import Any, Iterable, cast + +import astropy.units as u +import numpy as np +from astropy.coordinates import SkyCoord +from astropy.table import Table +from lsst.afw.cameraGeom import FIELD_ANGLE, PIXELS +from lsst.afw.detection import Footprint, FootprintSet, Threshold +from lsst.afw.geom import SkyWcs, SpanSet, makeModifiedWcs +from lsst.afw.geom.transformFactory import makeTransform +from lsst.afw.image import ExposureF, ImageD, ImageF, MaskedImageF +from lsst.afw.math import BackgroundList, FixedKernel, WarpingControl, warpImage +from lsst.daf.butler import DataCoordinate +from lsst.geom import ( + AffineTransform, + Box2I, + Extent2D, + Extent2I, + Point2D, + Point2I, + SpherePoint, + arcseconds, + floor, + radians, +) +from lsst.meas.algorithms import ( + BrightStarStamp, + BrightStarStamps, + KernelPsf, + LoadReferenceObjectsConfig, + ReferenceObjectLoader, + WarpedPsf, +) +from lsst.pex.config import ChoiceField, ConfigField, Field, ListField +from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct +from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput +from lsst.utils.timer import timeMethod +from copy import deepcopy + +NEIGHBOR_MASK_PLANE = "NEIGHBOR" + + +class BrightStarCutoutConnections( + PipelineTaskConnections, + dimensions=("instrument", "visit", "detector"), +): + """Connections for BrightStarCutoutTask.""" + + refCat = PrerequisiteInput( + name="gaia_dr3_20230707", + storageClass="SimpleCatalog", + doc="Reference catalog that contains bright star positions.", + dimensions=("skypix",), + multiple=True, + deferLoad=True, + ) + inputExposure = Input( + name="calexp", + storageClass="ExposureF", + doc="Background-subtracted input exposure from which to extract bright star stamp cutouts.", + dimensions=("visit", "detector"), + ) + inputBackground = Input( + name="calexpBackground", + storageClass="Background", + doc="Background model for the input exposure, to be added back on during processing.", + dimensions=("visit", "detector"), + ) + extendedPsf = Input( + name="extendedPsf2", + storageClass="ImageF", + doc="Extended PSF model, built from stacking bright star cutouts.", + dimensions=("band",), + ) + brightStarStamps = Output( + name="brightStarStamps", + storageClass="BrightStarStamps", + doc="Set of preprocessed postage stamp cutouts, each centered on a single bright star.", + dimensions=("visit", "detector"), + ) + + def __init__(self, *, config: "BrightStarCutoutConfig | None" = None): + super().__init__(config=config) + assert config is not None + if not config.useExtendedPsf: + self.inputs.remove("extendedPsf") + + +class BrightStarCutoutConfig( + PipelineTaskConfig, + pipelineConnections=BrightStarCutoutConnections, +): + """Configuration parameters for BrightStarCutoutTask.""" + + # Star selection + magRange = ListField[float]( + doc="Magnitude range in Gaia G. Cutouts will be made for all stars in this range.", + default=[0, 18], + ) + excludeArcsecRadius = Field[float]( + doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeArcsecRadius`` are not used.", + default=5, + ) + excludeMagRange = ListField[float]( + doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeArcsecRadius`` are not used.", + default=[0, 20], + ) + minAreaFraction = Field[float]( + doc="Minimum fraction of the stamp area, post-masking, that must remain for a cutout to be retained.", + default=0.1, + ) + badMaskPlanes = ListField[str]( + doc="Mask planes that identify excluded pixels for the calculation of ``minAreaFraction`` and, " + "optionally, fitting of the PSF.", + default=[ + "BAD", + "CR", + "CROSSTALK", + "EDGE", + "NO_DATA", + "SAT", + "SUSPECT", + "UNMASKEDNAN", + NEIGHBOR_MASK_PLANE, + ], + ) + + # Cutout geometry + stampSize = ListField[int]( + doc="Size of the stamps to be extracted, in pixels.", + default=(251, 251), + ) + stampSizePadding = Field[float]( + doc="Multiplicative factor applied to the cutout stamp size, to guard against post-warp data loss.", + default=1.1, + ) + warpingKernelName = ChoiceField[str]( + doc="Warping kernel.", + default="lanczos5", + allowed={ + "bilinear": "bilinear interpolation", + "lanczos3": "Lanczos kernel of order 3", + "lanczos4": "Lanczos kernel of order 4", + "lanczos5": "Lanczos kernel of order 5", + }, + ) + maskWarpingKernelName = ChoiceField[str]( + doc="Warping kernel for mask.", + default="bilinear", + allowed={ + "bilinear": "bilinear interpolation", + "lanczos3": "Lanczos kernel of order 3", + "lanczos4": "Lanczos kernel of order 4", + "lanczos5": "Lanczos kernel of order 5", + }, + ) + + scalePsfModel = Field[bool]( + doc="If True, uses a scale factor to bring the PSF model data to the same level of the star data.", + default=True, + ) + + # PSF Fitting + useExtendedPsf = Field[bool]( + doc="Use the extended PSF model to normalize bright star cutouts.", + default=False, + ) + doFitPsf = Field[bool]( + doc="Fit a scaled PSF and a pedestal to each bright star cutout.", + default=True, + ) + useMedianVariance = Field[bool]( + doc="Use the median of the variance plane for PSF fitting.", + default=False, + ) + psfMaskedFluxFracThreshold = Field[float]( + doc="Maximum allowed fraction of masked PSF flux for PSF fitting to occur.", + default=0.97, + ) + + # Misc + loadReferenceObjectsConfig = ConfigField[LoadReferenceObjectsConfig]( + doc="Reference object loader for astrometric calibration.", + ) + + +class BrightStarCutoutTask(PipelineTask): + """Extract bright star cutouts; normalize and warp to the same pixel grid. + + The BrightStarCutoutTask is used to extract, process, and store small image + cutouts (or "postage stamps") around bright stars. + This task essentially consists of three principal steps. + First, it identifies bright stars within an exposure using a reference + catalog and extracts a stamp around each. + Second, it shifts and warps each stamp to remove optical distortions and + sample all stars on the same pixel grid. + Finally, it optionally fits a PSF plus plane flux model to the cutout. + This final fitting procedure may be used to normalize each bright star + stamp prior to stacking when producing extended PSF models. + """ + + ConfigClass = BrightStarCutoutConfig + _DefaultName = "brightStarCutout" + config: BrightStarCutoutConfig + + def __init__(self, initInputs=None, *args, **kwargs): + super().__init__(*args, **kwargs) + stampSize = Extent2D(*self.config.stampSize.list()) + stampRadius = floor(stampSize / 2) + self.stampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(stampRadius) + paddedStampSize = stampSize * self.config.stampSizePadding + self.paddedStampRadius = floor(paddedStampSize / 2) + self.paddedStampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy( + self.paddedStampRadius + ) + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + inputs = butlerQC.get(inputRefs) + inputs["dataId"] = butlerQC.quantum.dataId + refObjLoader = ReferenceObjectLoader( + dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat], + refCats=inputs.pop("refCat"), + name=self.config.connections.refCat, + config=self.config.loadReferenceObjectsConfig, + ) + extendedPsf = inputs.pop("extendedPsf", None) + output = self.run(**inputs, extendedPsf=extendedPsf, refObjLoader=refObjLoader) + # Only ingest Stamp if it exists; prevents ingesting an empty FITS file + if output: + butlerQC.put(output, outputRefs) + + @timeMethod + def run( + self, + inputExposure: ExposureF, + inputBackground: BackgroundList, + extendedPsf: ImageF | None, + refObjLoader: ReferenceObjectLoader, + dataId: dict[str, Any] | DataCoordinate, + ): + """Identify bright stars within an exposure using a reference catalog, + extract stamps around each, warp/shift stamps onto a common frame and + then optionally fit a PSF plus plane model. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The background-subtracted image to extract bright star stamps. + inputBackground : `~lsst.afw.math.BackgroundList` + The background model associated with the input exposure. + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional + Loader to find objects within a reference catalog. + dataId : `dict` or `~lsst.daf.butler.DataCoordinate` + The dataId of the exposure that bright stars are extracted from. + Both 'visit' and 'detector' will be persisted in the output data. + + Returns + ------- + brightStarResults : `~lsst.pipe.base.Struct` + Results as a struct with attributes: + + ``brightStarStamps`` + (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`) + """ + wcs = inputExposure.getWcs() + bbox = inputExposure.getBBox() + warpingControl = WarpingControl(self.config.warpingKernelName, self.config.maskWarpingKernelName) + + refCatBright = self._getRefCatBright(refObjLoader, wcs, bbox) + zipRaDec = zip(refCatBright["coord_ra"] * radians, refCatBright["coord_dec"] * radians) + spherePoints = [SpherePoint(ra, dec) for ra, dec in zipRaDec] + pixCoords = wcs.skyToPixel(spherePoints) + + # Restore original subtracted background + inputMI = inputExposure.getMaskedImage() + inputMI += inputBackground.getImage() + # Amir: the above addition to inputMI, also adds to the inputExposure. + # Amir: but the calibration, three lines later, only is applied to the inputMI. + + # Set up NEIGHBOR mask plane; associate footprints with stars + inputExposure.mask.addMaskPlane(NEIGHBOR_MASK_PLANE) + allFootprints, associations = self._associateFootprints(inputExposure, pixCoords, plane="DETECTED") + + # TODO: If we eventually have better PhotoCalibs (eg FGCM), apply here + inputMI = inputExposure.getPhotoCalib().calibrateImage(inputMI, False) + + # Set up transform + detector = inputExposure.detector + pixelScale = wcs.getPixelScale().asArcseconds() * arcseconds + pixToFocalPlaneTan = detector.getTransform(PIXELS, FIELD_ANGLE).then( + makeTransform(AffineTransform.makeScaling(1 / pixelScale.asRadians())) + ) + + # Loop over each bright star + stamps, goodFracs, stamps_fitPsfResults = [], [], [] + for starIndex, (obj, pixCoord) in enumerate(zip(refCatBright, pixCoords)): # type: ignore + footprintIndex = associations.get(starIndex, None) + stampMI = MaskedImageF(self.paddedStampBBox) + + # Set NEIGHBOR footprints in the mask plane + if footprintIndex: + neighborFootprints = [fp for i, fp in enumerate(allFootprints) if i != footprintIndex] + # self._setFootprints(inputExposure, neighborFootprints, NEIGHBOR_MASK_PLANE) + self._setFootprints(inputMI, neighborFootprints, NEIGHBOR_MASK_PLANE) + else: + # self._setFootprints(inputExposure, allFootprints, NEIGHBOR_MASK_PLANE) + self._setFootprints(inputMI, allFootprints, NEIGHBOR_MASK_PLANE) + + # Define linear shifting to recenter stamps + coordFocalPlaneTan = pixToFocalPlaneTan.applyForward(pixCoord) # center of warped star + shift = makeTransform(AffineTransform(Point2D(0, 0) - coordFocalPlaneTan)) + angle = np.arctan2(coordFocalPlaneTan.getY(), coordFocalPlaneTan.getX()) * radians + rotation = makeTransform(AffineTransform.makeRotation(-angle)) + pixToPolar = pixToFocalPlaneTan.then(shift).then(rotation) + + # Apply the warp to the star stamp (in-place) + # warpImage(stampMI, inputExposure.maskedImage, pixToPolar, warpingControl) + warpImage(stampMI, inputMI, pixToPolar, warpingControl) + + # Trim to the base stamp size, check mask coverage, update metadata + stampMI = stampMI[self.stampBBox] + badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes) + goodFrac = np.sum(stampMI.mask.array & badMaskBitMask == 0) / stampMI.mask.array.size + goodFracs.append(goodFrac) + if goodFrac < self.config.minAreaFraction: + continue + + # Fit a scaled PSF and a pedestal to each bright star cutout + psf = WarpedPsf(inputExposure.getPsf(), pixToPolar, warpingControl) + constantPsf = KernelPsf(FixedKernel(psf.computeKernelImage(Point2D(0, 0)))) + # TODO: discuss with Lee whether we should warp the psf here as well? + if self.config.useExtendedPsf: + psfImage = extendedPsf # Assumed to be warped, center at [0,0] + else: + psfImage = constantPsf.computeKernelImage(constantPsf.getAveragePosition()) + if self.config.scalePsfModel: + psfNeg = psfImage.array < 0 + self.modelScale = np.nanmean(stampMI.image.array) / np.nanmean(psfImage.array[~psfNeg]) + psfImage.array *= self.modelScale ######## model scale correction ######## + else: + self.modelScale = 1 + + fitPsfResults = {} + if self.config.doFitPsf: + fitPsfResults = self._fitPsf(stampMI, psfImage) + stamps_fitPsfResults.append(fitPsfResults) + + # Save the stamp if the PSF fit was successful or no fit requested + if fitPsfResults or not self.config.doFitPsf: + stamp = BrightStarStamp( + maskedImage=stampMI, + # TODO: what to do about this PSF? + psf=constantPsf, + wcs=makeModifiedWcs(pixToPolar, wcs, False), + visit=cast(int, dataId["visit"]), + detector=cast(int, dataId["detector"]), + refId=obj["id"], + refMag=obj["mag"], + position=pixCoord, + scale=fitPsfResults.get("scale", None), + scaleErr=fitPsfResults.get("scaleErr", None), + pedestal=fitPsfResults.get("pedestal", None), + pedestalErr=fitPsfResults.get("pedestalErr", None), + pedestalScaleCov=fitPsfResults.get("pedestalScaleCov", None), + xGradient=fitPsfResults.get("xGradient", None), + yGradient=fitPsfResults.get("yGradient", None), + globalReducedChiSquared=fitPsfResults.get("globalReducedChiSquared", None), + globalDegreesOfFreedom=fitPsfResults.get("globalDegreesOfFreedom", None), + psfReducedChiSquared=fitPsfResults.get("psfReducedChiSquared", None), + psfDegreesOfFreedom=fitPsfResults.get("psfDegreesOfFreedom", None), + psfMaskedFluxFrac=fitPsfResults.get("psfMaskedFluxFrac", None), + ) + stamps.append(stamp) + + self.log.info( + "Extracted %i bright star stamp%s. " + "Excluded %i star%s: insufficient area (%i), PSF fit failure (%i).", + len(stamps), + "" if len(stamps) == 1 else "s", + len(refCatBright) - len(stamps), + "" if len(refCatBright) - len(stamps) == 1 else "s", + np.sum(np.array(goodFracs) < self.config.minAreaFraction), + ( + np.sum(np.isnan([x.get("pedestal", np.nan) for x in stamps_fitPsfResults])) + if self.config.doFitPsf + else 0 + ), + ) + brightStarStamps = BrightStarStamps(stamps) + return Struct(brightStarStamps=brightStarStamps) + + def _getRefCatBright(self, refObjLoader: ReferenceObjectLoader, wcs: SkyWcs, bbox: Box2I) -> Table: + """Get a bright star subset of the reference catalog. + + Trim the reference catalog to only those objects within the exposure + bounding box dilated by half the bright star stamp size. + This ensures all stars that overlap the exposure are included. + + Parameters + ---------- + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader` + Loader to find objects within a reference catalog. + wcs : `~lsst.afw.geom.SkyWcs` + World coordinate system. + bbox : `~lsst.geom.Box2I` + Bounding box of the exposure. + + Returns + ------- + refCatBright : `~astropy.table.Table` + Bright star subset of the reference catalog. + """ + dilatedBBox = bbox.dilatedBy(self.paddedStampRadius) + withinExposure = refObjLoader.loadPixelBox(dilatedBBox, wcs, filterName="phot_g_mean") + refCatFull = withinExposure.refCat + fluxField: str = withinExposure.fluxField + + proxFluxRange = sorted(((self.config.excludeMagRange * u.ABmag).to(u.nJy)).to_value()) + brightFluxRange = sorted(((self.config.magRange * u.ABmag).to(u.nJy)).to_value()) + + subsetStars = (refCatFull[fluxField] > np.min((proxFluxRange[0], brightFluxRange[0]))) & ( + refCatFull[fluxField] < np.max((proxFluxRange[1], brightFluxRange[1])) + ) + refCatSubset = Table(refCatFull.extract("id", "coord_ra", "coord_dec", fluxField, where=subsetStars)) + + proxStars = (refCatSubset[fluxField] >= proxFluxRange[0]) & ( + refCatSubset[fluxField] <= proxFluxRange[1] + ) + brightStars = (refCatSubset[fluxField] >= brightFluxRange[0]) & ( + refCatSubset[fluxField] <= brightFluxRange[1] + ) + + coords = SkyCoord(refCatSubset["coord_ra"], refCatSubset["coord_dec"], unit="rad") + excludeArcsecRadius = self.config.excludeArcsecRadius * u.arcsec # type: ignore + refCatBrightIsolated = [] + for coord in cast(Iterable[SkyCoord], coords[brightStars]): + neighbors = coords[proxStars] + seps = coord.separation(neighbors).to(u.arcsec) + tooClose = (seps > 0) & (seps <= excludeArcsecRadius) # not self matched + refCatBrightIsolated.append(not tooClose.any()) + + refCatBright = cast(Table, refCatSubset[brightStars][refCatBrightIsolated]) + + fluxNanojansky = refCatBright[fluxField][:] * u.nJy # type: ignore + refCatBright["mag"] = fluxNanojansky.to(u.ABmag).to_value() # AB magnitudes + + self.log.info( + "Identified %i of %i star%s which satisfy: frame overlap; in the range %s mag; no neighboring " + "stars within %s arcsec.", + len(refCatBright), + len(refCatFull), + "" if len(refCatFull) == 1 else "s", + self.config.magRange, + self.config.excludeArcsecRadius, + ) + + return refCatBright + + def _associateFootprints( + self, inputExposure: ExposureF, pixCoords: list[Point2D], plane: str + ) -> tuple[list[Footprint], dict[int, int]]: + """Associate footprints from a given mask plane with specific objects. + + Footprints from the given mask plane are associated with objects at the + coordinates provided, where possible. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The input exposure with a mask plane. + pixCoords : `list` [`~lsst.geom.Point2D`] + The pixel coordinates of the objects. + plane : `str` + The mask plane used to identify masked pixels. + + Returns + ------- + footprints : `list` [`~lsst.afw.detection.Footprint`] + The footprints from the input exposure. + associations : `dict`[int, int] + Association indices between objects (key) and footprints (value). + """ + detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(plane), Threshold.BITMASK) + footprintSet = FootprintSet(inputExposure.mask, detThreshold) + footprints = footprintSet.getFootprints() + associations = {} + for starIndex, pixCoord in enumerate(pixCoords): + for footprintIndex, footprint in enumerate(footprints): + if footprint.contains(Point2I(pixCoord)): + associations[starIndex] = footprintIndex + break + self.log.debug( + "Associated %i of %i star%s to one each of the %i %s footprint%s.", + len(associations), + len(pixCoords), + "" if len(pixCoords) == 1 else "s", + len(footprints), + plane, + "" if len(footprints) == 1 else "s", + ) + return footprints, associations + + def _setFootprints(self, inputExposure: ExposureF, footprints: list, maskPlane: str): + """Set footprints in a given mask plane. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The input exposure to modify. + footprints : `list` [`~lsst.afw.detection.Footprint`] + The footprints to set in the mask plane. + maskPlane : `str` + The mask plane to set the footprints in. + + Notes + ----- + This method modifies the ``inputExposure`` object in-place. + """ + detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(maskPlane), Threshold.BITMASK) + detThresholdValue = int(detThreshold.getValue()) + footprintSet = FootprintSet(inputExposure.mask, detThreshold) + + # Wipe any existing footprints in the mask plane + inputExposure.mask.clearMaskPlane(int(np.log2(detThresholdValue))) + + # Set the footprints in the mask plane + footprintSet.setFootprints(footprints) + footprintSet.setMask(inputExposure.mask, maskPlane) + + def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, Any]: + """Fit a scaled PSF and a pedestal to each bright star cutout. + + Parameters + ---------- + stampMI : `~lsst.afw.image.MaskedImageF` + The masked image of the bright star cutout. + psfImage : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` + The PSF model to fit. + + Returns + ------- + fitPsfResults : `dict`[`str`, `float`] + The result of the PSF fitting, with keys: + + ``scale`` : `float` + The scale factor. + ``scaleErr`` : `float` + The error on the scale factor. + ``pedestal`` : `float` + The pedestal value. + ``pedestalErr`` : `float` + The error on the pedestal value. + ``pedestalScaleCov`` : `float` + The covariance between the pedestal and scale factor. + ``xGradient`` : `float` + The gradient in the x-direction. + ``yGradient`` : `float` + The gradient in the y-direction. + ``globalReducedChiSquared`` : `float` + The global reduced chi-squared goodness-of-fit. + ``globalDegreesOfFreedom`` : `int` + The global number of degrees of freedom. + ``psfReducedChiSquared`` : `float` + The PSF BBox reduced chi-squared goodness-of-fit. + ``psfDegreesOfFreedom`` : `int` + The PSF BBox number of degrees of freedom. + ``psfMaskedFluxFrac`` : `float` + The fraction of the PSF image flux masked by bad pixels. + """ + badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes) + + # Calculate the fraction of the PSF image flux masked by bad pixels + psfMaskedPixels = ImageF(psfImage.getBBox()) + psfMaskedPixels.array[:, :] = (stampMI.mask[psfImage.getBBox()].array & badMaskBitMask).astype(bool) + # TODO: This is np.float64, else FITS metadata serialization fails + # Amir: the following tries to find the fraction of the psf flux in the masked area of the psf image. + psfMaskedFluxFrac = np.dot(psfImage.array.flat, psfMaskedPixels.array.flat).astype(np.float64) / psfImage.array.sum() + # psfMaskedFluxFrac = np.dot(psfImage.array.astype(bool).flat, psfMaskedPixels.array.flat).astype(np.float64) / psfImage.array.astype(bool).sum() + if psfMaskedFluxFrac > self.config.psfMaskedFluxFracThreshold: + return {} # Handle cases where the PSF image is mostly masked + + # Create a padded version of the input constant PSF image + paddedPsfImage = ImageF(stampMI.getBBox()) + paddedPsfImage[psfImage.getBBox()] = psfImage.convertF() + + mask = self.add_psf_mask(paddedPsfImage, stampMI) + # Create consistently masked data + # badSpans = SpanSet.fromMask(stampMI.mask, badMaskBitMask) + badSpans = SpanSet.fromMask(mask, badMaskBitMask) + goodSpans = SpanSet(stampMI.getBBox()).intersectNot(badSpans) + varianceData = goodSpans.flatten(stampMI.variance.array, stampMI.getXY0()) + if self.config.useMedianVariance: + varianceData = np.median(varianceData) + sigmaData = np.sqrt(varianceData) + imageData = goodSpans.flatten(stampMI.image.array, stampMI.getXY0()) # B + imageData /= sigmaData + psfData = goodSpans.flatten(paddedPsfImage.array, paddedPsfImage.getXY0()) + psfData /= sigmaData + # Fit the PSF scale factor and global pedestal + nData = len(imageData) + coefficientMatrix = np.ones((nData, 4), dtype=float) # A + coefficientMatrix[:, 0] = psfData + coefficientMatrix[:, 1] /= sigmaData + coefficientMatrix[:, 2:] = goodSpans.indices().T + coefficientMatrix[:, 2] /= sigmaData + coefficientMatrix[:, 3] /= sigmaData + try: + solutions, sumSquaredResiduals, *_ = np.linalg.lstsq(coefficientMatrix, imageData, rcond=None) + covarianceMatrix = np.linalg.inv(np.dot(coefficientMatrix.transpose(), coefficientMatrix)) # C + except np.linalg.LinAlgError: + return {} # Handle singular matrix errors + if sumSquaredResiduals.size == 0: + return {} # Handle cases where sum of the squared residuals are empty + # scale = solutions[0] + scale = solutions[0] * self.modelScale ######## model scale correction ######## + if scale <= 0: + return {} # Handle cases where the PSF scale fit has failed + scaleErr = np.sqrt(covarianceMatrix[0, 0]) * self.modelScale ######## model scale correction ######## + pedestal = solutions[1] + pedestalErr = np.sqrt(covarianceMatrix[1, 1]) + scalePedestalCov = covarianceMatrix[0, 1] * self.modelScale ######## model scale correction ######## + xGradient = solutions[3] + yGradient = solutions[2] + + # Calculate global (whole image) reduced chi-squared + globalChiSquared = np.sum(sumSquaredResiduals) + globalDegreesOfFreedom = nData - 4 + globalReducedChiSquared = globalChiSquared / globalDegreesOfFreedom + + # Calculate PSF BBox reduced chi-squared + psfBBoxGoodSpans = goodSpans.clippedTo(psfImage.getBBox()) + psfBBoxGoodSpansX, psfBBoxGoodSpansY = psfBBoxGoodSpans.indices() + psfBBoxData = psfBBoxGoodSpans.flatten(stampMI.image.array, stampMI.getXY0()) + paddedPsfImage.array /= self.modelScale ######## model scale correction ######## + psfBBoxModel = ( + psfBBoxGoodSpans.flatten(paddedPsfImage.array, stampMI.getXY0()) * scale + + pedestal + + psfBBoxGoodSpansX * xGradient + + psfBBoxGoodSpansY * yGradient + ) + psfBBoxVariance = psfBBoxGoodSpans.flatten(stampMI.variance.array, stampMI.getXY0()) + psfBBoxResiduals = (psfBBoxData - psfBBoxModel) ** 2 / psfBBoxVariance + psfBBoxChiSquared = np.sum(psfBBoxResiduals) + psfBBoxDegreesOfFreedom = len(psfBBoxData) - 4 + psfBBoxReducedChiSquared = psfBBoxChiSquared / psfBBoxDegreesOfFreedom + return dict( + scale=scale, + scaleErr=scaleErr, + pedestal=pedestal, + pedestalErr=pedestalErr, + xGradient=xGradient, + yGradient=yGradient, + pedestalScaleCov=scalePedestalCov, + globalReducedChiSquared=globalReducedChiSquared, + globalDegreesOfFreedom=globalDegreesOfFreedom, + psfReducedChiSquared=psfBBoxReducedChiSquared, + psfDegreesOfFreedom=psfBBoxDegreesOfFreedom, + psfMaskedFluxFrac=psfMaskedFluxFrac, + ) + + def add_psf_mask(self, psfImage, stampMI): + cond = np.isnan(psfImage.array) + cond |= psfImage.array < 0 + mask = deepcopy(stampMI.mask) + mask.array[cond] = np.bitwise_or(mask.array[cond], 1) + return mask diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py new file mode 100644 index 000000000..a99d9be3e --- /dev/null +++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py @@ -0,0 +1,240 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Stack bright star postage stamp cutouts to produce an extended PSF model.""" + +__all__ = ["BrightStarStackConnections", "BrightStarStackConfig", "BrightStarStackTask"] + +import numpy as np +from lsst.afw.image import ImageF +from lsst.afw.math import StatisticsControl, statisticsStack, stringToStatisticsProperty +from lsst.geom import Point2I +from lsst.meas.algorithms import BrightStarStamps +from lsst.pex.config import Field, ListField +from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct +from lsst.pipe.base.connectionTypes import Input, Output +from lsst.utils.timer import timeMethod + +NEIGHBOR_MASK_PLANE = "NEIGHBOR" + + +class BrightStarStackConnections( + PipelineTaskConnections, + dimensions=("instrument", "detector"), +): + """Connections for BrightStarStackTask.""" + + brightStarStamps = Input( + name="brightStarStamps", + storageClass="BrightStarStamps", + doc="Set of preprocessed postage stamp cutouts, each centered on a single bright star.", + dimensions=("visit", "detector"), + multiple=True, + deferLoad=True, + ) + extendedPsf = Output( + name="extendedPsf2", # extendedPsfDetector ??? + storageClass="ImageF", # MaskedImageF + doc="Extended PSF model, built from stacking bright star cutouts.", + dimensions=("band",), + ) + + +class BrightStarStackConfig( + PipelineTaskConfig, + pipelineConnections=BrightStarStackConnections, +): + """Configuration parameters for BrightStarStackTask.""" + + subsetStampNumber = Field[int]( + doc="Number of stamps per subset to generate stacked images for.", + default=2, + ) + globalReducedChiSquaredThreshold = Field[float]( + doc="Threshold for global reduced chi-squared for bright star stamps.", + default=5.0, + ) + psfReducedChiSquaredThreshold = Field[float]( + doc="Threshold for PSF reduced chi-squared for bright star stamps.", + default=50.0, + ) + + badMaskPlanes = ListField[str]( + doc="Mask planes that identify excluded (masked) pixels.", + default=[ + "BAD", + "CR", + "CROSSTALK", + "EDGE", + "NO_DATA", + # "SAT", + # "SUSPECT", + "UNMASKEDNAN", + NEIGHBOR_MASK_PLANE, + ], + ) + stackType = Field[str]( + default="MEANCLIP", + doc="Statistic name to use for stacking (from `~lsst.afw.math.Property`)", + ) + stackNumSigmaClip = Field[float]( + doc="Number of sigma to use for clipping when stacking.", + default=3.0, + ) + stackNumIter = Field[int]( + doc="Number of iterations to use for clipping when stacking.", + default=5, + ) + + +class BrightStarStackTask(PipelineTask): + """Stack bright star postage stamps to produce an extended PSF model.""" + + ConfigClass = BrightStarStackConfig + _DefaultName = "brightStarStack" + config: BrightStarStackConfig + + def __init__(self, initInputs=None, *args, **kwargs): + super().__init__(*args, **kwargs) + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + inputs = butlerQC.get(inputRefs) + output = self.run(**inputs) + butlerQC.put(output, outputRefs) + + def _applyStampFit(self, stamp): + """Apply fitted stamp components to a single bright star stamp.""" + stampMI = stamp.maskedImage + stampBBox = stampMI.getBBox() + xGrid, yGrid = np.meshgrid(stampBBox.getX().arange(), stampBBox.getY().arange()) + xPlane = ImageF((xGrid * stamp.xGradient).astype(np.float32), xy0=stampMI.getXY0()) + yPlane = ImageF((yGrid * stamp.yGradient).astype(np.float32), xy0=stampMI.getXY0()) + stampMI -= stamp.pedestal + stampMI -= xPlane + stampMI -= yPlane + stampMI /= stamp.scale + + @timeMethod + def run( + self, + brightStarStamps: BrightStarStamps, + ): + """Identify bright stars within an exposure using a reference catalog, + extract stamps around each, then preprocess them. + + Bright star preprocessing steps are: shifting, warping and potentially + rotating them to the same pixel grid; computing their annular flux, + and; normalizing them. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The image from which bright star stamps should be extracted. + inputBackground : `~lsst.afw.image.Background` + The background model for the input exposure. + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional + Loader to find objects within a reference catalog. + dataId : `dict` or `~lsst.daf.butler.DataCoordinate` + The dataId of the exposure (including detector) that bright stars + should be extracted from. + + Returns + ------- + brightStarResults : `~lsst.pipe.base.Struct` + Results as a struct with attributes: + + ``brightStarStamps`` + (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`) + """ + stackTypeProperty = stringToStatisticsProperty(self.config.stackType) + statisticsControl = StatisticsControl( + numSigmaClip=self.config.stackNumSigmaClip, + numIter=self.config.stackNumIter, + ) + + subsetStampMIs = [] + tempStampMIs = [] + all_stars = 0 + used_stars = 0 + for stampsDDH in brightStarStamps: + stamps = stampsDDH.get() + all_stars += len(stamps) + for stamp in stamps: + # print("globalReducedChiSquared: stamp ", stamp.globalReducedChiSquared, "config ", self.config.globalReducedChiSquaredThreshold) + # print("psfReducedChiSquared: stamp ", stamp.psfReducedChiSquared, "config ", self.config.psfReducedChiSquaredThreshold) + if ( + stamp.globalReducedChiSquared > self.config.globalReducedChiSquaredThreshold + or stamp.psfReducedChiSquared > self.config.psfReducedChiSquaredThreshold + ): + continue + stampMI = stamp.maskedImage + self._applyStampFit(stamp) + tempStampMIs.append(stampMI) + + badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes) + statisticsControl.setAndMask(badMaskBitMask) + + # Amir: In case the total number of stamps is less than 20, the following will result in an + # empty subsetStampMIs list. + if len(tempStampMIs) == self.config.subsetStampNumber: + subsetStampMIs.append(statisticsStack(tempStampMIs, stackTypeProperty, statisticsControl)) + # TODO: what to do with remaining temp stamps? + tempStampMIs = [] + used_stars += self.config.subsetStampNumber + + self.metadata["psfStarCount"] = {} + self.metadata["psfStarCount"]["all"] = all_stars + self.metadata["psfStarCount"]["used"] = used_stars + # TODO: which stamp mask plane to use here? + # TODO: Amir: there might be cases where subsetStampMIs is an empty list. What do we want to do then? + # Currently, we get an "IndexError: list index out of range" + badMaskBitMask = subsetStampMIs[0].mask.getPlaneBitMask(self.config.badMaskPlanes) + statisticsControl.setAndMask(badMaskBitMask) + extendedPsfMI = statisticsStack(subsetStampMIs, stackTypeProperty, statisticsControl) + + extendedPsfExtent = extendedPsfMI.getBBox().getDimensions() + extendedPsfOrigin = Point2I(-1 * (extendedPsfExtent.x // 2), -1 * (extendedPsfExtent.y // 2)) + extendedPsfMI.setXY0(extendedPsfOrigin) + # return Struct(extendedPsf=[extendedPsfMI]) + + return Struct(extendedPsf=extendedPsfMI.getImage()) + + # stack = [] + # chiStack = [] + # for loop over all groups: + # load up all visits for this detector + # drop all with GOF > thresh + # sigma-clip mean stack the rest + # append to stack + # compute the scatter (MAD/sigma-clipped var, etc) of the rest + # divide by sqrt(var plane), and append to chiStack + # after for-loop, combine images in median stack for final result + # also combine chi-images, save separately + + # idea: run with two different thresholds, and compare the results + + # medianStack = [] + # for loop over all groups: + # load up all visits for this detector + # drop all with GOF > thresh + # median/sigma-clip stack the rest + # append to medianStack + # after for-loop, combine images in median stack for final result diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/subtractBrightStar.py b/python/lsst/pipe/tasks/brightStarSubtraction/subtractBrightStar.py new file mode 100644 index 000000000..04de11116 --- /dev/null +++ b/python/lsst/pipe/tasks/brightStarSubtraction/subtractBrightStar.py @@ -0,0 +1,858 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Retrieve extended PSF model and subtract bright stars at visit level.""" + +__all__ = ["BrightStarSubtractConnections", "BrightStarSubtractConfig", "BrightStarSubtractTask"] + +import logging +from typing import Any +import astropy.units as u +import numpy as np +from astropy.table import Table, Column +from lsst.afw.cameraGeom import FIELD_ANGLE, PIXELS, TAN_PIXELS +from lsst.afw.detection import Footprint, FootprintSet, Threshold +from lsst.afw.geom import SkyWcs, SpanSet +from lsst.afw.geom.transformFactory import makeTransform +from lsst.afw.image import ExposureF, ImageD, ImageF, MaskedImageF +from lsst.afw.math import BackgroundList, WarpingControl, warpImage +from lsst.daf.butler import DataCoordinate +from lsst.geom import ( + AffineTransform, + Box2I, + Extent2D, + Extent2I, + Point2D, + Point2I, + SpherePoint, + arcseconds, + floor, + radians, +) +from lsst.meas.algorithms import ( + LoadReferenceObjectsConfig, + ReferenceObjectLoader, +) +from lsst.pex.config import ChoiceField, ConfigField, Field, ListField +from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct +from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput +from lsst.utils.timer import timeMethod +from copy import deepcopy + +NEIGHBOR_MASK_PLANE = "NEIGHBOR" + +logger = logging.getLogger(__name__) + + +class BrightStarSubtractConnections( + PipelineTaskConnections, + dimensions=("instrument", "visit", "detector"), + defaultTemplates={ + # "outputExposureName": "brightStar_subtracted", + "outputExposureName": "postISRCCD", + "outputBackgroundName": "brightStars", + "badStampsName": "brightStars", + }, +): + inputCalexp = Input( + name="calexp", + storageClass="ExposureF", + doc="Background-subtracted input exposure from which to extract bright star stamp cutouts.", + dimensions=("visit", "detector"), + ) + inputBackground = Input( + name="calexpBackground", + storageClass="Background", + doc="Background model for the input exposure, to be added back on during processing.", + dimensions=("visit", "detector"), + ) + inputExposure = Input( + doc="Input exposure from which to subtract bright star stamps.", + name="postISRCCD", + storageClass="Exposure", + dimensions=( + "exposure", + "detector", + ), + ) + inputExtendedPsf = Input( + name="extendedPsf2", # extendedPsfDetector ??? + storageClass="ImageF", # MaskedImageF + doc="Extended PSF model, built from stacking bright star cutouts.", + dimensions=("band",), + ) + refCat = PrerequisiteInput( + doc="Reference catalog that contains bright star positions", + name="gaia_dr3_20230707", + storageClass="SimpleCatalog", + dimensions=("skypix",), + multiple=True, + deferLoad=True, + ) + # outputBadStamps = Output( + # doc="The stamps that are not normalized and consequently not subtracted from the exposure.", + # name="{badStampsName}_unsubtracted_stamps", + # storageClass="BrightStarStamps", + # dimensions=( + # "visit", + # "detector", + # ), + # ) + + outputExposure = Output( + doc="Exposure with bright stars subtracted.", + name="{outputExposureName}_subtracted", + storageClass="ExposureF", + dimensions=( + "exposure", + "detector", + ), + ) + outputBackgroundExposure = Output( + doc="Exposure containing only the modelled bright stars.", + name="{outputBackgroundName}_background", + storageClass="ExposureF", + dimensions=( + "visit", + "detector", + ), + ) + # scaledModels = Output( + # doc="Stamps containing models scaled to the level of stars", + # name="scaledModels", + # storageClass="BrightStarStamps", + # dimensions=( + # "visit", + # "detector", + # ), + # ) + + +class BrightStarSubtractConfig(PipelineTaskConfig, pipelineConnections=BrightStarSubtractConnections): + """Configuration parameters for BrightStarSubtractTask""" + + doWriteSubtractor = Field[bool]( + doc="Should an exposure containing all bright star models be written to disk?", + default=True, + ) + doWriteSubtractedExposure = Field[bool]( + doc="Should an exposure with bright stars subtracted be written to disk?", + default=True, + ) + magLimit = Field[float]( + doc="Magnitude limit, in Gaia G; all stars brighter than this value will be subtracted", + default=18, + ) + minValidAnnulusFraction = Field[float]( + doc="Minimum number of valid pixels that must fall within the annulus for the bright star to be " + "saved for subsequent generation of a PSF.", + default=0.0, + ) + numSigmaClip = Field[float]( + doc="Sigma for outlier rejection; ignored if annularFluxStatistic != 'MEANCLIP'.", + default=4, + ) + numIter = Field[int]( + doc="Number of iterations of outlier rejection; ignored if annularFluxStatistic != 'MEANCLIP'.", + default=3, + ) + warpingKernelName = ChoiceField[str]( + doc="Warping kernel", + default="lanczos5", + allowed={ + "bilinear": "bilinear interpolation", + "lanczos3": "Lanczos kernel of order 3", + "lanczos4": "Lanczos kernel of order 4", + "lanczos5": "Lanczos kernel of order 5", + "lanczos6": "Lanczos kernel of order 6", + "lanczos7": "Lanczos kernel of order 7", + }, + ) + scalingType = ChoiceField[str]( + doc="How the model should be scaled to each bright star; implemented options are " + "`annularFlux` to reuse the annular flux of each stamp, or `leastSquare` to perform " + "least square fitting on each pixel with no bad mask plane set.", + default="leastSquare", + allowed={ + "annularFlux": "reuse BrightStarStamp annular flux measurement", + "leastSquare": "find least square scaling factor", + }, + ) + annularFluxStatistic = ChoiceField[str]( + doc="Type of statistic to use to compute annular flux.", + default="MEANCLIP", + allowed={ + "MEAN": "mean", + "MEDIAN": "median", + "MEANCLIP": "clipped mean", + }, + ) + badMaskPlanes = ListField[str]( + doc="Mask planes that, if set, lead to associated pixels not being included in the computation of " + "the scaling factor (`BAD` should always be included). Ignored if scalingType is `annularFlux`, " + "as the stamps are expected to already be normalized.", + # Note that `BAD` should always be included, as secondary detected + # sources (i.e., detected sources other than the primary source of + # interest) also get set to `BAD`. + # Lee: find out the value of "BAD" and set the nan values into that number in the mask plane(?) + default=("BAD", "CR", "CROSSTALK", "EDGE", "NO_DATA", "SAT", "SUSPECT", "UNMASKEDNAN"), + ) + subtractionBox = ListField[int]( + doc="Size of the stamps to be extracted, in pixels.", + default=(250, 250), + ) + subtractionBoxBuffer = Field[float]( + doc=( + "'Buffer' (multiplicative) factor to be applied to determine the size of the stamp the " + "processed stars will be saved in. This is also the size of the extended PSF model. The buffer " + "region is masked and contain no data and subtractionBox determines the region where contains " + "the data." + ), + default=1.1, + ) + doApplySkyCorr = Field[bool]( + doc="Apply full focal plane sky correction before extracting stars?", + default=True, + ) + min_iterations = Field[int]( + doc="Minimum number of iterations to complete before evaluating changes in each iteration.", + default=3, + ) + refObjLoader = ConfigField[LoadReferenceObjectsConfig]( + doc="Reference object loader for astrometric calibration.", + ) + maskWarpingKernelName = ChoiceField[str]( + doc="Warping kernel for mask.", + default="bilinear", + allowed={ + "bilinear": "bilinear interpolation", + "lanczos3": "Lanczos kernel of order 3", + "lanczos4": "Lanczos kernel of order 4", + "lanczos5": "Lanczos kernel of order 5", + }, + ) + # Cutout geometry + stampSize = ListField[int]( + doc="Size of the stamps to be extracted, in pixels.", + default=(251, 251), + ) + stampSizePadding = Field[float]( + doc="Multiplicative factor applied to the cutout stamp size, to guard against post-warp data loss.", + default=1.1, + ) + # Star selection + magRange = ListField[float]( + doc="Magnitude range in Gaia G. Cutouts will be made for all stars in this range.", + default=[0, 18], + ) + minAreaFraction = Field[float]( + doc="Minimum fraction of the stamp area, post-masking, that must remain for a cutout to be retained.", + default=0.1, + ) + useMedianVariance = Field[bool]( + doc="Use the median of the variance plane for PSF fitting.", + default=False, + ) + psfMaskedFluxFracThreshold = Field[float]( + doc="Maximum allowed fraction of masked PSF flux for PSF fitting to occur.", + default=0.97, + ) + scalePsfModel = Field[bool]( + doc="If True, uses a scale factor to bring the PSF model data to the same level of the star data.", + default=True, + ) + + +class BrightStarSubtractTask(PipelineTask): + """Use an extended PSF model to subtract bright stars from a calibrated + exposure (i.e. at single-visit level). + + This task uses both a set of bright star stamps produced by + `~lsst.pipe.tasks.processBrightStars.ProcessBrightStarsTask` + and an extended PSF model produced by + `~lsst.pipe.tasks.extended_psf.MeasureExtendedPsfTask`. + """ + + ConfigClass = BrightStarSubtractConfig + _DefaultName = "subtractBrightStars" + + def __init__(self, *args, **kwargs): + # super().__init__(*args, **kwargs) + # # Placeholders to set up Statistics if scalingType is leastSquare. + # self.statsControl, self.statsFlag = None, None + # # Warping control; only contains shiftingALg provided in config. + + super().__init__(*args, **kwargs) + stampSize = Extent2D(*self.config.stampSize.list()) + stampRadius = floor(stampSize / 2) + # Define a central bounding box of the configured stamp size. + self.stampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(stampRadius) + paddedStampSize = stampSize + self.paddedStampRadius = floor(paddedStampSize / 2) + self.paddedStampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy( + self.paddedStampRadius + ) + + def runQuantum(self, butlerQC, inputRefs, outputRefs): + # Docstring inherited. + inputs = butlerQC.get(inputRefs) + dataId = butlerQC.quantum.dataId + refObjLoader = ReferenceObjectLoader( + dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat], + refCats=inputs.pop("refCat"), + name=self.config.connections.refCat, + config=self.config.refObjLoader, + ) + # TODO: include the un-subtracted stars here! + subtractor = self.run(**inputs, dataId=dataId, refObjLoader=refObjLoader) + if self.config.doWriteSubtractedExposure: + outputExposure = inputs["inputExposure"].clone() + outputExposure.image -= subtractor.image + else: + outputExposure = None + outputBackgroundExposure = ExposureF(subtractor) if self.config.doWriteSubtractor else None + output = Struct( + outputExposure=outputExposure, + outputBackgroundExposure=outputBackgroundExposure, + ) + butlerQC.put(output, outputRefs) + + @timeMethod + def run( + self, + inputExposure: ExposureF, + inputCalexp: ExposureF, + inputBackground: BackgroundList, + inputExtendedPsf: ImageF, + dataId: dict[str, Any] | DataCoordinate, + refObjLoader: ReferenceObjectLoader, + ): + """Generate a bright star subtractor image using scaled extended PSF models. + + Identifies bright stars within the calibrated exposure using a + reference catalog, extracts stamps around each, warps the extended PSF + model onto the stamp frame, fits for a scale factor and pedestal for + each star iteratively, and combines the scaled models into a single + subtractor exposure. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The Post-ISR CCD frame. Note: Currently appears unused directly + within this method's main logic, but required by the pipeline + definition. Cutouts are based on `inputCalexp` + `inputBackground`. + inputCalexp: `~lsst.afw.image.ExposureF` + The background-subtracted calibrated exposure used for identifying + stars, extracting stamps, and fitting models. + inputBackground : `~lsst.afw.math.BackgroundList` + The background model associated with `inputCalexp`. Added back + before processing stamps. + inputExtendedPsf : `~lsst.afw.image.ImageF` + The extended PSF model (e.g., from MeasureExtendedPsfTask). + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader` + Loader used to query the reference catalog for bright stars within + the exposure footprint. + dataId : `dict` or `~lsst.daf.butler.DataCoordinate` + The data identifier for the input exposure. + + Returns + ------- + subtractor : `~lsst.afw.image.ExposureF` + An exposure containing the combined, scaled models of the bright + stars identified and processed. This image can be subtracted from + the original `inputExposure` (or `inputCalexp` + `inputBackground`). + The image plane contains the model flux, while the variance and + mask planes are typically empty or minimal unless specifically populated. + """ + wcs = inputCalexp.getWcs() + bbox = inputCalexp.getBBox() + warpingControl = WarpingControl(self.config.warpingKernelName, self.config.maskWarpingKernelName) + + refCatBright = self._getRefCatBright(refObjLoader, wcs, bbox) + refCatBright.sort("mag") + zipRaDec = zip(refCatBright["coord_ra"] * radians, refCatBright["coord_dec"] * radians) + spherePoints = [SpherePoint(ra, dec) for ra, dec in zipRaDec] + pixCoords = wcs.skyToPixel(spherePoints) + + # Create image with background added back + # Using calibrated exposure for finding the scale factor and creating subtrator. + # The generated subtractor will be subtracted from PostISRCCd. + inputFixed = inputCalexp.getMaskedImage() + inputFixed += inputBackground.getImage() + inputCalexp.mask.addMaskPlane(NEIGHBOR_MASK_PLANE) + # Associate detected footprints (from DETECTED plane) with the bright reference stars. + allFootprints, associations = self._associateFootprints(inputCalexp, pixCoords, plane="DETECTED") + + subtractorExp = ExposureF(bbox=bbox) + templateSubtractor = subtractorExp.maskedImage + + detector = inputCalexp.detector + pixelScale = wcs.getPixelScale().asArcseconds() * arcseconds + pixToTan = detector.getTransform(PIXELS, TAN_PIXELS) + pixToFocalPlaneTan = detector.getTransform(PIXELS, FIELD_ANGLE).then( + makeTransform(AffineTransform.makeScaling(1 / pixelScale.asRadians())) + ) + + self.warpedDataDict = {} + removalIndices = [] + for j in range(self.config.min_iterations): + scaleList = [] + for starIndex, (obj, pixCoord) in enumerate(zip(refCatBright, pixCoords)): # type: ignore + # Start with the background-added image for each star + inputMI = deepcopy(inputFixed) + restSubtractor = deepcopy(templateSubtractor) + myNumber = 0 + for key in self.warpedDataDict.keys(): + # Subtract the *current best models* of all *other* stars before fitting this one. + if self.warpedDataDict[key]["subtractor"] is not None and key != obj['id']: + restSubtractor.image += self.warpedDataDict[key]["subtractor"].image + myNumber += 1 + self.log.debug(f"Number of stars subtracted before finding the scale factor for {obj['id']}: ", myNumber) + inputMI.image -= restSubtractor.image + + footprintIndex = associations.get(starIndex, None) + + if footprintIndex: + neighborFootprints = [fp for i, fp in enumerate(allFootprints) if i != footprintIndex] + self._setFootprints(inputMI, neighborFootprints, NEIGHBOR_MASK_PLANE) + else: + self._setFootprints(inputMI, allFootprints, NEIGHBOR_MASK_PLANE) + # Define linear shifting to recenter stamps + coordFocalPlaneTan = pixToFocalPlaneTan.applyForward(pixCoord) # center of warped star + shift = makeTransform(AffineTransform(Point2D(0, 0) - coordFocalPlaneTan)) + angle = np.arctan2(coordFocalPlaneTan.getY(), coordFocalPlaneTan.getX()) * radians + rotation = makeTransform(AffineTransform.makeRotation(-angle)) + pixToPolar = pixToFocalPlaneTan.then(shift).then(rotation) + rawStamp= self._getCutout(inputExposure=inputMI, coordPix=pixCoord, stampSize=self.config.stampSize.list()) + if rawStamp is None: + self.log.debug(f"No stamp for star with refID {obj['id']}") + removalIndices.append(starIndex) + continue + warpedStamp = self._warpRawStamp(rawStamp, warpingControl, pixToTan, pixCoord) + warpedModel = ImageF(warpedStamp.getBBox()) + inputExtendedPsfGeneral = deepcopy(inputExtendedPsf) + good_pixels = warpImage(warpedModel, inputExtendedPsfGeneral, pixToPolar.inverted(), warpingControl) + self.warpedDataDict[obj["id"]] = {"stamp": warpedStamp, "model": warpedModel, "starIndex": starIndex, "pixCoord": pixCoord} + if j == 0: + self.warpedDataDict[obj["id"]]["scale"] = None + self.warpedDataDict[obj["id"]]["subtractor"] = None + fitPsfResults = {} + if self.config.scalePsfModel: + psfNeg = warpedModel.array < 0 + self.modelScale = np.nanmean(warpedStamp.image.array) / np.nanmean(warpedModel.array[~psfNeg]) + warpedModel.array *= self.modelScale ######## model scale correction ######## + fitPsfResults = self._fitPsf( warpedStamp, warpedModel) + if fitPsfResults: + scaleList.append(fitPsfResults["scale"]) + self.warpedDataDict[obj["id"]]["scale"] = fitPsfResults["scale"] + + + cond = np.isnan(warpedModel.array) + warpedModel.array[cond] = 0 + warpedModel.array *= fitPsfResults["scale"] + overlapBBox = Box2I(warpedStamp.getBBox()) + overlapBBox.clip(inputCalexp.getBBox()) + + subtractor = deepcopy(templateSubtractor) + subtractor[overlapBBox] += warpedModel[overlapBBox] + self.warpedDataDict[obj["id"]]["subtractor"] = subtractor + + + else: + scaleList.append(np.nan) + if "subtractor" not in self.warpedDataDict[obj["id"]].keys(): + self.warpedDataDict[obj["id"]]["subtractor"] = None + self.warpedDataDict[obj["id"]]["scale"] = None + if j == 0: + refCatBright.remove_rows(removalIndices) + updatedPixCoords = [item for i, item in enumerate(pixCoords) if i not in removalIndices] + pixCoords = updatedPixCoords + new_scale_column = Column(scaleList, name=f'scale_0{j}') + # The following is handy when developing, not sure if we want to do that in the final version! + refCatBright.add_columns([new_scale_column]) + + subtractor = deepcopy(templateSubtractor) + for key in self.warpedDataDict.keys(): + if self.warpedDataDict[key]["scale"] is not None: + subtractor.image.array += self.warpedDataDict[key]["subtractor"].image.array + return subtractor + + def _getRefCatBright(self, refObjLoader: ReferenceObjectLoader, wcs: SkyWcs, bbox: Box2I) -> Table: + """Get a bright star subset of the reference catalog. + + Trim the reference catalog to only those objects within the exposure + bounding box dilated by half the bright star stamp size. + This ensures all stars that overlap the exposure are included. + + Parameters + ---------- + refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader` + Loader to find objects within a reference catalog. + wcs : `~lsst.afw.geom.SkyWcs` + World coordinate system. + bbox : `~lsst.geom.Box2I` + Bounding box of the exposure. + + Returns + ------- + refCatBright : `~astropy.table.Table` + Bright star subset of the reference catalog. + """ + dilatedBBox = bbox.dilatedBy(self.paddedStampRadius) + withinExposure = refObjLoader.loadPixelBox(dilatedBBox, wcs, filterName="phot_g_mean") + refCatFull = withinExposure.refCat + fluxField: str = withinExposure.fluxField + + brightFluxRange = sorted(((self.config.magRange * u.ABmag).to(u.nJy)).to_value()) + + subsetStars = (refCatFull[fluxField] > brightFluxRange[0]) & ( + refCatFull[fluxField] < brightFluxRange[1] + ) + refCatSubset = Table(refCatFull.extract("id", "coord_ra", "coord_dec", fluxField, where=subsetStars)) + fluxNanojansky = refCatSubset[fluxField][:] * u.nJy # type: ignore + refCatSubset["mag"] = fluxNanojansky.to(u.ABmag).to_value() # AB magnitudes + return refCatSubset + + def _associateFootprints( + self, inputExposure: ExposureF, pixCoords: list[Point2D], plane: str + ) -> tuple[list[Footprint], dict[int, int]]: + """Associate footprints from a given mask plane with specific objects. + + Footprints from the given mask plane are associated with objects at the + coordinates provided, where possible. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The input exposure with a mask plane. + pixCoords : `list` [`~lsst.geom.Point2D`] + The pixel coordinates of the objects. + plane : `str` + The mask plane used to identify masked pixels. + + Returns + ------- + footprints : `list` [`~lsst.afw.detection.Footprint`] + The footprints from the input exposure. + associations : `dict`[int, int] + Association indices between objects (key) and footprints (value). + """ + detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(plane), Threshold.BITMASK) + footprintSet = FootprintSet(inputExposure.mask, detThreshold) + footprints = footprintSet.getFootprints() + associations = {} + for starIndex, pixCoord in enumerate(pixCoords): + for footprintIndex, footprint in enumerate(footprints): + if footprint.contains(Point2I(pixCoord)): + associations[starIndex] = footprintIndex + break + self.log.debug( + "Associated %i of %i star%s to one each of the %i %s footprint%s.", + len(associations), + len(pixCoords), + "" if len(pixCoords) == 1 else "s", + len(footprints), + plane, + "" if len(footprints) == 1 else "s", + ) + return footprints, associations + + def _setFootprints(self, inputExposure: ExposureF, footprints: list, maskPlane: str): + """Set footprints in a given mask plane. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The input exposure to modify. + footprints : `list` [`~lsst.afw.detection.Footprint`] + The footprints to set in the mask plane. + maskPlane : `str` + The mask plane to set the footprints in. + + Notes + ----- + This method modifies the ``inputExposure`` object in-place. + """ + detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(maskPlane), Threshold.BITMASK) + detThresholdValue = int(detThreshold.getValue()) + footprintSet = FootprintSet(inputExposure.mask, detThreshold) + + # Wipe any existing footprints in the mask plane + inputExposure.mask.clearMaskPlane(int(np.log2(detThresholdValue))) + + # Set the footprints in the mask plane + footprintSet.setFootprints(footprints) + footprintSet.setMask(inputExposure.mask, maskPlane) + + def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, Any]: + """Fit a scaled PSF and a pedestal to each bright star cutout. + + Parameters + ---------- + stampMI : `~lsst.afw.image.MaskedImageF` + The masked image of the bright star cutout. + psfImage : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF` + The PSF model to fit. + + Returns + ------- + fitPsfResults : `dict`[`str`, `float`] + The result of the PSF fitting, with keys: + + ``scale`` : `float` + The scale factor. + ``scaleErr`` : `float` + The error on the scale factor. + ``pedestal`` : `float` + The pedestal value. + ``pedestalErr`` : `float` + The error on the pedestal value. + ``pedestalScaleCov`` : `float` + The covariance between the pedestal and scale factor. + ``xGradient`` : `float` + The gradient in the x-direction. + ``yGradient`` : `float` + The gradient in the y-direction. + ``globalReducedChiSquared`` : `float` + The global reduced chi-squared goodness-of-fit. + ``globalDegreesOfFreedom`` : `int` + The global number of degrees of freedom. + ``psfReducedChiSquared`` : `float` + The PSF BBox reduced chi-squared goodness-of-fit. + ``psfDegreesOfFreedom`` : `int` + The PSF BBox number of degrees of freedom. + ``psfMaskedFluxFrac`` : `float` + The fraction of the PSF image flux masked by bad pixels. + """ + badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes) + + # Calculate the fraction of the PSF image flux masked by bad pixels + psfMaskedPixels = ImageF(psfImage.getBBox()) + psfMaskedPixels.array[:, :] = (stampMI.mask[psfImage.getBBox()].array & badMaskBitMask).astype(bool) + # TODO: This is np.float64, else FITS metadata serialization fails + # Amir: what do we want to do for subtraction? we do not have the luxury of removing the star from the process here! + psfMaskedFluxFrac = np.dot(psfImage.array.flat, psfMaskedPixels.array.flat).astype(np.float64) + # if psfMaskedFluxFrac > self.config.psfMaskedFluxFracThreshold: + # return {} # Handle cases where the PSF image is mostly masked + + # Create a padded version of the input constant PSF image + paddedPsfImage = ImageF(stampMI.getBBox()) + paddedPsfImage[psfImage.getBBox()] = psfImage.convertF() + + # Create consistently masked data + mask = self.add_psf_mask(psfImage, stampMI) + badSpans = SpanSet.fromMask(mask, badMaskBitMask) + goodSpans = SpanSet(stampMI.getBBox()).intersectNot(badSpans) + varianceData = goodSpans.flatten(stampMI.variance.array, stampMI.getXY0()) + if self.config.useMedianVariance: + varianceData = np.median(varianceData) + sigmaData = np.sqrt(varianceData) + imageData = goodSpans.flatten(stampMI.image.array, stampMI.getXY0()) # B + imageData /= sigmaData + psfData = goodSpans.flatten(paddedPsfImage.array, paddedPsfImage.getXY0()) + psfData /= sigmaData + + # Fit the PSF scale factor and global pedestal + nData = len(imageData) + coefficientMatrix = np.ones((nData, 4), dtype=float) # A + coefficientMatrix[:, 0] = psfData + coefficientMatrix[:, 1] /= sigmaData + coefficientMatrix[:, 2:] = goodSpans.indices().T + coefficientMatrix[:, 2] /= sigmaData + coefficientMatrix[:, 3] /= sigmaData + try: + solutions, sumSquaredResiduals, *_ = np.linalg.lstsq(coefficientMatrix, imageData, rcond=None) + covarianceMatrix = np.linalg.inv(np.dot(coefficientMatrix.transpose(), coefficientMatrix)) # C + except np.linalg.LinAlgError: + return {} # Handle singular matrix errors + if sumSquaredResiduals.size == 0: + return {} + scale = solutions[0] + if scale <= 0: + return {} # Handle cases where the PSF scale fit has failed + scaleErr = np.sqrt(covarianceMatrix[0, 0]) + pedestal = solutions[1] + pedestalErr = np.sqrt(covarianceMatrix[1, 1]) + scalePedestalCov = covarianceMatrix[0, 1] + xGradient = solutions[3] + yGradient = solutions[2] + + # Calculate global (whole image) reduced chi-squared + globalChiSquared = np.sum(sumSquaredResiduals) + globalDegreesOfFreedom = nData - 4 + globalReducedChiSquared = globalChiSquared / globalDegreesOfFreedom + + # Calculate PSF BBox reduced chi-squared + psfBBoxGoodSpans = goodSpans.clippedTo(psfImage.getBBox()) + psfBBoxGoodSpansX, psfBBoxGoodSpansY = psfBBoxGoodSpans.indices() + psfBBoxData = psfBBoxGoodSpans.flatten(stampMI.image.array, stampMI.getXY0()) + psfBBoxModel = ( + psfBBoxGoodSpans.flatten(paddedPsfImage.array, stampMI.getXY0()) * scale + + pedestal + + psfBBoxGoodSpansX * xGradient + + psfBBoxGoodSpansY * yGradient + ) + psfBBoxVariance = psfBBoxGoodSpans.flatten(stampMI.variance.array, stampMI.getXY0()) + psfBBoxResiduals = (psfBBoxData - psfBBoxModel) ** 2 / psfBBoxVariance + psfBBoxChiSquared = np.sum(psfBBoxResiduals) + psfBBoxDegreesOfFreedom = len(psfBBoxData) - 4 + psfBBoxReducedChiSquared = psfBBoxChiSquared / psfBBoxDegreesOfFreedom + + return dict( + scale=scale, + scaleErr=scaleErr, + pedestal=pedestal, + pedestalErr=pedestalErr, + xGradient=xGradient, + yGradient=yGradient, + pedestalScaleCov=scalePedestalCov, + globalReducedChiSquared=globalReducedChiSquared, + globalDegreesOfFreedom=globalDegreesOfFreedom, + psfReducedChiSquared=psfBBoxReducedChiSquared, + psfDegreesOfFreedom=psfBBoxDegreesOfFreedom, + psfMaskedFluxFrac=psfMaskedFluxFrac, + ) + + def add_psf_mask(self, psfImage, stampMI): + """Add problematic PSF pixels to the stamp's mask. + + Identifies pixels in the PSF model image that are NaN or negative + and sets the corresponding bits (hardcoded as plane 0, likely 'BAD') + in a copy of the input stamp's mask. + + Parameters + ---------- + psfImage : `~lsst.afw.image.ImageF` + PSF model image defined on the stamp grid. + stampMI : `~lsst.afw.image.MaskedImageF` + The masked image stamp of the star being fitted. Its mask is used + as the base. + + Returns + ------- + mask : `~lsst.afw.image.Mask` + A mask object based on the input stamp's mask, updated to include + masked pixels derived from the PSF model image. + """ + cond = np.isnan(psfImage.array) + cond |= psfImage.array < 0 + mask = deepcopy(stampMI.mask) + mask.array[cond] = np.bitwise_or(mask.array[cond], 1) + return mask + + + def _getCutout(self, inputExposure, coordPix: Point2D, stampSize: list[int]): + """Get a cutout from an input exposure, handling edge cases. + + Generate a cutout from an input exposure centered on a given position + and with a given size. + If any part of the cutout is outside the input exposure bounding box, + the cutout is padded with NaNs. + + Parameters + ---------- + inputExposure : `~lsst.afw.image.ExposureF` + The image to extract bright star stamps from. + coordPix : `~lsst.geom.Point2D` + Center of the cutout in pixel space. + stampSize : `list` [`int`] + Size of the cutout, in pixels. + + Returns + ------- + stamp : `~lsst.afw.image.ExposureF` or `None` + The cutout, or `None` if the cutout is entirely outside the input + exposure bounding box. + + Notes + ----- + This method is a short-term workaround until DM-40042 is implemented. + At that point, it should be replaced by a call to the Exposure method + ``getCutout``, which will handle edge cases automatically. + """ + corner = Point2I(np.array(coordPix) - np.array(stampSize) / 2) + dimensions = Extent2I(stampSize) + stampBBox = Box2I(corner, dimensions) + overlapBBox = Box2I(stampBBox) + overlapBBox.clip(inputExposure.getBBox()) + if overlapBBox.getArea() > 0: + # Create full-sized stamp with pixels initially flagged as NO_DATA. + stamp = ExposureF(bbox=stampBBox) + stamp.image[:] = np.nan + stamp.mask.set(inputExposure.mask.getPlaneBitMask("NO_DATA")) + # # Restore pixels which overlap the input exposure. + overlap = inputExposure.Factory(inputExposure, overlapBBox) + stamp.maskedImage[overlapBBox] = overlap + else: + stamp = None + return stamp + + def _warpRawStamp(self, rawStamp, warpingControl, pixToTan, pixCoord): + """Warps a raw image stamp onto a common tangent plane projection. + + Applies a transformation (`pixToTan`) followed by a shift + transform to warp the input `rawStamp` onto a destination + `MaskedImageF` aligned with a tangent plane centered near the object. + The shift aims to place the object center at the center of the + destination image. + + Parameters + ---------- + rawStamp : `~lsst.afw.image.ExposureF` + The raw cutout image stamp (e.g., from `_getCutout`). + warpingControl : `~lsst.afw.math.WarpingControl` + Configuration for the warping process. + pixToTan : `~lsst.afw.geom.Transform` + Transformation from the raw stamp's pixel coordinates to the + common tangent plane coordinates. + pixCoord : `~lsst.geom.Point2D` + Pixel coordinates of the object center in the original exposure, + used to calculate the centering shift. + + Returns + ------- + warped_stamp : `~lsst.afw.image.MaskedImageF` or `None` + The warped and shifted masked image, or None if warping failed + (e.g., due to insufficient good pixels). + """ + destImage = MaskedImageF(*self.config.stampSize) + bottomLeft = Point2D(rawStamp.getXY0()) + newBottomLeft = pixToTan.applyForward(bottomLeft) + newBottomLeft = Point2I(newBottomLeft) + destImage.setXY0(newBottomLeft) + # Define linear shifting to recenter stamps + newCenter = pixToTan.applyForward(pixCoord) + self.modelCenter = self.config.stampSize[0] // 2, self.config.stampSize[1] // 2 + shift = (self.modelCenter[0] + newBottomLeft[0] - newCenter[0], self.modelCenter[1] + newBottomLeft[1] - newCenter[1]) + affineShift = AffineTransform(shift) + shiftTransform = makeTransform(affineShift) + + # Define full transform (warp and shift) + starWarper = pixToTan.then(shiftTransform) + + # Apply it + goodPix = warpImage(destImage, rawStamp.getMaskedImage(), starWarper, warpingControl) + if not goodPix: + return None + return destImage + + # # Arbitrarily set origin of shifted star to 0 + # destImage.setXY0(0, 0) \ No newline at end of file diff --git a/tests/test_brightStarCutout.py b/tests/test_brightStarCutout.py new file mode 100644 index 000000000..67b88d02f --- /dev/null +++ b/tests/test_brightStarCutout.py @@ -0,0 +1,102 @@ +# This file is part of pipe_tasks. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import unittest + +import lsst.afw.cameraGeom.testUtils +import lsst.afw.image +import lsst.utils.tests +import numpy as np +from lsst.afw.image import ImageD, ImageF, MaskedImageF +from lsst.afw.math import FixedKernel +from lsst.geom import Point2I +from lsst.meas.algorithms import KernelPsf +from lsst.pipe.tasks.brightStarSubtraction import BrightStarCutoutConfig, BrightStarCutoutTask + + +class BrightStarCutoutTestCase(lsst.utils.tests.TestCase): + def setUp(self): + # Fit values + self.scale = 2.34e5 + self.pedestal = 3210.1 + self.xGradient = 5.432 + self.yGradient = 10.987 + + # Create a pedestal + 2D plane + xCoords = np.linspace(-50, 50, 101) + yCoords = np.linspace(-50, 50, 101) + xPlane, yPlane = np.meshgrid(xCoords, yCoords) + pedestal = np.ones_like(xPlane) * self.pedestal + + # Create a pseudo-PSF + dist_from_center = np.sqrt(xPlane**2 + yPlane**2) + psfArray = np.exp(-dist_from_center / 5) + psfArray /= np.sum(psfArray) + fixedKernel = FixedKernel(ImageD(psfArray)) + self.psf = KernelPsf(fixedKernel) + + # Bring everything together to construct a stamp masked image + stampArray = psfArray * self.scale + pedestal + xPlane * self.xGradient + yPlane * self.yGradient + stampIm = ImageF((stampArray).astype(np.float32)) + stampVa = ImageF(stampIm.getBBox(), 654.321) + self.stampMI = MaskedImageF(image=stampIm, variance=stampVa) + self.stampMI.setXY0(Point2I(-50, -50)) + + # Ensure that all mask planes required by the task are in-place; + # new mask plane entries will be created as necessary + badMaskPlanes = [ + "BAD", + "CR", + "CROSSTALK", + "EDGE", + "NO_DATA", + "SAT", + "SUSPECT", + "UNMASKEDNAN", + "NEIGHBOR", + ] + _ = [self.stampMI.mask.addMaskPlane(mask) for mask in badMaskPlanes] + + def test_fitPsf(self): + """Test the PSF fitting method.""" + brightStarCutoutConfig = BrightStarCutoutConfig() + brightStarCutoutTask = BrightStarCutoutTask(config=brightStarCutoutConfig) + fitPsfResults = brightStarCutoutTask._fitPsf( + self.stampMI, + self.psf, + ) + self.assertAlmostEqual(fitPsfResults["scale"], self.scale, delta=1e-3) + self.assertAlmostEqual(fitPsfResults["pedestal"], self.pedestal, delta=1e-5) + self.assertAlmostEqual(fitPsfResults["xGradient"], self.xGradient, delta=1e-7) + self.assertAlmostEqual(fitPsfResults["yGradient"], self.yGradient, delta=1e-7) + + +def setup_module(module): + lsst.utils.tests.init() + + +class MemoryTestCase(lsst.utils.tests.MemoryTestCase): + pass + + +if __name__ == "__main__": + lsst.utils.tests.init() + unittest.main()