Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-17426 #1

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 124 additions & 2 deletions python/lsst/pipe/drivers/background.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import lsst.afw.image as afwImage
import lsst.afw.geom as afwGeom
import lsst.afw.cameraGeom as afwCameraGeom
import lsst.meas.algorithms as measAlg
import lsst.afw.table as afwTable

from lsst.pex.config import Config, Field, ListField, ChoiceField, ConfigField, RangeField
from lsst.pex.config import Config, Field, ListField, ChoiceField, ConfigField, RangeField, ConfigurableField
from lsst.pipe.base import Task


Expand Down Expand Up @@ -284,7 +286,9 @@ def measureScale(self, image, skyBackground):
statistic = afwMath.stringToStatisticsProperty(self.config.stats.statistic)
imageSamples = []
skySamples = []
for xStart, yStart, xStop, yStop in zip(xLimits[:-1], yLimits[:-1], xLimits[1:], yLimits[1:]):
for xIndex, yIndex in itertools.product(range(self.config.xNumSamples), range(self.config.yNumSamples)):
xStart, xStop = xLimits[xIndex : xIndex + 2]
yStart, yStop = yLimits[yIndex : yIndex + 2]
box = afwGeom.Box2I(afwGeom.Point2I(xStart, yStart), afwGeom.Point2I(xStop, yStop))
subImage = image.Factory(image, box)
subSky = sky.Factory(sky, box)
Expand Down Expand Up @@ -464,6 +468,9 @@ class FocalPlaneBackgroundConfig(Config):
"NONE": "No background estimation is to be attempted",
},
)
doSmooth = Field(dtype=bool, default=False, doc="Do smoothing?")
smoothScale = Field(dtype=float, doc="Smoothing scale")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No default?

smoothWindowSize = Field(dtype=int, default=15, doc="Window size for smoothing")
binning = Field(dtype=int, default=64, doc="Binning to use for CCD background model (pixels)")


Expand Down Expand Up @@ -717,7 +724,122 @@ def getStatsImage(self):
values /= self._numbers
thresh = self.config.minFrac*self.config.xSize*self.config.ySize
isBad = self._numbers.getArray() < thresh
if self.config.doSmooth:
array = values.getArray()
array[isBad] = numpy.nan
gridSize = min(self.config.xSize, self.config.ySize)
array[:] = NanSafeSmoothing.gaussianSmoothing(array, self.config.smoothWindowSize, self.config.smoothScale / gridSize)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line length.

isBad = numpy.isnan(values.array)
interpolateBadPixels(values.getArray(), isBad, self.config.interpolation)
return values


class MaskObjectsConfig(Config):
"""Configuration for MaskObjectsTask"""
nIter = Field(doc="Iteration for masking", dtype=int, default=3)
subtractBackground = ConfigurableField(target=measAlg.SubtractBackgroundTask, doc='Background configuration')
detection = ConfigurableField(target=measAlg.SourceDetectionTask, doc="Detection configuration")
detectSigma = Field(dtype=float, default=5., doc='Detection PSF gaussian sigmas')
doInterpolate = Field(dtype=bool, default=True, doc='Interpolate masked region?')
interpolate = ConfigurableField(target=measAlg.SubtractBackgroundTask, doc='Interpolate configuration')

def setDefaults(self):
self.detection.reEstimateBackground = False
self.detection.doTempLocalBackground = False
self.detection.doTempWideBackground = False
self.detection.thresholdValue = 2.5
# self.detection.thresholdPolarity = "both"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove commented code.

self.subtractBackground.binSize = 1024
self.subtractBackground.useApprox = False
self.interpolate.binSize = 256
self.interpolate.useApprox = False

def validate(self):
assert not self.detection.reEstimateBackground
assert not self.detection.doTempLocalBackground
assert not self.detection.doTempWideBackground

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you change these into RuntimeErrors with useful error messages?



class MaskObjectsTask(Task):
"""MaskObjectsTask

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need a better short summary than the name of the task. Maybe "Iterative masking of objects on an image"?


This task makes more exhaustive object mask by iteratively doing detection and background-subtraction.
michitaro marked this conversation as resolved.
Show resolved Hide resolved
The purpose of this task is to get true background removing faint tails of large objects.
This is useful to make clean SKY from relatively small number of visits.

We deliberately use the specified 'detectSigma' instead of the PSF,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put detectSigma in double backticks instead of quotes:

``detectSigma``

in order to better pick up the faint wings of objects.
"""
ConfigClass = MaskObjectsConfig

def __init__(self, *args, **kwargs):
super(MaskObjectsTask, self).__init__(*args, **kwargs)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python 3, so just need super().__init__(*args, **kwargs).

# Disposable schema suppresses warning from SourceDetectionTask.__init__
self.makeSubtask("detection", schema=afwTable.Schema())
self.makeSubtask('interpolate')
self.makeSubtask('subtractBackground')

def run(self, exp):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for i in range(self.config.nIter):
self.log.info("Masking %d/%d", i + 1, self.config.nIter)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be self.log.debug.

bg = self.subtractBackground.run(exp).background
fp = self.detection.detectFootprints(exp, sigma=self.config.detectSigma, clearMask=True)
exp.maskedImage += bg.getImage()

if self.config.doInterpolate:
self.log.info("Interpolating")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.log.debug

smooth = self.interpolate.run(exp).background
exp.maskedImage += smooth.getImage()
mask = exp.maskedImage.mask
detected = mask.array & mask.getPlaneBitMask(['DETECTED']) > 0
exp.maskedImage.image.array[detected] = smooth.getImage().getArray()[detected]


class NanSafeSmoothing:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this section needs a different implementation.
It looks like you're using class here just as a namespace, which feels dirty. I don't think it adds anything beyond making a stand-alone function with a couple of embedded functions.
A direct convolution implemented in python is going to be crazy slow (especially since you're not making use of the fact that the kernel is separable). Is there a reason you can't use the convolution functionality in afw, as is done in SourceDetectionTask? If you're worried about NaNs, you can always (temporarily?) replace them with zeros when you do the convolution.

Copy link
Author

@michitaro michitaro Apr 2, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're worried about NaNs, you can always (temporarily?) replace them with zeros when you do the convolution.

This is why I didn't use the existing convolution code. Replacing NaNs with 0 is not equivalent to my code.
Sometimes super-pixels on the edge of the field are bit high. In that case it is better that the super-pixels out of the field are as the are extrapolated (this is what I intended to) than assuming that they are zero (replacing NaNs with 0). In other words, replacing NaNs with 0 leads low sky estimation near the bright edge.

Screen Shot 2019-04-02 at 17 17 16

A direct convolution in Python is crazy slow, but the targets of this code are relatively small images such as FocalPlaneBackground._values whose dimensions are about 140x150. Convolutions on such images end in about 0.5 second on my mac.

It looks like you're using class here just as a namespace, which feels dirty. I don't think it adds anything beyond making a stand-alone function with a couple of embedded functions.

I totally agree with you. I will put these functions flat in the module.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figured out the difference between your convolution and the vanilla convolution. This will allow us to condense the code, and probably run it faster:

def safeConvolve(array, sigma=2):
    bad = np.isnan(array)
    safe = np.where(bad, 0.0, array)
    convolved = gaussian_filter(safe, sigma, mode="constant", cval=0.0)
    corr = np.where(bad, 0.0, 1.0)
    ones = np.ones_like(array)
    numerator = gaussian_filter(ones, sigma, mode="constant", cval=0.0)
    denominator = gaussian_filter(corr, sigma, mode="constant", cval=0.0)
    return convolved*numerator/denominator

This reproduces your result:
image

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent solution.

'''
Smooth image dealing with NaN pixels
'''

@classmethod
def gaussianSmoothing(cls, array, windowSize, sigma):
return cls._safeConvolve2d(array, cls._gaussianKernel(windowSize, sigma))

@classmethod
def _gaussianKernel(cls, windowSize, sigma):
''' Returns 2D gaussian kernel '''
s = sigma
r = windowSize
X, Y = numpy.meshgrid(
numpy.linspace(-r, r, 2 * r + 1),
numpy.linspace(-r, r, 2 * r + 1),
)
kernel = cls._normalDist(X, s) * cls._normalDist(Y, s)
# cut off
kernel[X**2 + Y**2 > (r + 0.5)**2] = 0.
return kernel / kernel.sum()

@staticmethod
def _normalDist(x, s=1., m=0.):
''' Normal Distribution '''
return 1. / (s * numpy.sqrt(2. * numpy.pi)) * numpy.exp(-(x-m)**2/(2*s**2)) / (s * numpy.sqrt(2*numpy.pi))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why two normalisation factors of (s * numpy.sqrt(2. * numpy.pi))?


@staticmethod
def _safeConvolve2d(image, kernel):
''' Convolve 2D safely dealing with NaNs in `image` '''
assert numpy.ndim(image) == 2
assert numpy.ndim(kernel) == 2
assert kernel.shape[0] % 2 == 1 and kernel.shape[1] % 2 == 1
ks = kernel.shape
kl = (ks[0] - 1) // 2, \
(ks[1] - 1) // 2
image2 = numpy.pad(image, ((kl[0], kl[0]), (kl[1], kl[1])), 'constant', constant_values=numpy.nan)
convolved = numpy.empty_like(image)
convolved.fill(numpy.nan)
for yi in range(convolved.shape[0]):
for xi in range(convolved.shape[1]):
patch = image2[yi : yi + ks[0], xi : xi + ks[1]]
c = patch * kernel
ok = numpy.isfinite(c)
if numpy.any(ok):
convolved[yi, xi] = numpy.nansum(c) / kernel[ok].sum()
return convolved
41 changes: 15 additions & 26 deletions python/lsst/pipe/drivers/constructCalibs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from builtins import zip
from builtins import range

from lsst.pex.config import Config, ConfigurableField, Field, ListField, ConfigField
from lsst.pex.config import Config, ConfigurableField, Field, ListField, ConfigField, ConfigurableField
from lsst.pipe.base import Task, Struct, TaskRunner, ArgumentParser
import lsst.daf.base as dafBase
import lsst.afw.math as afwMath
Expand All @@ -26,7 +26,7 @@

from lsst.ctrl.pool.parallel import BatchPoolTask
from lsst.ctrl.pool.pool import Pool, NODE
from lsst.pipe.drivers.background import SkyMeasurementTask, FocalPlaneBackground, FocalPlaneBackgroundConfig
from lsst.pipe.drivers.background import SkyMeasurementTask, FocalPlaneBackground, FocalPlaneBackgroundConfig, MaskObjectsTask

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line lengths are limited to 110 chars.

from lsst.pipe.drivers.visualizeVisit import makeCameraImage

from .checksum import checksum
Expand Down Expand Up @@ -1115,10 +1115,9 @@ def processSingle(self, sensorRef):

class SkyConfig(CalibConfig):
"""Configuration for sky frame construction"""
detection = ConfigurableField(target=measAlg.SourceDetectionTask, doc="Detection configuration")
detectSigma = Field(dtype=float, default=2.0, doc="Detection PSF gaussian sigma")
subtractBackground = ConfigurableField(target=measAlg.SubtractBackgroundTask,
doc="Regular-scale background configuration, for object detection")
maskObjects = ConfigurableField(target=MaskObjectsTask,
doc="Configuration for masking objects aggressively")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need an extra leading space to line up the doc with the previous line.

largeScaleBackground = ConfigField(dtype=FocalPlaneBackgroundConfig,
doc="Large-scale background configuration")
sky = ConfigurableField(target=SkyMeasurementTask, doc="Sky measurement")
Expand All @@ -1145,8 +1144,7 @@ class SkyTask(CalibTask):

def __init__(self, *args, **kwargs):
CalibTask.__init__(self, *args, **kwargs)
self.makeSubtask("detection")
self.makeSubtask("subtractBackground")
self.makeSubtask("maskObjects")
self.makeSubtask("sky")

def scatterProcess(self, pool, ccdIdLists):
Expand Down Expand Up @@ -1230,27 +1228,18 @@ def processSingleBackground(self, dataRef):
return dataRef.get("postISRCCD")
exposure = CalibTask.processSingle(self, dataRef)

# Detect sources. Requires us to remove the background; we'll restore it later.
bgTemp = self.subtractBackground.run(exposure).background
footprints = self.detection.detectFootprints(exposure, sigma=self.config.detectSigma)
image = exposure.getMaskedImage()
if footprints.background is not None:
image += footprints.background.getImage()

# Mask high pixels
variance = image.getVariance()
noise = np.sqrt(np.median(variance.getArray()))
isHigh = image.getImage().getArray() > self.config.maskThresh*noise
image.getMask().getArray()[isHigh] |= image.getMask().getPlaneBitMask("DETECTED")

# Restore the background: it's what we want!
image += bgTemp.getImage()
self.maskObjects.run(exposure)
mi = exposure.maskedImage

# Set detected/bad pixels to background to ensure they don't corrupt the background
maskVal = image.getMask().getPlaneBitMask(self.config.mask)
isBad = image.getMask().getArray() & maskVal > 0
bgLevel = np.median(image.getImage().getArray()[~isBad])
image.getImage().getArray()[isBad] = bgLevel
if self.config.maskObjects.doInterpolate:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks encapsulation: this code shouldn't assume that self.config.maskObjects has a doInterpolate variable.
Maybe the way to deal with this is to move this whole block into the MaskObjectsTask?

mi.mask.array &= ~mi.mask.getPlaneBitMask(['DETECTED'])
else:
maskVal = mi.mask.getPlaneBitMask(self.config.mask)
isBad = mi.mask.array & maskVal > 0
bgLevel = np.median(mi.image.array[~isBad])
mi.image.array[isBad] = bgLevel

dataRef.put(exposure, "postISRCCD")
return exposure

Expand Down
54 changes: 34 additions & 20 deletions python/lsst/pipe/drivers/skyCorrection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from lsst.pex.config import Config, Field, ConfigurableField, ConfigField
from lsst.ctrl.pool.pool import Pool
from lsst.ctrl.pool.parallel import BatchPoolTask
from lsst.pipe.drivers.background import SkyMeasurementTask, FocalPlaneBackground, FocalPlaneBackgroundConfig
from lsst.pipe.drivers.background import SkyMeasurementTask, FocalPlaneBackground, FocalPlaneBackgroundConfig, MaskObjectsTask

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line length.

import lsst.pipe.drivers.visualizeVisit as visualizeVisit

DEBUG = False # Debugging outputs?
Expand Down Expand Up @@ -41,21 +41,21 @@ def makeCameraImage(camera, exposures, filename=None, binning=8):
class SkyCorrectionConfig(Config):
"""Configuration for SkyCorrectionTask"""
bgModel = ConfigField(dtype=FocalPlaneBackgroundConfig, doc="Background model")
bgModel2 = ConfigField(dtype=FocalPlaneBackgroundConfig, doc="2nd Background model")
sky = ConfigurableField(target=SkyMeasurementTask, doc="Sky measurement")
detection = ConfigurableField(target=measAlg.SourceDetectionTask, doc="Detection configuration")
doDetection = Field(dtype=bool, default=True, doc="Detect sources (to find good sky)?")
detectSigma = Field(dtype=float, default=5.0, doc="Detection PSF gaussian sigma")
maskObjects = ConfigurableField(target=MaskObjectsTask, doc="Mask Objects")
doMaskObjects = Field(dtype=bool, default=True, doc="Mask objects to find good sky?")
doBgModel = Field(dtype=bool, default=True, doc="Do background model subtraction?")
doSky = Field(dtype=bool, default=True, doc="Do sky frame subtraction?")
binning = Field(dtype=int, default=8, doc="Binning factor for constructing focal-plane images")

def setDefaults(self):
Config.setDefaults(self)
self.detection.reEstimateBackground = False
self.detection.thresholdPolarity = "both"
self.detection.doTempLocalBackground = False
self.detection.thresholdType = "pixel_stdev"
self.detection.thresholdValue = 3.0
self.maskObjects.doInterpolate = False

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why disable interpolation here? Because you don't want to destroy the image?
It sounds like doInterpolate shouldn't be a configuration parameter, but a functional parameter for MaskObjectsTask.run.

self.bgModel2.doSmooth = True

def validate(self):
assert not self.maskObjects.doInterpolate

class SkyCorrectionTask(BatchPoolTask):
"""Correct sky over entire focal plane"""
Expand All @@ -64,9 +64,8 @@ class SkyCorrectionTask(BatchPoolTask):

def __init__(self, *args, **kwargs):
BatchPoolTask.__init__(self, *args, **kwargs)
self.makeSubtask("maskObjects")
self.makeSubtask("sky")
# Disposable schema suppresses warning from SourceDetectionTask.__init__
self.makeSubtask("detection", schema=afwTable.Schema())

@classmethod
def _makeArgumentParser(cls, *args, **kwargs):
Expand Down Expand Up @@ -102,7 +101,10 @@ def run(self, expRef):
algorithms. We optionally apply:

1. A large-scale background model.
This step removes very-large-scale sky such as moonlight.
2. A sky frame.
3. A medium-scale background model.
This step removes residual sky (This is smooth on the focal plane).

Only the master node executes this method. The data is held on
the slave nodes, which do all the hard work.
Expand Down Expand Up @@ -159,12 +161,30 @@ def run(self, expRef):
calibs = pool.mapToPrevious(self.collectSky, dataIdList)
makeCameraImage(camera, calibs, "sky" + extension)

exposures = self.smoothFocalPlaneSubtraction(camera, pool, dataIdList)

# Persist camera-level image of calexp
image = makeCameraImage(camera, exposures)
expRef.put(image, "calexp_camera")

pool.mapToPrevious(self.write, dataIdList)

def smoothFocalPlaneSubtraction(self, camera, pool, dataIdList):
'''Do 2nd Focal Plane subtraction

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use numpydoc style for docstrings.


After doSky, we get smooth focal plane image.
(Before doSky, sky pistons remain in HSC-G)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is a bit too terse to understand what you mean.

Now make smooth focal plane background and subtract it.
'''
bgModel = FocalPlaneBackground.fromCamera(self.config.bgModel2, camera)
data = [Struct(dataId=dataId, bgModel=bgModel.clone()) for dataId in dataIdList]
bgModelList = pool.mapToPrevious(self.accumulateModel, data)
for ii, bg in enumerate(bgModelList):
self.log.info("Background %d: %d pixels", ii, bg._numbers.array.sum())
bgModel.merge(bg)
exposures = pool.mapToPrevious(self.subtractModel, dataIdList, bgModel)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same code as before (lines 138-143). Can you factor it into a common method so the code isn't duplicated?

return exposures

def loadImage(self, cache, dataId):
"""Load original image and restore the sky

Expand All @@ -187,15 +207,6 @@ def loadImage(self, cache, dataId):
bgOld = cache.butler.get("calexpBackground", dataId, immediate=True)
image = cache.exposure.getMaskedImage()

if self.config.doDetection:
# We deliberately use the specified 'detectSigma' instead of the PSF, in order to better pick up
# the faint wings of objects.
results = self.detection.detectFootprints(cache.exposure, doSmooth=True,
sigma=self.config.detectSigma, clearMask=True)
if hasattr(results, "background") and results.background:
# Restore any background that was removed during detection
maskedImage += results.background.getImage()

# We're removing the old background, so change the sense of all its components
for bgData in bgOld:
statsImage = bgData[0].getStatsImage()
Expand All @@ -206,6 +217,9 @@ def loadImage(self, cache, dataId):
for bgData in bgOld:
cache.bgList.append(bgData)

if self.config.doMaskObjects:
self.maskObjects.run(cache.exposure)

return self.collect(cache)

def measureSkyFrame(self, cache, dataId):
Expand Down