diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py b/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py
new file mode 100644
index 000000000..fe5088369
--- /dev/null
+++ b/python/lsst/pipe/tasks/brightStarSubtraction/__init__.py
@@ -0,0 +1,2 @@
+from .brightStarCutout import *
+from .brightStarStack import *
diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py
new file mode 100644
index 000000000..3be878eb0
--- /dev/null
+++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py
@@ -0,0 +1,691 @@
+# This file is part of pipe_tasks.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+"""Extract bright star cutouts; normalize and warp, optionally fit the PSF."""
+
+__all__ = ["BrightStarCutoutConnections", "BrightStarCutoutConfig", "BrightStarCutoutTask"]
+
+from typing import Any, Iterable, cast
+
+import astropy.units as u
+import numpy as np
+from astropy.coordinates import SkyCoord
+from astropy.table import Table
+from lsst.afw.cameraGeom import FIELD_ANGLE, PIXELS
+from lsst.afw.detection import Footprint, FootprintSet, Threshold
+from lsst.afw.geom import SkyWcs, SpanSet, makeModifiedWcs
+from lsst.afw.geom.transformFactory import makeTransform
+from lsst.afw.image import ExposureF, ImageD, ImageF, MaskedImageF
+from lsst.afw.math import BackgroundList, FixedKernel, WarpingControl, warpImage
+from lsst.daf.butler import DataCoordinate
+from lsst.geom import (
+ AffineTransform,
+ Box2I,
+ Extent2D,
+ Extent2I,
+ Point2D,
+ Point2I,
+ SpherePoint,
+ arcseconds,
+ floor,
+ radians,
+)
+from lsst.meas.algorithms import (
+ BrightStarStamp,
+ BrightStarStamps,
+ KernelPsf,
+ LoadReferenceObjectsConfig,
+ ReferenceObjectLoader,
+ WarpedPsf,
+)
+from lsst.pex.config import ChoiceField, ConfigField, Field, ListField
+from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct
+from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput
+from lsst.utils.timer import timeMethod
+from copy import deepcopy
+
+NEIGHBOR_MASK_PLANE = "NEIGHBOR"
+
+
+class BrightStarCutoutConnections(
+ PipelineTaskConnections,
+ dimensions=("instrument", "visit", "detector"),
+):
+ """Connections for BrightStarCutoutTask."""
+
+ refCat = PrerequisiteInput(
+ name="gaia_dr3_20230707",
+ storageClass="SimpleCatalog",
+ doc="Reference catalog that contains bright star positions.",
+ dimensions=("skypix",),
+ multiple=True,
+ deferLoad=True,
+ )
+ inputExposure = Input(
+ name="calexp",
+ storageClass="ExposureF",
+ doc="Background-subtracted input exposure from which to extract bright star stamp cutouts.",
+ dimensions=("visit", "detector"),
+ )
+ inputBackground = Input(
+ name="calexpBackground",
+ storageClass="Background",
+ doc="Background model for the input exposure, to be added back on during processing.",
+ dimensions=("visit", "detector"),
+ )
+ extendedPsf = Input(
+ name="extendedPsf2",
+ storageClass="ImageF",
+ doc="Extended PSF model, built from stacking bright star cutouts.",
+ dimensions=("band",),
+ )
+ brightStarStamps = Output(
+ name="brightStarStamps",
+ storageClass="BrightStarStamps",
+ doc="Set of preprocessed postage stamp cutouts, each centered on a single bright star.",
+ dimensions=("visit", "detector"),
+ )
+
+ def __init__(self, *, config: "BrightStarCutoutConfig | None" = None):
+ super().__init__(config=config)
+ assert config is not None
+ if not config.useExtendedPsf:
+ self.inputs.remove("extendedPsf")
+
+
+class BrightStarCutoutConfig(
+ PipelineTaskConfig,
+ pipelineConnections=BrightStarCutoutConnections,
+):
+ """Configuration parameters for BrightStarCutoutTask."""
+
+ # Star selection
+ magRange = ListField[float](
+ doc="Magnitude range in Gaia G. Cutouts will be made for all stars in this range.",
+ default=[0, 18],
+ )
+ excludeArcsecRadius = Field[float](
+ doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeArcsecRadius`` are not used.",
+ default=5,
+ )
+ excludeMagRange = ListField[float](
+ doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeArcsecRadius`` are not used.",
+ default=[0, 20],
+ )
+ minAreaFraction = Field[float](
+ doc="Minimum fraction of the stamp area, post-masking, that must remain for a cutout to be retained.",
+ default=0.1,
+ )
+ badMaskPlanes = ListField[str](
+ doc="Mask planes that identify excluded pixels for the calculation of ``minAreaFraction`` and, "
+ "optionally, fitting of the PSF.",
+ default=[
+ "BAD",
+ "CR",
+ "CROSSTALK",
+ "EDGE",
+ "NO_DATA",
+ "SAT",
+ "SUSPECT",
+ "UNMASKEDNAN",
+ NEIGHBOR_MASK_PLANE,
+ ],
+ )
+
+ # Cutout geometry
+ stampSize = ListField[int](
+ doc="Size of the stamps to be extracted, in pixels.",
+ default=(251, 251),
+ )
+ stampSizePadding = Field[float](
+ doc="Multiplicative factor applied to the cutout stamp size, to guard against post-warp data loss.",
+ default=1.1,
+ )
+ warpingKernelName = ChoiceField[str](
+ doc="Warping kernel.",
+ default="lanczos5",
+ allowed={
+ "bilinear": "bilinear interpolation",
+ "lanczos3": "Lanczos kernel of order 3",
+ "lanczos4": "Lanczos kernel of order 4",
+ "lanczos5": "Lanczos kernel of order 5",
+ },
+ )
+ maskWarpingKernelName = ChoiceField[str](
+ doc="Warping kernel for mask.",
+ default="bilinear",
+ allowed={
+ "bilinear": "bilinear interpolation",
+ "lanczos3": "Lanczos kernel of order 3",
+ "lanczos4": "Lanczos kernel of order 4",
+ "lanczos5": "Lanczos kernel of order 5",
+ },
+ )
+
+ scalePsfModel = Field[bool](
+ doc="If True, uses a scale factor to bring the PSF model data to the same level of the star data.",
+ default=True,
+ )
+
+ # PSF Fitting
+ useExtendedPsf = Field[bool](
+ doc="Use the extended PSF model to normalize bright star cutouts.",
+ default=False,
+ )
+ doFitPsf = Field[bool](
+ doc="Fit a scaled PSF and a pedestal to each bright star cutout.",
+ default=True,
+ )
+ useMedianVariance = Field[bool](
+ doc="Use the median of the variance plane for PSF fitting.",
+ default=False,
+ )
+ psfMaskedFluxFracThreshold = Field[float](
+ doc="Maximum allowed fraction of masked PSF flux for PSF fitting to occur.",
+ default=0.97,
+ )
+
+ # Misc
+ loadReferenceObjectsConfig = ConfigField[LoadReferenceObjectsConfig](
+ doc="Reference object loader for astrometric calibration.",
+ )
+
+
+class BrightStarCutoutTask(PipelineTask):
+ """Extract bright star cutouts; normalize and warp to the same pixel grid.
+
+ The BrightStarCutoutTask is used to extract, process, and store small image
+ cutouts (or "postage stamps") around bright stars.
+ This task essentially consists of three principal steps.
+ First, it identifies bright stars within an exposure using a reference
+ catalog and extracts a stamp around each.
+ Second, it shifts and warps each stamp to remove optical distortions and
+ sample all stars on the same pixel grid.
+ Finally, it optionally fits a PSF plus plane flux model to the cutout.
+ This final fitting procedure may be used to normalize each bright star
+ stamp prior to stacking when producing extended PSF models.
+ """
+
+ ConfigClass = BrightStarCutoutConfig
+ _DefaultName = "brightStarCutout"
+ config: BrightStarCutoutConfig
+
+ def __init__(self, initInputs=None, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ stampSize = Extent2D(*self.config.stampSize.list())
+ stampRadius = floor(stampSize / 2)
+ self.stampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(stampRadius)
+ paddedStampSize = stampSize * self.config.stampSizePadding
+ self.paddedStampRadius = floor(paddedStampSize / 2)
+ self.paddedStampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(
+ self.paddedStampRadius
+ )
+
+ def runQuantum(self, butlerQC, inputRefs, outputRefs):
+ inputs = butlerQC.get(inputRefs)
+ inputs["dataId"] = butlerQC.quantum.dataId
+ refObjLoader = ReferenceObjectLoader(
+ dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat],
+ refCats=inputs.pop("refCat"),
+ name=self.config.connections.refCat,
+ config=self.config.loadReferenceObjectsConfig,
+ )
+ extendedPsf = inputs.pop("extendedPsf", None)
+ output = self.run(**inputs, extendedPsf=extendedPsf, refObjLoader=refObjLoader)
+ # Only ingest Stamp if it exists; prevents ingesting an empty FITS file
+ if output:
+ butlerQC.put(output, outputRefs)
+
+ @timeMethod
+ def run(
+ self,
+ inputExposure: ExposureF,
+ inputBackground: BackgroundList,
+ extendedPsf: ImageF | None,
+ refObjLoader: ReferenceObjectLoader,
+ dataId: dict[str, Any] | DataCoordinate,
+ ):
+ """Identify bright stars within an exposure using a reference catalog,
+ extract stamps around each, warp/shift stamps onto a common frame and
+ then optionally fit a PSF plus plane model.
+
+ Parameters
+ ----------
+ inputExposure : `~lsst.afw.image.ExposureF`
+ The background-subtracted image to extract bright star stamps.
+ inputBackground : `~lsst.afw.math.BackgroundList`
+ The background model associated with the input exposure.
+ refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional
+ Loader to find objects within a reference catalog.
+ dataId : `dict` or `~lsst.daf.butler.DataCoordinate`
+ The dataId of the exposure that bright stars are extracted from.
+ Both 'visit' and 'detector' will be persisted in the output data.
+
+ Returns
+ -------
+ brightStarResults : `~lsst.pipe.base.Struct`
+ Results as a struct with attributes:
+
+ ``brightStarStamps``
+ (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`)
+ """
+ wcs = inputExposure.getWcs()
+ bbox = inputExposure.getBBox()
+ warpingControl = WarpingControl(self.config.warpingKernelName, self.config.maskWarpingKernelName)
+
+ refCatBright = self._getRefCatBright(refObjLoader, wcs, bbox)
+ zipRaDec = zip(refCatBright["coord_ra"] * radians, refCatBright["coord_dec"] * radians)
+ spherePoints = [SpherePoint(ra, dec) for ra, dec in zipRaDec]
+ pixCoords = wcs.skyToPixel(spherePoints)
+
+ # Restore original subtracted background
+ inputMI = inputExposure.getMaskedImage()
+ inputMI += inputBackground.getImage()
+ # Amir: the above addition to inputMI, also adds to the inputExposure.
+ # Amir: but the calibration, three lines later, only is applied to the inputMI.
+
+ # Set up NEIGHBOR mask plane; associate footprints with stars
+ inputExposure.mask.addMaskPlane(NEIGHBOR_MASK_PLANE)
+ allFootprints, associations = self._associateFootprints(inputExposure, pixCoords, plane="DETECTED")
+
+ # TODO: If we eventually have better PhotoCalibs (eg FGCM), apply here
+ inputMI = inputExposure.getPhotoCalib().calibrateImage(inputMI, False)
+
+ # Set up transform
+ detector = inputExposure.detector
+ pixelScale = wcs.getPixelScale().asArcseconds() * arcseconds
+ pixToFocalPlaneTan = detector.getTransform(PIXELS, FIELD_ANGLE).then(
+ makeTransform(AffineTransform.makeScaling(1 / pixelScale.asRadians()))
+ )
+
+ # Loop over each bright star
+ stamps, goodFracs, stamps_fitPsfResults = [], [], []
+ for starIndex, (obj, pixCoord) in enumerate(zip(refCatBright, pixCoords)): # type: ignore
+ footprintIndex = associations.get(starIndex, None)
+ stampMI = MaskedImageF(self.paddedStampBBox)
+
+ # Set NEIGHBOR footprints in the mask plane
+ if footprintIndex:
+ neighborFootprints = [fp for i, fp in enumerate(allFootprints) if i != footprintIndex]
+ # self._setFootprints(inputExposure, neighborFootprints, NEIGHBOR_MASK_PLANE)
+ self._setFootprints(inputMI, neighborFootprints, NEIGHBOR_MASK_PLANE)
+ else:
+ # self._setFootprints(inputExposure, allFootprints, NEIGHBOR_MASK_PLANE)
+ self._setFootprints(inputMI, allFootprints, NEIGHBOR_MASK_PLANE)
+
+ # Define linear shifting to recenter stamps
+ coordFocalPlaneTan = pixToFocalPlaneTan.applyForward(pixCoord) # center of warped star
+ shift = makeTransform(AffineTransform(Point2D(0, 0) - coordFocalPlaneTan))
+ angle = np.arctan2(coordFocalPlaneTan.getY(), coordFocalPlaneTan.getX()) * radians
+ rotation = makeTransform(AffineTransform.makeRotation(-angle))
+ pixToPolar = pixToFocalPlaneTan.then(shift).then(rotation)
+
+ # Apply the warp to the star stamp (in-place)
+ # warpImage(stampMI, inputExposure.maskedImage, pixToPolar, warpingControl)
+ warpImage(stampMI, inputMI, pixToPolar, warpingControl)
+
+ # Trim to the base stamp size, check mask coverage, update metadata
+ stampMI = stampMI[self.stampBBox]
+ badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes)
+ goodFrac = np.sum(stampMI.mask.array & badMaskBitMask == 0) / stampMI.mask.array.size
+ goodFracs.append(goodFrac)
+ if goodFrac < self.config.minAreaFraction:
+ continue
+
+ # Fit a scaled PSF and a pedestal to each bright star cutout
+ psf = WarpedPsf(inputExposure.getPsf(), pixToPolar, warpingControl)
+ constantPsf = KernelPsf(FixedKernel(psf.computeKernelImage(Point2D(0, 0))))
+ # TODO: discuss with Lee whether we should warp the psf here as well?
+ if self.config.useExtendedPsf:
+ psfImage = extendedPsf # Assumed to be warped, center at [0,0]
+ else:
+ psfImage = constantPsf.computeKernelImage(constantPsf.getAveragePosition())
+ if self.config.scalePsfModel:
+ psfNeg = psfImage.array < 0
+ self.modelScale = np.nanmean(stampMI.image.array) / np.nanmean(psfImage.array[~psfNeg])
+ psfImage.array *= self.modelScale ######## model scale correction ########
+ else:
+ self.modelScale = 1
+
+ fitPsfResults = {}
+ if self.config.doFitPsf:
+ fitPsfResults = self._fitPsf(stampMI, psfImage)
+ stamps_fitPsfResults.append(fitPsfResults)
+
+ # Save the stamp if the PSF fit was successful or no fit requested
+ if fitPsfResults or not self.config.doFitPsf:
+ stamp = BrightStarStamp(
+ maskedImage=stampMI,
+ # TODO: what to do about this PSF?
+ psf=constantPsf,
+ wcs=makeModifiedWcs(pixToPolar, wcs, False),
+ visit=cast(int, dataId["visit"]),
+ detector=cast(int, dataId["detector"]),
+ refId=obj["id"],
+ refMag=obj["mag"],
+ position=pixCoord,
+ scale=fitPsfResults.get("scale", None),
+ scaleErr=fitPsfResults.get("scaleErr", None),
+ pedestal=fitPsfResults.get("pedestal", None),
+ pedestalErr=fitPsfResults.get("pedestalErr", None),
+ pedestalScaleCov=fitPsfResults.get("pedestalScaleCov", None),
+ xGradient=fitPsfResults.get("xGradient", None),
+ yGradient=fitPsfResults.get("yGradient", None),
+ globalReducedChiSquared=fitPsfResults.get("globalReducedChiSquared", None),
+ globalDegreesOfFreedom=fitPsfResults.get("globalDegreesOfFreedom", None),
+ psfReducedChiSquared=fitPsfResults.get("psfReducedChiSquared", None),
+ psfDegreesOfFreedom=fitPsfResults.get("psfDegreesOfFreedom", None),
+ psfMaskedFluxFrac=fitPsfResults.get("psfMaskedFluxFrac", None),
+ )
+ stamps.append(stamp)
+
+ self.log.info(
+ "Extracted %i bright star stamp%s. "
+ "Excluded %i star%s: insufficient area (%i), PSF fit failure (%i).",
+ len(stamps),
+ "" if len(stamps) == 1 else "s",
+ len(refCatBright) - len(stamps),
+ "" if len(refCatBright) - len(stamps) == 1 else "s",
+ np.sum(np.array(goodFracs) < self.config.minAreaFraction),
+ (
+ np.sum(np.isnan([x.get("pedestal", np.nan) for x in stamps_fitPsfResults]))
+ if self.config.doFitPsf
+ else 0
+ ),
+ )
+ brightStarStamps = BrightStarStamps(stamps)
+ return Struct(brightStarStamps=brightStarStamps)
+
+ def _getRefCatBright(self, refObjLoader: ReferenceObjectLoader, wcs: SkyWcs, bbox: Box2I) -> Table:
+ """Get a bright star subset of the reference catalog.
+
+ Trim the reference catalog to only those objects within the exposure
+ bounding box dilated by half the bright star stamp size.
+ This ensures all stars that overlap the exposure are included.
+
+ Parameters
+ ----------
+ refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`
+ Loader to find objects within a reference catalog.
+ wcs : `~lsst.afw.geom.SkyWcs`
+ World coordinate system.
+ bbox : `~lsst.geom.Box2I`
+ Bounding box of the exposure.
+
+ Returns
+ -------
+ refCatBright : `~astropy.table.Table`
+ Bright star subset of the reference catalog.
+ """
+ dilatedBBox = bbox.dilatedBy(self.paddedStampRadius)
+ withinExposure = refObjLoader.loadPixelBox(dilatedBBox, wcs, filterName="phot_g_mean")
+ refCatFull = withinExposure.refCat
+ fluxField: str = withinExposure.fluxField
+
+ proxFluxRange = sorted(((self.config.excludeMagRange * u.ABmag).to(u.nJy)).to_value())
+ brightFluxRange = sorted(((self.config.magRange * u.ABmag).to(u.nJy)).to_value())
+
+ subsetStars = (refCatFull[fluxField] > np.min((proxFluxRange[0], brightFluxRange[0]))) & (
+ refCatFull[fluxField] < np.max((proxFluxRange[1], brightFluxRange[1]))
+ )
+ refCatSubset = Table(refCatFull.extract("id", "coord_ra", "coord_dec", fluxField, where=subsetStars))
+
+ proxStars = (refCatSubset[fluxField] >= proxFluxRange[0]) & (
+ refCatSubset[fluxField] <= proxFluxRange[1]
+ )
+ brightStars = (refCatSubset[fluxField] >= brightFluxRange[0]) & (
+ refCatSubset[fluxField] <= brightFluxRange[1]
+ )
+
+ coords = SkyCoord(refCatSubset["coord_ra"], refCatSubset["coord_dec"], unit="rad")
+ excludeArcsecRadius = self.config.excludeArcsecRadius * u.arcsec # type: ignore
+ refCatBrightIsolated = []
+ for coord in cast(Iterable[SkyCoord], coords[brightStars]):
+ neighbors = coords[proxStars]
+ seps = coord.separation(neighbors).to(u.arcsec)
+ tooClose = (seps > 0) & (seps <= excludeArcsecRadius) # not self matched
+ refCatBrightIsolated.append(not tooClose.any())
+
+ refCatBright = cast(Table, refCatSubset[brightStars][refCatBrightIsolated])
+
+ fluxNanojansky = refCatBright[fluxField][:] * u.nJy # type: ignore
+ refCatBright["mag"] = fluxNanojansky.to(u.ABmag).to_value() # AB magnitudes
+
+ self.log.info(
+ "Identified %i of %i star%s which satisfy: frame overlap; in the range %s mag; no neighboring "
+ "stars within %s arcsec.",
+ len(refCatBright),
+ len(refCatFull),
+ "" if len(refCatFull) == 1 else "s",
+ self.config.magRange,
+ self.config.excludeArcsecRadius,
+ )
+
+ return refCatBright
+
+ def _associateFootprints(
+ self, inputExposure: ExposureF, pixCoords: list[Point2D], plane: str
+ ) -> tuple[list[Footprint], dict[int, int]]:
+ """Associate footprints from a given mask plane with specific objects.
+
+ Footprints from the given mask plane are associated with objects at the
+ coordinates provided, where possible.
+
+ Parameters
+ ----------
+ inputExposure : `~lsst.afw.image.ExposureF`
+ The input exposure with a mask plane.
+ pixCoords : `list` [`~lsst.geom.Point2D`]
+ The pixel coordinates of the objects.
+ plane : `str`
+ The mask plane used to identify masked pixels.
+
+ Returns
+ -------
+ footprints : `list` [`~lsst.afw.detection.Footprint`]
+ The footprints from the input exposure.
+ associations : `dict`[int, int]
+ Association indices between objects (key) and footprints (value).
+ """
+ detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(plane), Threshold.BITMASK)
+ footprintSet = FootprintSet(inputExposure.mask, detThreshold)
+ footprints = footprintSet.getFootprints()
+ associations = {}
+ for starIndex, pixCoord in enumerate(pixCoords):
+ for footprintIndex, footprint in enumerate(footprints):
+ if footprint.contains(Point2I(pixCoord)):
+ associations[starIndex] = footprintIndex
+ break
+ self.log.debug(
+ "Associated %i of %i star%s to one each of the %i %s footprint%s.",
+ len(associations),
+ len(pixCoords),
+ "" if len(pixCoords) == 1 else "s",
+ len(footprints),
+ plane,
+ "" if len(footprints) == 1 else "s",
+ )
+ return footprints, associations
+
+ def _setFootprints(self, inputExposure: ExposureF, footprints: list, maskPlane: str):
+ """Set footprints in a given mask plane.
+
+ Parameters
+ ----------
+ inputExposure : `~lsst.afw.image.ExposureF`
+ The input exposure to modify.
+ footprints : `list` [`~lsst.afw.detection.Footprint`]
+ The footprints to set in the mask plane.
+ maskPlane : `str`
+ The mask plane to set the footprints in.
+
+ Notes
+ -----
+ This method modifies the ``inputExposure`` object in-place.
+ """
+ detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(maskPlane), Threshold.BITMASK)
+ detThresholdValue = int(detThreshold.getValue())
+ footprintSet = FootprintSet(inputExposure.mask, detThreshold)
+
+ # Wipe any existing footprints in the mask plane
+ inputExposure.mask.clearMaskPlane(int(np.log2(detThresholdValue)))
+
+ # Set the footprints in the mask plane
+ footprintSet.setFootprints(footprints)
+ footprintSet.setMask(inputExposure.mask, maskPlane)
+
+ def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, Any]:
+ """Fit a scaled PSF and a pedestal to each bright star cutout.
+
+ Parameters
+ ----------
+ stampMI : `~lsst.afw.image.MaskedImageF`
+ The masked image of the bright star cutout.
+ psfImage : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF`
+ The PSF model to fit.
+
+ Returns
+ -------
+ fitPsfResults : `dict`[`str`, `float`]
+ The result of the PSF fitting, with keys:
+
+ ``scale`` : `float`
+ The scale factor.
+ ``scaleErr`` : `float`
+ The error on the scale factor.
+ ``pedestal`` : `float`
+ The pedestal value.
+ ``pedestalErr`` : `float`
+ The error on the pedestal value.
+ ``pedestalScaleCov`` : `float`
+ The covariance between the pedestal and scale factor.
+ ``xGradient`` : `float`
+ The gradient in the x-direction.
+ ``yGradient`` : `float`
+ The gradient in the y-direction.
+ ``globalReducedChiSquared`` : `float`
+ The global reduced chi-squared goodness-of-fit.
+ ``globalDegreesOfFreedom`` : `int`
+ The global number of degrees of freedom.
+ ``psfReducedChiSquared`` : `float`
+ The PSF BBox reduced chi-squared goodness-of-fit.
+ ``psfDegreesOfFreedom`` : `int`
+ The PSF BBox number of degrees of freedom.
+ ``psfMaskedFluxFrac`` : `float`
+ The fraction of the PSF image flux masked by bad pixels.
+ """
+ badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes)
+
+ # Calculate the fraction of the PSF image flux masked by bad pixels
+ psfMaskedPixels = ImageF(psfImage.getBBox())
+ psfMaskedPixels.array[:, :] = (stampMI.mask[psfImage.getBBox()].array & badMaskBitMask).astype(bool)
+ # TODO: This is np.float64, else FITS metadata serialization fails
+ # Amir: the following tries to find the fraction of the psf flux in the masked area of the psf image.
+ psfMaskedFluxFrac = np.dot(psfImage.array.flat, psfMaskedPixels.array.flat).astype(np.float64) / psfImage.array.sum()
+ # psfMaskedFluxFrac = np.dot(psfImage.array.astype(bool).flat, psfMaskedPixels.array.flat).astype(np.float64) / psfImage.array.astype(bool).sum()
+ if psfMaskedFluxFrac > self.config.psfMaskedFluxFracThreshold:
+ return {} # Handle cases where the PSF image is mostly masked
+
+ # Create a padded version of the input constant PSF image
+ paddedPsfImage = ImageF(stampMI.getBBox())
+ paddedPsfImage[psfImage.getBBox()] = psfImage.convertF()
+
+ mask = self.add_psf_mask(paddedPsfImage, stampMI)
+ # Create consistently masked data
+ # badSpans = SpanSet.fromMask(stampMI.mask, badMaskBitMask)
+ badSpans = SpanSet.fromMask(mask, badMaskBitMask)
+ goodSpans = SpanSet(stampMI.getBBox()).intersectNot(badSpans)
+ varianceData = goodSpans.flatten(stampMI.variance.array, stampMI.getXY0())
+ if self.config.useMedianVariance:
+ varianceData = np.median(varianceData)
+ sigmaData = np.sqrt(varianceData)
+ imageData = goodSpans.flatten(stampMI.image.array, stampMI.getXY0()) # B
+ imageData /= sigmaData
+ psfData = goodSpans.flatten(paddedPsfImage.array, paddedPsfImage.getXY0())
+ psfData /= sigmaData
+ # Fit the PSF scale factor and global pedestal
+ nData = len(imageData)
+ coefficientMatrix = np.ones((nData, 4), dtype=float) # A
+ coefficientMatrix[:, 0] = psfData
+ coefficientMatrix[:, 1] /= sigmaData
+ coefficientMatrix[:, 2:] = goodSpans.indices().T
+ coefficientMatrix[:, 2] /= sigmaData
+ coefficientMatrix[:, 3] /= sigmaData
+ try:
+ solutions, sumSquaredResiduals, *_ = np.linalg.lstsq(coefficientMatrix, imageData, rcond=None)
+ covarianceMatrix = np.linalg.inv(np.dot(coefficientMatrix.transpose(), coefficientMatrix)) # C
+ except np.linalg.LinAlgError:
+ return {} # Handle singular matrix errors
+ if sumSquaredResiduals.size == 0:
+ return {} # Handle cases where sum of the squared residuals are empty
+ # scale = solutions[0]
+ scale = solutions[0] * self.modelScale ######## model scale correction ########
+ if scale <= 0:
+ return {} # Handle cases where the PSF scale fit has failed
+ scaleErr = np.sqrt(covarianceMatrix[0, 0]) * self.modelScale ######## model scale correction ########
+ pedestal = solutions[1]
+ pedestalErr = np.sqrt(covarianceMatrix[1, 1])
+ scalePedestalCov = covarianceMatrix[0, 1] * self.modelScale ######## model scale correction ########
+ xGradient = solutions[3]
+ yGradient = solutions[2]
+
+ # Calculate global (whole image) reduced chi-squared
+ globalChiSquared = np.sum(sumSquaredResiduals)
+ globalDegreesOfFreedom = nData - 4
+ globalReducedChiSquared = globalChiSquared / globalDegreesOfFreedom
+
+ # Calculate PSF BBox reduced chi-squared
+ psfBBoxGoodSpans = goodSpans.clippedTo(psfImage.getBBox())
+ psfBBoxGoodSpansX, psfBBoxGoodSpansY = psfBBoxGoodSpans.indices()
+ psfBBoxData = psfBBoxGoodSpans.flatten(stampMI.image.array, stampMI.getXY0())
+ paddedPsfImage.array /= self.modelScale ######## model scale correction ########
+ psfBBoxModel = (
+ psfBBoxGoodSpans.flatten(paddedPsfImage.array, stampMI.getXY0()) * scale
+ + pedestal
+ + psfBBoxGoodSpansX * xGradient
+ + psfBBoxGoodSpansY * yGradient
+ )
+ psfBBoxVariance = psfBBoxGoodSpans.flatten(stampMI.variance.array, stampMI.getXY0())
+ psfBBoxResiduals = (psfBBoxData - psfBBoxModel) ** 2 / psfBBoxVariance
+ psfBBoxChiSquared = np.sum(psfBBoxResiduals)
+ psfBBoxDegreesOfFreedom = len(psfBBoxData) - 4
+ psfBBoxReducedChiSquared = psfBBoxChiSquared / psfBBoxDegreesOfFreedom
+ return dict(
+ scale=scale,
+ scaleErr=scaleErr,
+ pedestal=pedestal,
+ pedestalErr=pedestalErr,
+ xGradient=xGradient,
+ yGradient=yGradient,
+ pedestalScaleCov=scalePedestalCov,
+ globalReducedChiSquared=globalReducedChiSquared,
+ globalDegreesOfFreedom=globalDegreesOfFreedom,
+ psfReducedChiSquared=psfBBoxReducedChiSquared,
+ psfDegreesOfFreedom=psfBBoxDegreesOfFreedom,
+ psfMaskedFluxFrac=psfMaskedFluxFrac,
+ )
+
+ def add_psf_mask(self, psfImage, stampMI):
+ cond = np.isnan(psfImage.array)
+ cond |= psfImage.array < 0
+ mask = deepcopy(stampMI.mask)
+ mask.array[cond] = np.bitwise_or(mask.array[cond], 1)
+ return mask
diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py
new file mode 100644
index 000000000..a99d9be3e
--- /dev/null
+++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarStack.py
@@ -0,0 +1,240 @@
+# This file is part of pipe_tasks.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+"""Stack bright star postage stamp cutouts to produce an extended PSF model."""
+
+__all__ = ["BrightStarStackConnections", "BrightStarStackConfig", "BrightStarStackTask"]
+
+import numpy as np
+from lsst.afw.image import ImageF
+from lsst.afw.math import StatisticsControl, statisticsStack, stringToStatisticsProperty
+from lsst.geom import Point2I
+from lsst.meas.algorithms import BrightStarStamps
+from lsst.pex.config import Field, ListField
+from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct
+from lsst.pipe.base.connectionTypes import Input, Output
+from lsst.utils.timer import timeMethod
+
+NEIGHBOR_MASK_PLANE = "NEIGHBOR"
+
+
+class BrightStarStackConnections(
+ PipelineTaskConnections,
+ dimensions=("instrument", "detector"),
+):
+ """Connections for BrightStarStackTask."""
+
+ brightStarStamps = Input(
+ name="brightStarStamps",
+ storageClass="BrightStarStamps",
+ doc="Set of preprocessed postage stamp cutouts, each centered on a single bright star.",
+ dimensions=("visit", "detector"),
+ multiple=True,
+ deferLoad=True,
+ )
+ extendedPsf = Output(
+ name="extendedPsf2", # extendedPsfDetector ???
+ storageClass="ImageF", # MaskedImageF
+ doc="Extended PSF model, built from stacking bright star cutouts.",
+ dimensions=("band",),
+ )
+
+
+class BrightStarStackConfig(
+ PipelineTaskConfig,
+ pipelineConnections=BrightStarStackConnections,
+):
+ """Configuration parameters for BrightStarStackTask."""
+
+ subsetStampNumber = Field[int](
+ doc="Number of stamps per subset to generate stacked images for.",
+ default=2,
+ )
+ globalReducedChiSquaredThreshold = Field[float](
+ doc="Threshold for global reduced chi-squared for bright star stamps.",
+ default=5.0,
+ )
+ psfReducedChiSquaredThreshold = Field[float](
+ doc="Threshold for PSF reduced chi-squared for bright star stamps.",
+ default=50.0,
+ )
+
+ badMaskPlanes = ListField[str](
+ doc="Mask planes that identify excluded (masked) pixels.",
+ default=[
+ "BAD",
+ "CR",
+ "CROSSTALK",
+ "EDGE",
+ "NO_DATA",
+ # "SAT",
+ # "SUSPECT",
+ "UNMASKEDNAN",
+ NEIGHBOR_MASK_PLANE,
+ ],
+ )
+ stackType = Field[str](
+ default="MEANCLIP",
+ doc="Statistic name to use for stacking (from `~lsst.afw.math.Property`)",
+ )
+ stackNumSigmaClip = Field[float](
+ doc="Number of sigma to use for clipping when stacking.",
+ default=3.0,
+ )
+ stackNumIter = Field[int](
+ doc="Number of iterations to use for clipping when stacking.",
+ default=5,
+ )
+
+
+class BrightStarStackTask(PipelineTask):
+ """Stack bright star postage stamps to produce an extended PSF model."""
+
+ ConfigClass = BrightStarStackConfig
+ _DefaultName = "brightStarStack"
+ config: BrightStarStackConfig
+
+ def __init__(self, initInputs=None, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def runQuantum(self, butlerQC, inputRefs, outputRefs):
+ inputs = butlerQC.get(inputRefs)
+ output = self.run(**inputs)
+ butlerQC.put(output, outputRefs)
+
+ def _applyStampFit(self, stamp):
+ """Apply fitted stamp components to a single bright star stamp."""
+ stampMI = stamp.maskedImage
+ stampBBox = stampMI.getBBox()
+ xGrid, yGrid = np.meshgrid(stampBBox.getX().arange(), stampBBox.getY().arange())
+ xPlane = ImageF((xGrid * stamp.xGradient).astype(np.float32), xy0=stampMI.getXY0())
+ yPlane = ImageF((yGrid * stamp.yGradient).astype(np.float32), xy0=stampMI.getXY0())
+ stampMI -= stamp.pedestal
+ stampMI -= xPlane
+ stampMI -= yPlane
+ stampMI /= stamp.scale
+
+ @timeMethod
+ def run(
+ self,
+ brightStarStamps: BrightStarStamps,
+ ):
+ """Identify bright stars within an exposure using a reference catalog,
+ extract stamps around each, then preprocess them.
+
+ Bright star preprocessing steps are: shifting, warping and potentially
+ rotating them to the same pixel grid; computing their annular flux,
+ and; normalizing them.
+
+ Parameters
+ ----------
+ inputExposure : `~lsst.afw.image.ExposureF`
+ The image from which bright star stamps should be extracted.
+ inputBackground : `~lsst.afw.image.Background`
+ The background model for the input exposure.
+ refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional
+ Loader to find objects within a reference catalog.
+ dataId : `dict` or `~lsst.daf.butler.DataCoordinate`
+ The dataId of the exposure (including detector) that bright stars
+ should be extracted from.
+
+ Returns
+ -------
+ brightStarResults : `~lsst.pipe.base.Struct`
+ Results as a struct with attributes:
+
+ ``brightStarStamps``
+ (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`)
+ """
+ stackTypeProperty = stringToStatisticsProperty(self.config.stackType)
+ statisticsControl = StatisticsControl(
+ numSigmaClip=self.config.stackNumSigmaClip,
+ numIter=self.config.stackNumIter,
+ )
+
+ subsetStampMIs = []
+ tempStampMIs = []
+ all_stars = 0
+ used_stars = 0
+ for stampsDDH in brightStarStamps:
+ stamps = stampsDDH.get()
+ all_stars += len(stamps)
+ for stamp in stamps:
+ # print("globalReducedChiSquared: stamp ", stamp.globalReducedChiSquared, "config ", self.config.globalReducedChiSquaredThreshold)
+ # print("psfReducedChiSquared: stamp ", stamp.psfReducedChiSquared, "config ", self.config.psfReducedChiSquaredThreshold)
+ if (
+ stamp.globalReducedChiSquared > self.config.globalReducedChiSquaredThreshold
+ or stamp.psfReducedChiSquared > self.config.psfReducedChiSquaredThreshold
+ ):
+ continue
+ stampMI = stamp.maskedImage
+ self._applyStampFit(stamp)
+ tempStampMIs.append(stampMI)
+
+ badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes)
+ statisticsControl.setAndMask(badMaskBitMask)
+
+ # Amir: In case the total number of stamps is less than 20, the following will result in an
+ # empty subsetStampMIs list.
+ if len(tempStampMIs) == self.config.subsetStampNumber:
+ subsetStampMIs.append(statisticsStack(tempStampMIs, stackTypeProperty, statisticsControl))
+ # TODO: what to do with remaining temp stamps?
+ tempStampMIs = []
+ used_stars += self.config.subsetStampNumber
+
+ self.metadata["psfStarCount"] = {}
+ self.metadata["psfStarCount"]["all"] = all_stars
+ self.metadata["psfStarCount"]["used"] = used_stars
+ # TODO: which stamp mask plane to use here?
+ # TODO: Amir: there might be cases where subsetStampMIs is an empty list. What do we want to do then?
+ # Currently, we get an "IndexError: list index out of range"
+ badMaskBitMask = subsetStampMIs[0].mask.getPlaneBitMask(self.config.badMaskPlanes)
+ statisticsControl.setAndMask(badMaskBitMask)
+ extendedPsfMI = statisticsStack(subsetStampMIs, stackTypeProperty, statisticsControl)
+
+ extendedPsfExtent = extendedPsfMI.getBBox().getDimensions()
+ extendedPsfOrigin = Point2I(-1 * (extendedPsfExtent.x // 2), -1 * (extendedPsfExtent.y // 2))
+ extendedPsfMI.setXY0(extendedPsfOrigin)
+ # return Struct(extendedPsf=[extendedPsfMI])
+
+ return Struct(extendedPsf=extendedPsfMI.getImage())
+
+ # stack = []
+ # chiStack = []
+ # for loop over all groups:
+ # load up all visits for this detector
+ # drop all with GOF > thresh
+ # sigma-clip mean stack the rest
+ # append to stack
+ # compute the scatter (MAD/sigma-clipped var, etc) of the rest
+ # divide by sqrt(var plane), and append to chiStack
+ # after for-loop, combine images in median stack for final result
+ # also combine chi-images, save separately
+
+ # idea: run with two different thresholds, and compare the results
+
+ # medianStack = []
+ # for loop over all groups:
+ # load up all visits for this detector
+ # drop all with GOF > thresh
+ # median/sigma-clip stack the rest
+ # append to medianStack
+ # after for-loop, combine images in median stack for final result
diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/subtractBrightStar.py b/python/lsst/pipe/tasks/brightStarSubtraction/subtractBrightStar.py
new file mode 100644
index 000000000..04de11116
--- /dev/null
+++ b/python/lsst/pipe/tasks/brightStarSubtraction/subtractBrightStar.py
@@ -0,0 +1,858 @@
+# This file is part of pipe_tasks.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+"""Retrieve extended PSF model and subtract bright stars at visit level."""
+
+__all__ = ["BrightStarSubtractConnections", "BrightStarSubtractConfig", "BrightStarSubtractTask"]
+
+import logging
+from typing import Any
+import astropy.units as u
+import numpy as np
+from astropy.table import Table, Column
+from lsst.afw.cameraGeom import FIELD_ANGLE, PIXELS, TAN_PIXELS
+from lsst.afw.detection import Footprint, FootprintSet, Threshold
+from lsst.afw.geom import SkyWcs, SpanSet
+from lsst.afw.geom.transformFactory import makeTransform
+from lsst.afw.image import ExposureF, ImageD, ImageF, MaskedImageF
+from lsst.afw.math import BackgroundList, WarpingControl, warpImage
+from lsst.daf.butler import DataCoordinate
+from lsst.geom import (
+ AffineTransform,
+ Box2I,
+ Extent2D,
+ Extent2I,
+ Point2D,
+ Point2I,
+ SpherePoint,
+ arcseconds,
+ floor,
+ radians,
+)
+from lsst.meas.algorithms import (
+ LoadReferenceObjectsConfig,
+ ReferenceObjectLoader,
+)
+from lsst.pex.config import ChoiceField, ConfigField, Field, ListField
+from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct
+from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput
+from lsst.utils.timer import timeMethod
+from copy import deepcopy
+
+NEIGHBOR_MASK_PLANE = "NEIGHBOR"
+
+logger = logging.getLogger(__name__)
+
+
+class BrightStarSubtractConnections(
+ PipelineTaskConnections,
+ dimensions=("instrument", "visit", "detector"),
+ defaultTemplates={
+ # "outputExposureName": "brightStar_subtracted",
+ "outputExposureName": "postISRCCD",
+ "outputBackgroundName": "brightStars",
+ "badStampsName": "brightStars",
+ },
+):
+ inputCalexp = Input(
+ name="calexp",
+ storageClass="ExposureF",
+ doc="Background-subtracted input exposure from which to extract bright star stamp cutouts.",
+ dimensions=("visit", "detector"),
+ )
+ inputBackground = Input(
+ name="calexpBackground",
+ storageClass="Background",
+ doc="Background model for the input exposure, to be added back on during processing.",
+ dimensions=("visit", "detector"),
+ )
+ inputExposure = Input(
+ doc="Input exposure from which to subtract bright star stamps.",
+ name="postISRCCD",
+ storageClass="Exposure",
+ dimensions=(
+ "exposure",
+ "detector",
+ ),
+ )
+ inputExtendedPsf = Input(
+ name="extendedPsf2", # extendedPsfDetector ???
+ storageClass="ImageF", # MaskedImageF
+ doc="Extended PSF model, built from stacking bright star cutouts.",
+ dimensions=("band",),
+ )
+ refCat = PrerequisiteInput(
+ doc="Reference catalog that contains bright star positions",
+ name="gaia_dr3_20230707",
+ storageClass="SimpleCatalog",
+ dimensions=("skypix",),
+ multiple=True,
+ deferLoad=True,
+ )
+ # outputBadStamps = Output(
+ # doc="The stamps that are not normalized and consequently not subtracted from the exposure.",
+ # name="{badStampsName}_unsubtracted_stamps",
+ # storageClass="BrightStarStamps",
+ # dimensions=(
+ # "visit",
+ # "detector",
+ # ),
+ # )
+
+ outputExposure = Output(
+ doc="Exposure with bright stars subtracted.",
+ name="{outputExposureName}_subtracted",
+ storageClass="ExposureF",
+ dimensions=(
+ "exposure",
+ "detector",
+ ),
+ )
+ outputBackgroundExposure = Output(
+ doc="Exposure containing only the modelled bright stars.",
+ name="{outputBackgroundName}_background",
+ storageClass="ExposureF",
+ dimensions=(
+ "visit",
+ "detector",
+ ),
+ )
+ # scaledModels = Output(
+ # doc="Stamps containing models scaled to the level of stars",
+ # name="scaledModels",
+ # storageClass="BrightStarStamps",
+ # dimensions=(
+ # "visit",
+ # "detector",
+ # ),
+ # )
+
+
+class BrightStarSubtractConfig(PipelineTaskConfig, pipelineConnections=BrightStarSubtractConnections):
+ """Configuration parameters for BrightStarSubtractTask"""
+
+ doWriteSubtractor = Field[bool](
+ doc="Should an exposure containing all bright star models be written to disk?",
+ default=True,
+ )
+ doWriteSubtractedExposure = Field[bool](
+ doc="Should an exposure with bright stars subtracted be written to disk?",
+ default=True,
+ )
+ magLimit = Field[float](
+ doc="Magnitude limit, in Gaia G; all stars brighter than this value will be subtracted",
+ default=18,
+ )
+ minValidAnnulusFraction = Field[float](
+ doc="Minimum number of valid pixels that must fall within the annulus for the bright star to be "
+ "saved for subsequent generation of a PSF.",
+ default=0.0,
+ )
+ numSigmaClip = Field[float](
+ doc="Sigma for outlier rejection; ignored if annularFluxStatistic != 'MEANCLIP'.",
+ default=4,
+ )
+ numIter = Field[int](
+ doc="Number of iterations of outlier rejection; ignored if annularFluxStatistic != 'MEANCLIP'.",
+ default=3,
+ )
+ warpingKernelName = ChoiceField[str](
+ doc="Warping kernel",
+ default="lanczos5",
+ allowed={
+ "bilinear": "bilinear interpolation",
+ "lanczos3": "Lanczos kernel of order 3",
+ "lanczos4": "Lanczos kernel of order 4",
+ "lanczos5": "Lanczos kernel of order 5",
+ "lanczos6": "Lanczos kernel of order 6",
+ "lanczos7": "Lanczos kernel of order 7",
+ },
+ )
+ scalingType = ChoiceField[str](
+ doc="How the model should be scaled to each bright star; implemented options are "
+ "`annularFlux` to reuse the annular flux of each stamp, or `leastSquare` to perform "
+ "least square fitting on each pixel with no bad mask plane set.",
+ default="leastSquare",
+ allowed={
+ "annularFlux": "reuse BrightStarStamp annular flux measurement",
+ "leastSquare": "find least square scaling factor",
+ },
+ )
+ annularFluxStatistic = ChoiceField[str](
+ doc="Type of statistic to use to compute annular flux.",
+ default="MEANCLIP",
+ allowed={
+ "MEAN": "mean",
+ "MEDIAN": "median",
+ "MEANCLIP": "clipped mean",
+ },
+ )
+ badMaskPlanes = ListField[str](
+ doc="Mask planes that, if set, lead to associated pixels not being included in the computation of "
+ "the scaling factor (`BAD` should always be included). Ignored if scalingType is `annularFlux`, "
+ "as the stamps are expected to already be normalized.",
+ # Note that `BAD` should always be included, as secondary detected
+ # sources (i.e., detected sources other than the primary source of
+ # interest) also get set to `BAD`.
+ # Lee: find out the value of "BAD" and set the nan values into that number in the mask plane(?)
+ default=("BAD", "CR", "CROSSTALK", "EDGE", "NO_DATA", "SAT", "SUSPECT", "UNMASKEDNAN"),
+ )
+ subtractionBox = ListField[int](
+ doc="Size of the stamps to be extracted, in pixels.",
+ default=(250, 250),
+ )
+ subtractionBoxBuffer = Field[float](
+ doc=(
+ "'Buffer' (multiplicative) factor to be applied to determine the size of the stamp the "
+ "processed stars will be saved in. This is also the size of the extended PSF model. The buffer "
+ "region is masked and contain no data and subtractionBox determines the region where contains "
+ "the data."
+ ),
+ default=1.1,
+ )
+ doApplySkyCorr = Field[bool](
+ doc="Apply full focal plane sky correction before extracting stars?",
+ default=True,
+ )
+ min_iterations = Field[int](
+ doc="Minimum number of iterations to complete before evaluating changes in each iteration.",
+ default=3,
+ )
+ refObjLoader = ConfigField[LoadReferenceObjectsConfig](
+ doc="Reference object loader for astrometric calibration.",
+ )
+ maskWarpingKernelName = ChoiceField[str](
+ doc="Warping kernel for mask.",
+ default="bilinear",
+ allowed={
+ "bilinear": "bilinear interpolation",
+ "lanczos3": "Lanczos kernel of order 3",
+ "lanczos4": "Lanczos kernel of order 4",
+ "lanczos5": "Lanczos kernel of order 5",
+ },
+ )
+ # Cutout geometry
+ stampSize = ListField[int](
+ doc="Size of the stamps to be extracted, in pixels.",
+ default=(251, 251),
+ )
+ stampSizePadding = Field[float](
+ doc="Multiplicative factor applied to the cutout stamp size, to guard against post-warp data loss.",
+ default=1.1,
+ )
+ # Star selection
+ magRange = ListField[float](
+ doc="Magnitude range in Gaia G. Cutouts will be made for all stars in this range.",
+ default=[0, 18],
+ )
+ minAreaFraction = Field[float](
+ doc="Minimum fraction of the stamp area, post-masking, that must remain for a cutout to be retained.",
+ default=0.1,
+ )
+ useMedianVariance = Field[bool](
+ doc="Use the median of the variance plane for PSF fitting.",
+ default=False,
+ )
+ psfMaskedFluxFracThreshold = Field[float](
+ doc="Maximum allowed fraction of masked PSF flux for PSF fitting to occur.",
+ default=0.97,
+ )
+ scalePsfModel = Field[bool](
+ doc="If True, uses a scale factor to bring the PSF model data to the same level of the star data.",
+ default=True,
+ )
+
+
+class BrightStarSubtractTask(PipelineTask):
+ """Use an extended PSF model to subtract bright stars from a calibrated
+ exposure (i.e. at single-visit level).
+
+ This task uses both a set of bright star stamps produced by
+ `~lsst.pipe.tasks.processBrightStars.ProcessBrightStarsTask`
+ and an extended PSF model produced by
+ `~lsst.pipe.tasks.extended_psf.MeasureExtendedPsfTask`.
+ """
+
+ ConfigClass = BrightStarSubtractConfig
+ _DefaultName = "subtractBrightStars"
+
+ def __init__(self, *args, **kwargs):
+ # super().__init__(*args, **kwargs)
+ # # Placeholders to set up Statistics if scalingType is leastSquare.
+ # self.statsControl, self.statsFlag = None, None
+ # # Warping control; only contains shiftingALg provided in config.
+
+ super().__init__(*args, **kwargs)
+ stampSize = Extent2D(*self.config.stampSize.list())
+ stampRadius = floor(stampSize / 2)
+ # Define a central bounding box of the configured stamp size.
+ self.stampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(stampRadius)
+ paddedStampSize = stampSize
+ self.paddedStampRadius = floor(paddedStampSize / 2)
+ self.paddedStampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(
+ self.paddedStampRadius
+ )
+
+ def runQuantum(self, butlerQC, inputRefs, outputRefs):
+ # Docstring inherited.
+ inputs = butlerQC.get(inputRefs)
+ dataId = butlerQC.quantum.dataId
+ refObjLoader = ReferenceObjectLoader(
+ dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat],
+ refCats=inputs.pop("refCat"),
+ name=self.config.connections.refCat,
+ config=self.config.refObjLoader,
+ )
+ # TODO: include the un-subtracted stars here!
+ subtractor = self.run(**inputs, dataId=dataId, refObjLoader=refObjLoader)
+ if self.config.doWriteSubtractedExposure:
+ outputExposure = inputs["inputExposure"].clone()
+ outputExposure.image -= subtractor.image
+ else:
+ outputExposure = None
+ outputBackgroundExposure = ExposureF(subtractor) if self.config.doWriteSubtractor else None
+ output = Struct(
+ outputExposure=outputExposure,
+ outputBackgroundExposure=outputBackgroundExposure,
+ )
+ butlerQC.put(output, outputRefs)
+
+ @timeMethod
+ def run(
+ self,
+ inputExposure: ExposureF,
+ inputCalexp: ExposureF,
+ inputBackground: BackgroundList,
+ inputExtendedPsf: ImageF,
+ dataId: dict[str, Any] | DataCoordinate,
+ refObjLoader: ReferenceObjectLoader,
+ ):
+ """Generate a bright star subtractor image using scaled extended PSF models.
+
+ Identifies bright stars within the calibrated exposure using a
+ reference catalog, extracts stamps around each, warps the extended PSF
+ model onto the stamp frame, fits for a scale factor and pedestal for
+ each star iteratively, and combines the scaled models into a single
+ subtractor exposure.
+
+ Parameters
+ ----------
+ inputExposure : `~lsst.afw.image.ExposureF`
+ The Post-ISR CCD frame. Note: Currently appears unused directly
+ within this method's main logic, but required by the pipeline
+ definition. Cutouts are based on `inputCalexp` + `inputBackground`.
+ inputCalexp: `~lsst.afw.image.ExposureF`
+ The background-subtracted calibrated exposure used for identifying
+ stars, extracting stamps, and fitting models.
+ inputBackground : `~lsst.afw.math.BackgroundList`
+ The background model associated with `inputCalexp`. Added back
+ before processing stamps.
+ inputExtendedPsf : `~lsst.afw.image.ImageF`
+ The extended PSF model (e.g., from MeasureExtendedPsfTask).
+ refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`
+ Loader used to query the reference catalog for bright stars within
+ the exposure footprint.
+ dataId : `dict` or `~lsst.daf.butler.DataCoordinate`
+ The data identifier for the input exposure.
+
+ Returns
+ -------
+ subtractor : `~lsst.afw.image.ExposureF`
+ An exposure containing the combined, scaled models of the bright
+ stars identified and processed. This image can be subtracted from
+ the original `inputExposure` (or `inputCalexp` + `inputBackground`).
+ The image plane contains the model flux, while the variance and
+ mask planes are typically empty or minimal unless specifically populated.
+ """
+ wcs = inputCalexp.getWcs()
+ bbox = inputCalexp.getBBox()
+ warpingControl = WarpingControl(self.config.warpingKernelName, self.config.maskWarpingKernelName)
+
+ refCatBright = self._getRefCatBright(refObjLoader, wcs, bbox)
+ refCatBright.sort("mag")
+ zipRaDec = zip(refCatBright["coord_ra"] * radians, refCatBright["coord_dec"] * radians)
+ spherePoints = [SpherePoint(ra, dec) for ra, dec in zipRaDec]
+ pixCoords = wcs.skyToPixel(spherePoints)
+
+ # Create image with background added back
+ # Using calibrated exposure for finding the scale factor and creating subtrator.
+ # The generated subtractor will be subtracted from PostISRCCd.
+ inputFixed = inputCalexp.getMaskedImage()
+ inputFixed += inputBackground.getImage()
+ inputCalexp.mask.addMaskPlane(NEIGHBOR_MASK_PLANE)
+ # Associate detected footprints (from DETECTED plane) with the bright reference stars.
+ allFootprints, associations = self._associateFootprints(inputCalexp, pixCoords, plane="DETECTED")
+
+ subtractorExp = ExposureF(bbox=bbox)
+ templateSubtractor = subtractorExp.maskedImage
+
+ detector = inputCalexp.detector
+ pixelScale = wcs.getPixelScale().asArcseconds() * arcseconds
+ pixToTan = detector.getTransform(PIXELS, TAN_PIXELS)
+ pixToFocalPlaneTan = detector.getTransform(PIXELS, FIELD_ANGLE).then(
+ makeTransform(AffineTransform.makeScaling(1 / pixelScale.asRadians()))
+ )
+
+ self.warpedDataDict = {}
+ removalIndices = []
+ for j in range(self.config.min_iterations):
+ scaleList = []
+ for starIndex, (obj, pixCoord) in enumerate(zip(refCatBright, pixCoords)): # type: ignore
+ # Start with the background-added image for each star
+ inputMI = deepcopy(inputFixed)
+ restSubtractor = deepcopy(templateSubtractor)
+ myNumber = 0
+ for key in self.warpedDataDict.keys():
+ # Subtract the *current best models* of all *other* stars before fitting this one.
+ if self.warpedDataDict[key]["subtractor"] is not None and key != obj['id']:
+ restSubtractor.image += self.warpedDataDict[key]["subtractor"].image
+ myNumber += 1
+ self.log.debug(f"Number of stars subtracted before finding the scale factor for {obj['id']}: ", myNumber)
+ inputMI.image -= restSubtractor.image
+
+ footprintIndex = associations.get(starIndex, None)
+
+ if footprintIndex:
+ neighborFootprints = [fp for i, fp in enumerate(allFootprints) if i != footprintIndex]
+ self._setFootprints(inputMI, neighborFootprints, NEIGHBOR_MASK_PLANE)
+ else:
+ self._setFootprints(inputMI, allFootprints, NEIGHBOR_MASK_PLANE)
+ # Define linear shifting to recenter stamps
+ coordFocalPlaneTan = pixToFocalPlaneTan.applyForward(pixCoord) # center of warped star
+ shift = makeTransform(AffineTransform(Point2D(0, 0) - coordFocalPlaneTan))
+ angle = np.arctan2(coordFocalPlaneTan.getY(), coordFocalPlaneTan.getX()) * radians
+ rotation = makeTransform(AffineTransform.makeRotation(-angle))
+ pixToPolar = pixToFocalPlaneTan.then(shift).then(rotation)
+ rawStamp= self._getCutout(inputExposure=inputMI, coordPix=pixCoord, stampSize=self.config.stampSize.list())
+ if rawStamp is None:
+ self.log.debug(f"No stamp for star with refID {obj['id']}")
+ removalIndices.append(starIndex)
+ continue
+ warpedStamp = self._warpRawStamp(rawStamp, warpingControl, pixToTan, pixCoord)
+ warpedModel = ImageF(warpedStamp.getBBox())
+ inputExtendedPsfGeneral = deepcopy(inputExtendedPsf)
+ good_pixels = warpImage(warpedModel, inputExtendedPsfGeneral, pixToPolar.inverted(), warpingControl)
+ self.warpedDataDict[obj["id"]] = {"stamp": warpedStamp, "model": warpedModel, "starIndex": starIndex, "pixCoord": pixCoord}
+ if j == 0:
+ self.warpedDataDict[obj["id"]]["scale"] = None
+ self.warpedDataDict[obj["id"]]["subtractor"] = None
+ fitPsfResults = {}
+ if self.config.scalePsfModel:
+ psfNeg = warpedModel.array < 0
+ self.modelScale = np.nanmean(warpedStamp.image.array) / np.nanmean(warpedModel.array[~psfNeg])
+ warpedModel.array *= self.modelScale ######## model scale correction ########
+ fitPsfResults = self._fitPsf( warpedStamp, warpedModel)
+ if fitPsfResults:
+ scaleList.append(fitPsfResults["scale"])
+ self.warpedDataDict[obj["id"]]["scale"] = fitPsfResults["scale"]
+
+
+ cond = np.isnan(warpedModel.array)
+ warpedModel.array[cond] = 0
+ warpedModel.array *= fitPsfResults["scale"]
+ overlapBBox = Box2I(warpedStamp.getBBox())
+ overlapBBox.clip(inputCalexp.getBBox())
+
+ subtractor = deepcopy(templateSubtractor)
+ subtractor[overlapBBox] += warpedModel[overlapBBox]
+ self.warpedDataDict[obj["id"]]["subtractor"] = subtractor
+
+
+ else:
+ scaleList.append(np.nan)
+ if "subtractor" not in self.warpedDataDict[obj["id"]].keys():
+ self.warpedDataDict[obj["id"]]["subtractor"] = None
+ self.warpedDataDict[obj["id"]]["scale"] = None
+ if j == 0:
+ refCatBright.remove_rows(removalIndices)
+ updatedPixCoords = [item for i, item in enumerate(pixCoords) if i not in removalIndices]
+ pixCoords = updatedPixCoords
+ new_scale_column = Column(scaleList, name=f'scale_0{j}')
+ # The following is handy when developing, not sure if we want to do that in the final version!
+ refCatBright.add_columns([new_scale_column])
+
+ subtractor = deepcopy(templateSubtractor)
+ for key in self.warpedDataDict.keys():
+ if self.warpedDataDict[key]["scale"] is not None:
+ subtractor.image.array += self.warpedDataDict[key]["subtractor"].image.array
+ return subtractor
+
+ def _getRefCatBright(self, refObjLoader: ReferenceObjectLoader, wcs: SkyWcs, bbox: Box2I) -> Table:
+ """Get a bright star subset of the reference catalog.
+
+ Trim the reference catalog to only those objects within the exposure
+ bounding box dilated by half the bright star stamp size.
+ This ensures all stars that overlap the exposure are included.
+
+ Parameters
+ ----------
+ refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`
+ Loader to find objects within a reference catalog.
+ wcs : `~lsst.afw.geom.SkyWcs`
+ World coordinate system.
+ bbox : `~lsst.geom.Box2I`
+ Bounding box of the exposure.
+
+ Returns
+ -------
+ refCatBright : `~astropy.table.Table`
+ Bright star subset of the reference catalog.
+ """
+ dilatedBBox = bbox.dilatedBy(self.paddedStampRadius)
+ withinExposure = refObjLoader.loadPixelBox(dilatedBBox, wcs, filterName="phot_g_mean")
+ refCatFull = withinExposure.refCat
+ fluxField: str = withinExposure.fluxField
+
+ brightFluxRange = sorted(((self.config.magRange * u.ABmag).to(u.nJy)).to_value())
+
+ subsetStars = (refCatFull[fluxField] > brightFluxRange[0]) & (
+ refCatFull[fluxField] < brightFluxRange[1]
+ )
+ refCatSubset = Table(refCatFull.extract("id", "coord_ra", "coord_dec", fluxField, where=subsetStars))
+ fluxNanojansky = refCatSubset[fluxField][:] * u.nJy # type: ignore
+ refCatSubset["mag"] = fluxNanojansky.to(u.ABmag).to_value() # AB magnitudes
+ return refCatSubset
+
+ def _associateFootprints(
+ self, inputExposure: ExposureF, pixCoords: list[Point2D], plane: str
+ ) -> tuple[list[Footprint], dict[int, int]]:
+ """Associate footprints from a given mask plane with specific objects.
+
+ Footprints from the given mask plane are associated with objects at the
+ coordinates provided, where possible.
+
+ Parameters
+ ----------
+ inputExposure : `~lsst.afw.image.ExposureF`
+ The input exposure with a mask plane.
+ pixCoords : `list` [`~lsst.geom.Point2D`]
+ The pixel coordinates of the objects.
+ plane : `str`
+ The mask plane used to identify masked pixels.
+
+ Returns
+ -------
+ footprints : `list` [`~lsst.afw.detection.Footprint`]
+ The footprints from the input exposure.
+ associations : `dict`[int, int]
+ Association indices between objects (key) and footprints (value).
+ """
+ detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(plane), Threshold.BITMASK)
+ footprintSet = FootprintSet(inputExposure.mask, detThreshold)
+ footprints = footprintSet.getFootprints()
+ associations = {}
+ for starIndex, pixCoord in enumerate(pixCoords):
+ for footprintIndex, footprint in enumerate(footprints):
+ if footprint.contains(Point2I(pixCoord)):
+ associations[starIndex] = footprintIndex
+ break
+ self.log.debug(
+ "Associated %i of %i star%s to one each of the %i %s footprint%s.",
+ len(associations),
+ len(pixCoords),
+ "" if len(pixCoords) == 1 else "s",
+ len(footprints),
+ plane,
+ "" if len(footprints) == 1 else "s",
+ )
+ return footprints, associations
+
+ def _setFootprints(self, inputExposure: ExposureF, footprints: list, maskPlane: str):
+ """Set footprints in a given mask plane.
+
+ Parameters
+ ----------
+ inputExposure : `~lsst.afw.image.ExposureF`
+ The input exposure to modify.
+ footprints : `list` [`~lsst.afw.detection.Footprint`]
+ The footprints to set in the mask plane.
+ maskPlane : `str`
+ The mask plane to set the footprints in.
+
+ Notes
+ -----
+ This method modifies the ``inputExposure`` object in-place.
+ """
+ detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(maskPlane), Threshold.BITMASK)
+ detThresholdValue = int(detThreshold.getValue())
+ footprintSet = FootprintSet(inputExposure.mask, detThreshold)
+
+ # Wipe any existing footprints in the mask plane
+ inputExposure.mask.clearMaskPlane(int(np.log2(detThresholdValue)))
+
+ # Set the footprints in the mask plane
+ footprintSet.setFootprints(footprints)
+ footprintSet.setMask(inputExposure.mask, maskPlane)
+
+ def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, Any]:
+ """Fit a scaled PSF and a pedestal to each bright star cutout.
+
+ Parameters
+ ----------
+ stampMI : `~lsst.afw.image.MaskedImageF`
+ The masked image of the bright star cutout.
+ psfImage : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF`
+ The PSF model to fit.
+
+ Returns
+ -------
+ fitPsfResults : `dict`[`str`, `float`]
+ The result of the PSF fitting, with keys:
+
+ ``scale`` : `float`
+ The scale factor.
+ ``scaleErr`` : `float`
+ The error on the scale factor.
+ ``pedestal`` : `float`
+ The pedestal value.
+ ``pedestalErr`` : `float`
+ The error on the pedestal value.
+ ``pedestalScaleCov`` : `float`
+ The covariance between the pedestal and scale factor.
+ ``xGradient`` : `float`
+ The gradient in the x-direction.
+ ``yGradient`` : `float`
+ The gradient in the y-direction.
+ ``globalReducedChiSquared`` : `float`
+ The global reduced chi-squared goodness-of-fit.
+ ``globalDegreesOfFreedom`` : `int`
+ The global number of degrees of freedom.
+ ``psfReducedChiSquared`` : `float`
+ The PSF BBox reduced chi-squared goodness-of-fit.
+ ``psfDegreesOfFreedom`` : `int`
+ The PSF BBox number of degrees of freedom.
+ ``psfMaskedFluxFrac`` : `float`
+ The fraction of the PSF image flux masked by bad pixels.
+ """
+ badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes)
+
+ # Calculate the fraction of the PSF image flux masked by bad pixels
+ psfMaskedPixels = ImageF(psfImage.getBBox())
+ psfMaskedPixels.array[:, :] = (stampMI.mask[psfImage.getBBox()].array & badMaskBitMask).astype(bool)
+ # TODO: This is np.float64, else FITS metadata serialization fails
+ # Amir: what do we want to do for subtraction? we do not have the luxury of removing the star from the process here!
+ psfMaskedFluxFrac = np.dot(psfImage.array.flat, psfMaskedPixels.array.flat).astype(np.float64)
+ # if psfMaskedFluxFrac > self.config.psfMaskedFluxFracThreshold:
+ # return {} # Handle cases where the PSF image is mostly masked
+
+ # Create a padded version of the input constant PSF image
+ paddedPsfImage = ImageF(stampMI.getBBox())
+ paddedPsfImage[psfImage.getBBox()] = psfImage.convertF()
+
+ # Create consistently masked data
+ mask = self.add_psf_mask(psfImage, stampMI)
+ badSpans = SpanSet.fromMask(mask, badMaskBitMask)
+ goodSpans = SpanSet(stampMI.getBBox()).intersectNot(badSpans)
+ varianceData = goodSpans.flatten(stampMI.variance.array, stampMI.getXY0())
+ if self.config.useMedianVariance:
+ varianceData = np.median(varianceData)
+ sigmaData = np.sqrt(varianceData)
+ imageData = goodSpans.flatten(stampMI.image.array, stampMI.getXY0()) # B
+ imageData /= sigmaData
+ psfData = goodSpans.flatten(paddedPsfImage.array, paddedPsfImage.getXY0())
+ psfData /= sigmaData
+
+ # Fit the PSF scale factor and global pedestal
+ nData = len(imageData)
+ coefficientMatrix = np.ones((nData, 4), dtype=float) # A
+ coefficientMatrix[:, 0] = psfData
+ coefficientMatrix[:, 1] /= sigmaData
+ coefficientMatrix[:, 2:] = goodSpans.indices().T
+ coefficientMatrix[:, 2] /= sigmaData
+ coefficientMatrix[:, 3] /= sigmaData
+ try:
+ solutions, sumSquaredResiduals, *_ = np.linalg.lstsq(coefficientMatrix, imageData, rcond=None)
+ covarianceMatrix = np.linalg.inv(np.dot(coefficientMatrix.transpose(), coefficientMatrix)) # C
+ except np.linalg.LinAlgError:
+ return {} # Handle singular matrix errors
+ if sumSquaredResiduals.size == 0:
+ return {}
+ scale = solutions[0]
+ if scale <= 0:
+ return {} # Handle cases where the PSF scale fit has failed
+ scaleErr = np.sqrt(covarianceMatrix[0, 0])
+ pedestal = solutions[1]
+ pedestalErr = np.sqrt(covarianceMatrix[1, 1])
+ scalePedestalCov = covarianceMatrix[0, 1]
+ xGradient = solutions[3]
+ yGradient = solutions[2]
+
+ # Calculate global (whole image) reduced chi-squared
+ globalChiSquared = np.sum(sumSquaredResiduals)
+ globalDegreesOfFreedom = nData - 4
+ globalReducedChiSquared = globalChiSquared / globalDegreesOfFreedom
+
+ # Calculate PSF BBox reduced chi-squared
+ psfBBoxGoodSpans = goodSpans.clippedTo(psfImage.getBBox())
+ psfBBoxGoodSpansX, psfBBoxGoodSpansY = psfBBoxGoodSpans.indices()
+ psfBBoxData = psfBBoxGoodSpans.flatten(stampMI.image.array, stampMI.getXY0())
+ psfBBoxModel = (
+ psfBBoxGoodSpans.flatten(paddedPsfImage.array, stampMI.getXY0()) * scale
+ + pedestal
+ + psfBBoxGoodSpansX * xGradient
+ + psfBBoxGoodSpansY * yGradient
+ )
+ psfBBoxVariance = psfBBoxGoodSpans.flatten(stampMI.variance.array, stampMI.getXY0())
+ psfBBoxResiduals = (psfBBoxData - psfBBoxModel) ** 2 / psfBBoxVariance
+ psfBBoxChiSquared = np.sum(psfBBoxResiduals)
+ psfBBoxDegreesOfFreedom = len(psfBBoxData) - 4
+ psfBBoxReducedChiSquared = psfBBoxChiSquared / psfBBoxDegreesOfFreedom
+
+ return dict(
+ scale=scale,
+ scaleErr=scaleErr,
+ pedestal=pedestal,
+ pedestalErr=pedestalErr,
+ xGradient=xGradient,
+ yGradient=yGradient,
+ pedestalScaleCov=scalePedestalCov,
+ globalReducedChiSquared=globalReducedChiSquared,
+ globalDegreesOfFreedom=globalDegreesOfFreedom,
+ psfReducedChiSquared=psfBBoxReducedChiSquared,
+ psfDegreesOfFreedom=psfBBoxDegreesOfFreedom,
+ psfMaskedFluxFrac=psfMaskedFluxFrac,
+ )
+
+ def add_psf_mask(self, psfImage, stampMI):
+ """Add problematic PSF pixels to the stamp's mask.
+
+ Identifies pixels in the PSF model image that are NaN or negative
+ and sets the corresponding bits (hardcoded as plane 0, likely 'BAD')
+ in a copy of the input stamp's mask.
+
+ Parameters
+ ----------
+ psfImage : `~lsst.afw.image.ImageF`
+ PSF model image defined on the stamp grid.
+ stampMI : `~lsst.afw.image.MaskedImageF`
+ The masked image stamp of the star being fitted. Its mask is used
+ as the base.
+
+ Returns
+ -------
+ mask : `~lsst.afw.image.Mask`
+ A mask object based on the input stamp's mask, updated to include
+ masked pixels derived from the PSF model image.
+ """
+ cond = np.isnan(psfImage.array)
+ cond |= psfImage.array < 0
+ mask = deepcopy(stampMI.mask)
+ mask.array[cond] = np.bitwise_or(mask.array[cond], 1)
+ return mask
+
+
+ def _getCutout(self, inputExposure, coordPix: Point2D, stampSize: list[int]):
+ """Get a cutout from an input exposure, handling edge cases.
+
+ Generate a cutout from an input exposure centered on a given position
+ and with a given size.
+ If any part of the cutout is outside the input exposure bounding box,
+ the cutout is padded with NaNs.
+
+ Parameters
+ ----------
+ inputExposure : `~lsst.afw.image.ExposureF`
+ The image to extract bright star stamps from.
+ coordPix : `~lsst.geom.Point2D`
+ Center of the cutout in pixel space.
+ stampSize : `list` [`int`]
+ Size of the cutout, in pixels.
+
+ Returns
+ -------
+ stamp : `~lsst.afw.image.ExposureF` or `None`
+ The cutout, or `None` if the cutout is entirely outside the input
+ exposure bounding box.
+
+ Notes
+ -----
+ This method is a short-term workaround until DM-40042 is implemented.
+ At that point, it should be replaced by a call to the Exposure method
+ ``getCutout``, which will handle edge cases automatically.
+ """
+ corner = Point2I(np.array(coordPix) - np.array(stampSize) / 2)
+ dimensions = Extent2I(stampSize)
+ stampBBox = Box2I(corner, dimensions)
+ overlapBBox = Box2I(stampBBox)
+ overlapBBox.clip(inputExposure.getBBox())
+ if overlapBBox.getArea() > 0:
+ # Create full-sized stamp with pixels initially flagged as NO_DATA.
+ stamp = ExposureF(bbox=stampBBox)
+ stamp.image[:] = np.nan
+ stamp.mask.set(inputExposure.mask.getPlaneBitMask("NO_DATA"))
+ # # Restore pixels which overlap the input exposure.
+ overlap = inputExposure.Factory(inputExposure, overlapBBox)
+ stamp.maskedImage[overlapBBox] = overlap
+ else:
+ stamp = None
+ return stamp
+
+ def _warpRawStamp(self, rawStamp, warpingControl, pixToTan, pixCoord):
+ """Warps a raw image stamp onto a common tangent plane projection.
+
+ Applies a transformation (`pixToTan`) followed by a shift
+ transform to warp the input `rawStamp` onto a destination
+ `MaskedImageF` aligned with a tangent plane centered near the object.
+ The shift aims to place the object center at the center of the
+ destination image.
+
+ Parameters
+ ----------
+ rawStamp : `~lsst.afw.image.ExposureF`
+ The raw cutout image stamp (e.g., from `_getCutout`).
+ warpingControl : `~lsst.afw.math.WarpingControl`
+ Configuration for the warping process.
+ pixToTan : `~lsst.afw.geom.Transform`
+ Transformation from the raw stamp's pixel coordinates to the
+ common tangent plane coordinates.
+ pixCoord : `~lsst.geom.Point2D`
+ Pixel coordinates of the object center in the original exposure,
+ used to calculate the centering shift.
+
+ Returns
+ -------
+ warped_stamp : `~lsst.afw.image.MaskedImageF` or `None`
+ The warped and shifted masked image, or None if warping failed
+ (e.g., due to insufficient good pixels).
+ """
+ destImage = MaskedImageF(*self.config.stampSize)
+ bottomLeft = Point2D(rawStamp.getXY0())
+ newBottomLeft = pixToTan.applyForward(bottomLeft)
+ newBottomLeft = Point2I(newBottomLeft)
+ destImage.setXY0(newBottomLeft)
+ # Define linear shifting to recenter stamps
+ newCenter = pixToTan.applyForward(pixCoord)
+ self.modelCenter = self.config.stampSize[0] // 2, self.config.stampSize[1] // 2
+ shift = (self.modelCenter[0] + newBottomLeft[0] - newCenter[0], self.modelCenter[1] + newBottomLeft[1] - newCenter[1])
+ affineShift = AffineTransform(shift)
+ shiftTransform = makeTransform(affineShift)
+
+ # Define full transform (warp and shift)
+ starWarper = pixToTan.then(shiftTransform)
+
+ # Apply it
+ goodPix = warpImage(destImage, rawStamp.getMaskedImage(), starWarper, warpingControl)
+ if not goodPix:
+ return None
+ return destImage
+
+ # # Arbitrarily set origin of shifted star to 0
+ # destImage.setXY0(0, 0)
\ No newline at end of file
diff --git a/tests/test_brightStarCutout.py b/tests/test_brightStarCutout.py
new file mode 100644
index 000000000..67b88d02f
--- /dev/null
+++ b/tests/test_brightStarCutout.py
@@ -0,0 +1,102 @@
+# This file is part of pipe_tasks.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import unittest
+
+import lsst.afw.cameraGeom.testUtils
+import lsst.afw.image
+import lsst.utils.tests
+import numpy as np
+from lsst.afw.image import ImageD, ImageF, MaskedImageF
+from lsst.afw.math import FixedKernel
+from lsst.geom import Point2I
+from lsst.meas.algorithms import KernelPsf
+from lsst.pipe.tasks.brightStarSubtraction import BrightStarCutoutConfig, BrightStarCutoutTask
+
+
+class BrightStarCutoutTestCase(lsst.utils.tests.TestCase):
+ def setUp(self):
+ # Fit values
+ self.scale = 2.34e5
+ self.pedestal = 3210.1
+ self.xGradient = 5.432
+ self.yGradient = 10.987
+
+ # Create a pedestal + 2D plane
+ xCoords = np.linspace(-50, 50, 101)
+ yCoords = np.linspace(-50, 50, 101)
+ xPlane, yPlane = np.meshgrid(xCoords, yCoords)
+ pedestal = np.ones_like(xPlane) * self.pedestal
+
+ # Create a pseudo-PSF
+ dist_from_center = np.sqrt(xPlane**2 + yPlane**2)
+ psfArray = np.exp(-dist_from_center / 5)
+ psfArray /= np.sum(psfArray)
+ fixedKernel = FixedKernel(ImageD(psfArray))
+ self.psf = KernelPsf(fixedKernel)
+
+ # Bring everything together to construct a stamp masked image
+ stampArray = psfArray * self.scale + pedestal + xPlane * self.xGradient + yPlane * self.yGradient
+ stampIm = ImageF((stampArray).astype(np.float32))
+ stampVa = ImageF(stampIm.getBBox(), 654.321)
+ self.stampMI = MaskedImageF(image=stampIm, variance=stampVa)
+ self.stampMI.setXY0(Point2I(-50, -50))
+
+ # Ensure that all mask planes required by the task are in-place;
+ # new mask plane entries will be created as necessary
+ badMaskPlanes = [
+ "BAD",
+ "CR",
+ "CROSSTALK",
+ "EDGE",
+ "NO_DATA",
+ "SAT",
+ "SUSPECT",
+ "UNMASKEDNAN",
+ "NEIGHBOR",
+ ]
+ _ = [self.stampMI.mask.addMaskPlane(mask) for mask in badMaskPlanes]
+
+ def test_fitPsf(self):
+ """Test the PSF fitting method."""
+ brightStarCutoutConfig = BrightStarCutoutConfig()
+ brightStarCutoutTask = BrightStarCutoutTask(config=brightStarCutoutConfig)
+ fitPsfResults = brightStarCutoutTask._fitPsf(
+ self.stampMI,
+ self.psf,
+ )
+ self.assertAlmostEqual(fitPsfResults["scale"], self.scale, delta=1e-3)
+ self.assertAlmostEqual(fitPsfResults["pedestal"], self.pedestal, delta=1e-5)
+ self.assertAlmostEqual(fitPsfResults["xGradient"], self.xGradient, delta=1e-7)
+ self.assertAlmostEqual(fitPsfResults["yGradient"], self.yGradient, delta=1e-7)
+
+
+def setup_module(module):
+ lsst.utils.tests.init()
+
+
+class MemoryTestCase(lsst.utils.tests.MemoryTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ lsst.utils.tests.init()
+ unittest.main()