Skip to content

Commit

Permalink
Speed up imports by using importlib instead of pkg_resources
Browse files Browse the repository at this point in the history
Speed up imports by up to a second by replacing uses of `pkg_resources`
with the new Python standard library module `importlib.resources` (or,
for Python < 3.7, the backport `importlib_resources`). The old
`pkg_resources` module is known to be slow because it does a lot of
work on startup.

See, for example,
[pypa/setuptools#926](pypa/setuptools#926) and
[pypa/setuptools#510](pypa/setuptools#510).
  • Loading branch information
lpsinger committed May 4, 2020
1 parent ffc0c12 commit 46374fa
Show file tree
Hide file tree
Showing 32 changed files with 124 additions and 115 deletions.
Empty file added gwcelery/data/__init__.py
Empty file.
Empty file.
File renamed without changes.
File renamed without changes.
6 changes: 3 additions & 3 deletions gwcelery/tasks/em_bright.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from ..import app
from . import gracedb, lvalert
from .p_astro import _format_prob
from ..util import NamedTemporaryFile, PromiseProxy, resource_pickle
from ..util import NamedTemporaryFile, PromiseProxy, read_pickle

NS_CLASSIFIER = PromiseProxy(
resource_pickle, ('ligo.data', 'knn_ns_classifier.pkl'))
read_pickle, ('ligo.data', 'knn_ns_classifier.pkl'))
EM_CLASSIFIER = PromiseProxy(
resource_pickle, ('ligo.data', 'knn_em_classifier.pkl'))
read_pickle, ('ligo.data', 'knn_em_classifier.pkl'))

log = get_task_logger(__name__)

Expand Down
11 changes: 5 additions & 6 deletions gwcelery/tasks/first2years.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Create mock events from the "First Two Years" paper."""
from importlib import resources
import io
import random

Expand All @@ -9,8 +10,8 @@
import lal
from ligo.skymap.io.events.ligolw import ContentHandler
import numpy as np
import pkg_resources

from ..data import first2years as data_first2years
from ..import app
from . import gracedb

Expand All @@ -19,9 +20,8 @@

def pick_coinc():
"""Pick a coincidence from the "First Two Years" paper."""
filename = pkg_resources.resource_filename(
__name__, '../data/first2years/2016/gstlal.xml.gz')
xmldoc = utils.load_filename(filename, contenthandler=ContentHandler)
with resources.open_binary(data_first2years, 'gstlal.xml.gz') as f:
xmldoc, _ = utils.load_fileobj(f, contenthandler=ContentHandler)
root, = xmldoc.childNodes

# Remove unneeded tables
Expand Down Expand Up @@ -139,8 +139,7 @@ def _vet_event(superevents):

@gracedb.task(ignore_result=True, shared=False)
def _upload_psd(graceid):
psd = pkg_resources.resource_string(
__name__, '../data/first2years/2016/psd.xml.gz')
psd = resources.read_binary(data_first2years, 'psd.xml.gz')
gracedb.upload(psd, 'psd.xml.gz', graceid, 'Noise PSD', ['psd'])


Expand Down
10 changes: 4 additions & 6 deletions gwcelery/tasks/p_astro.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,16 @@

from . import gracedb, lvalert
from .. import app
from ..util import PromiseProxy, resource_json
from ..util import PromiseProxy, read_json

MEAN_VALUES_DICT = PromiseProxy(
resource_json, ('ligo.data',
'H1L1V1-mean_counts-1126051217-61603201.json'))
read_json, ('ligo.data', 'H1L1V1-mean_counts-1126051217-61603201.json'))

THRESHOLDS_DICT = PromiseProxy(
resource_json, ('ligo.data',
'H1L1V1-pipeline-far_snr-thresholds.json'))
read_json, ('ligo.data', 'H1L1V1-pipeline-far_snr-thresholds.json'))

P_ASTRO_LIVETIME = PromiseProxy(
resource_json, ('ligo.data', 'p_astro_livetime.json'))
read_json, ('ligo.data', 'p_astro_livetime.json'))


log = get_task_logger(__name__)
Expand Down
4 changes: 2 additions & 2 deletions gwcelery/templates/index.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,8 @@
</div>
<div class=card-body>
<ul>
{% for package in packages %}
<li><a href="https://pypi.org/project/{{package.project_name}}/">{{package.project_name}}</a> {{package.version}}</li>
{% for distribution in distributions %}
<li><a href="https://pypi.org/project/{{distribution.metadata['Name']}}/">{{distribution.metadata['Name']}}</a> {{distribution.version}}</li>
{% endfor %}
</ul>
</div>
Expand Down
Empty file added gwcelery/tests/data/__init__.py
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
9 changes: 5 additions & 4 deletions gwcelery/tests/test_tasks_bayestar.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
from importlib import resources
from unittest.mock import patch
from xml.sax import SAXParseException

from astropy import table
from astropy.io import fits
from celery.exceptions import Ignore
import numpy as np
import pkg_resources
import pytest

from . import data
from ..tasks.bayestar import localize
from ..util.tempfile import NamedTemporaryFile


def test_localize_bad_psd():
"""Test running BAYESTAR with a pad PSD file"""
# Test data
coinc = pkg_resources.resource_string(__name__, 'data/coinc.xml')
coinc = resources.read_binary(data, 'coinc.xml')
psd = b''

# Run function under test
Expand All @@ -37,8 +38,8 @@ def mock_bayestar(event, *args, **kwargs):

@pytest.fixture
def coinc_psd():
return (pkg_resources.resource_string(__name__, 'data/coinc.xml'),
pkg_resources.resource_string(__name__, 'data/psd.xml.gz'))
return (resources.read_binary(data, 'coinc.xml'),
resources.read_binary(data, 'psd.xml.gz'))


@patch('ligo.skymap.bayestar.localize', mock_bayestar)
Expand Down
22 changes: 13 additions & 9 deletions gwcelery/tests/test_tasks_detchar.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from importlib import resources
from io import BytesIO
import logging
from unittest.mock import call, patch
Expand All @@ -6,29 +7,29 @@
from gwpy.timeseries import Bits
import matplotlib.pyplot as plt
import numpy as np
from pkg_resources import resource_filename
import pytest

from ..import app
from ..import _version
from ..tasks import detchar
from . import data


@pytest.fixture
def llhoft_glob_pass():
old = app.conf['llhoft_glob']
app.conf['llhoft_glob'] = resource_filename(
__name__, 'data/llhoft/pass/{detector}/*.gwf')
yield
with resources.path(data, '') as path:
app.conf['llhoft_glob'] = str(path / 'llhoft/pass/{detector}/*.gwf')
yield
app.conf['llhoft_glob'] = old


@pytest.fixture
def llhoft_glob_fail():
old = app.conf['llhoft_glob']
app.conf['llhoft_glob'] = resource_filename(
__name__, 'data/llhoft/fail/{detector}/*.gwf')
yield
with resources.path(data, '') as path:
app.conf['llhoft_glob'] = str(path / 'llhoft/fail/{detector}/*.gwf')
yield
app.conf['llhoft_glob'] = old


Expand Down Expand Up @@ -86,8 +87,11 @@ def test_create_cache_old_data(mock_find, llhoft_glob_fail):
mock_find.assert_called()


@patch('gwcelery.tasks.detchar.create_cache', return_value=[resource_filename(
__name__, 'data/llhoft/omegascan/scanme.gwf')])
with resources.path(data, '') as expected_path:
expected_path = str(expected_path / 'llhoft/omegascan/scanme.gwf')


@patch('gwcelery.tasks.detchar.create_cache', return_value=[expected_path])
def test_make_omegascan_worked(mock_create_cache, scan_strainname):
durs = [1, 1, 1]
t0 = 1126259463
Expand Down
16 changes: 8 additions & 8 deletions gwcelery/tests/test_tasks_external_skymaps.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from importlib import resources
from unittest.mock import patch

import pkg_resources
import pytest

from ..util import resource_json
from . import data
from ..util import read_json
from .test_tasks_skymaps import toy_fits_filecontents # noqa: F401
from .test_tasks_skymaps import toy_3d_fits_filecontents # noqa: F401
from ..tasks import external_skymaps
Expand All @@ -19,14 +20,14 @@ def mock_get_event(exttrig):


def mock_get_superevent(graceid):
return resource_json(__name__, 'data/mock_superevent_object.json')
return read_json(data, 'mock_superevent_object.json')


def mock_get_log(graceid):
if graceid == 'S12345':
return resource_json(__name__, 'data/gracedb_setrigger_log.json')
return read_json(data, 'gracedb_setrigger_log.json')
elif graceid == 'E12345':
return resource_json(__name__, 'data/gracedb_externaltrigger_log.json')
return read_json(data, 'gracedb_externaltrigger_log.json')
else:
raise ValueError

Expand All @@ -41,9 +42,8 @@ def download(filename, graceid):
elif (graceid == 'E12345' and
filename == ('nasa.gsfc.gcn_Fermi%23GBM_Gnd_Pos_2017-08-17'
+ 'T12%3A41%3A06.47_524666471_57-431.xml')):
return pkg_resources.resource_string(
__name__, 'data/externaltrigger_original_data.xml'
)
return resources.read_binary(
data, 'externaltrigger_original_data.xml')
else:
raise ValueError

Expand Down
48 changes: 23 additions & 25 deletions gwcelery/tests/test_tasks_external_triggers.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from importlib.resources import read_binary
from unittest.mock import patch, call

import pytest

from pkg_resources import resource_string

from . import data
from ..tasks import external_triggers
from ..tasks import detchar
from ..util import resource_json
from ..util import read_json


@pytest.mark.parametrize('pipeline, path',
[['Fermi', 'data/fermi_grb_gcn.xml'],
['INTEGRAL', 'data/integral_grb_gcn.xml'],
['AGILE', 'data/agile_grb_gcn.xml']])
[['Fermi', 'fermi_grb_gcn.xml'],
['INTEGRAL', 'integral_grb_gcn.xml'],
['AGILE', 'agile_grb_gcn.xml']])
@patch('gwcelery.tasks.external_skymaps.create_upload_external_skymap')
@patch('gwcelery.tasks.external_skymaps.get_upload_external_skymap.run')
@patch('gwcelery.tasks.detchar.dqr_json', return_value='dqrjson')
Expand All @@ -29,7 +29,7 @@ def test_handle_create_grb_event(mock_create_event, mock_get_event,
mock_get_upload_external_skymap,
mock_create_upload_external_skymap,
pipeline, path):
text = resource_string(__name__, path)
text = read_binary(data, path)
external_triggers.handle_grb_gcn(payload=text)
mock_create_event.assert_called_once_with(filecontents=text,
search='GRB',
Expand Down Expand Up @@ -98,11 +98,10 @@ def test_handle_create_subthreshold_grb_event(mock_get_upload_ext_skymap,
mock_create_event,
mock_get_event,
mock_get_events):
text = resource_string(__name__,
'data/fermi_subthresh_grb_lowconfidence.xml')
text = read_binary(data, 'fermi_subthresh_grb_lowconfidence.xml')
external_triggers.handle_grb_gcn(payload=text)
mock_create_event.assert_not_called()
text = resource_string(__name__, 'data/fermi_subthresh_grb_gcn.xml')
text = read_binary(data, 'fermi_subthresh_grb_gcn.xml')
external_triggers.handle_grb_gcn(payload=text)
mock_get_events.assert_called_once_with(query=(
'group: External pipeline: '
Expand Down Expand Up @@ -136,7 +135,7 @@ def test_handle_noise_fermi_event(mock_check_vectors,
mock_get_event,
mock_get_events,
mock_get_upload_external_skymap):
text = resource_string(__name__, 'data/fermi_noise_gcn.xml')
text = read_binary(data, 'fermi_noise_gcn.xml')
external_triggers.handle_grb_gcn(payload=text)
mock_get_events.assert_called_once_with(query=(
'group: External pipeline: '
Expand All @@ -153,8 +152,8 @@ def test_handle_noise_fermi_event(mock_check_vectors,


@pytest.mark.parametrize('filename',
['data/fermi_grb_gcn.xml',
'data/fermi_noise_gcn.xml'])
['fermi_grb_gcn.xml',
'fermi_noise_gcn.xml'])
@patch('gwcelery.tasks.external_skymaps.get_upload_external_skymap.run')
@patch('gwcelery.tasks.gracedb.create_label')
@patch('gwcelery.tasks.gracedb.remove_label')
Expand All @@ -175,7 +174,7 @@ def test_handle_replace_grb_event(mock_get_event, mock_get_events,
mock_replace_event, mock_remove_label,
mock_create_label,
mock_get_upload_external_skymap, filename):
text = resource_string(__name__, filename)
text = read_binary(data, filename)
external_triggers.handle_grb_gcn(payload=text)
mock_replace_event.assert_called_once_with('E1', text)
if 'grb' in filename:
Expand Down Expand Up @@ -278,7 +277,7 @@ def test_handle_skymap_combine(mock_create_combined_skymap):
@patch('gwcelery.tasks.gracedb.create_event')
def test_handle_create_snews_event(mock_create_event, mock_get_event,
mock_upload, mock_json):
text = resource_string(__name__, 'data/snews_gcn.xml')
text = read_binary(data, 'snews_gcn.xml')
external_triggers.handle_snews_gcn(payload=text)
mock_create_event.assert_called_once_with(filecontents=text,
search='Supernova',
Expand Down Expand Up @@ -310,7 +309,7 @@ def test_handle_create_snews_event(mock_create_event, mock_get_event,
@patch('gwcelery.tasks.gracedb.replace_event')
@patch('gwcelery.tasks.gracedb.get_events', return_value=[{'graceid': 'E1'}])
def test_handle_replace_snews_event(mock_get_events, mock_replace_event):
text = resource_string(__name__, 'data/snews_gcn.xml')
text = read_binary(data, 'snews_gcn.xml')
external_triggers.handle_snews_gcn(payload=text)
mock_replace_event.assert_called_once_with('E1', text)

Expand All @@ -319,7 +318,7 @@ def test_handle_replace_snews_event(mock_get_events, mock_replace_event):
def test_handle_grb_exttrig_creation(mock_raven_coincidence_search):
"""Test dispatch of an LVAlert message for an exttrig creation."""
# Test LVAlert payload.
alert = resource_json(__name__, 'data/lvalert_exttrig_creation.json')
alert = read_json(data, 'lvalert_exttrig_creation.json')

# Run function under test
external_triggers.handle_grb_lvalert(alert)
Expand All @@ -334,7 +333,7 @@ def test_handle_grb_exttrig_creation(mock_raven_coincidence_search):
def test_handle_subgrb_exttrig_creation(mock_raven_coincidence_search):
"""Test dispatch of an LVAlert message for an exttrig creation."""
# Test LVAlert payload.
alert = resource_json(__name__, 'data/lvalert_subgrb_creation.json')
alert = read_json(data, 'lvalert_subgrb_creation.json')

# Run function under test
external_triggers.handle_grb_lvalert(alert)
Expand All @@ -352,8 +351,7 @@ def test_handle_subgrb_targeted_creation(mock_raven_coincidence_search,
mock_create_upload_external_skymap):
"""Test dispatch of an LVAlert message for an exttrig creation."""
# Test LVAlert payload.
alert = resource_json(__name__,
'data/lvalert_exttrig_subgrb_targeted_creation.json')
alert = read_json(data, 'lvalert_exttrig_subgrb_targeted_creation.json')

# Run function under test
external_triggers.handle_grb_lvalert(alert)
Expand All @@ -371,13 +369,13 @@ def test_handle_subgrb_targeted_creation(mock_raven_coincidence_search,


@pytest.mark.parametrize('calls, path',
[[False, 'data/lvalert_snews_test_creation.json'],
[True, 'data/lvalert_snews_creation.json']])
[[False, 'lvalert_snews_test_creation.json'],
[True, 'lvalert_snews_creation.json']])
@patch('gwcelery.tasks.raven.coincidence_search')
def test_handle_sntrig_creation(mock_raven_coincidence_search, calls, path):
"""Test dispatch of an LVAlert message for SNEWS alerts."""
# Test LVAlert payload.
alert = resource_json(__name__, path)
alert = read_json(data, path)

# Run function under test
external_triggers.handle_snews_lvalert(alert)
Expand All @@ -399,7 +397,7 @@ def test_handle_superevent_cbc_creation(mock_raven_coincidence_search,
mock_get_superevent):
"""Test dispatch of an LVAlert message for a CBC superevent creation."""
# Test LVAlert payload.
alert = resource_json(__name__, 'data/lvalert_superevent_creation.json')
alert = read_json(data, 'lvalert_superevent_creation.json')

# Run function under test
external_triggers.handle_grb_lvalert(alert)
Expand All @@ -423,7 +421,7 @@ def test_handle_superevent_burst_creation(mock_raven_coincidence_search,
mock_get_superevent):
"""Test dispatch of an LVAlert message for a burst superevent creation."""
# Test LVAlert payload.
alert = resource_json(__name__, 'data/lvalert_superevent_creation.json')
alert = read_json(data, 'lvalert_superevent_creation.json')

# Run function under test
external_triggers.handle_grb_lvalert(alert)
Expand Down
Loading

0 comments on commit 46374fa

Please sign in to comment.