Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix task matching bug #4881

Merged
merged 3 commits into from
May 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,36 @@ creating a new release entry be sure to copy & paste the span tag with the
`actions:bind` attribute, which is used by a regex to find the text to be
updated. Only the first match gets replaced, so it's fine to leave the old
ones in. -->
-------------------------------------------------------------------------------
## __cylc-8.0rc4 (<span actions:bind='release-date'>Upcoming</span>)__

Fourth Release Candidate for Cylc 8 suitable for acceptance testing.

### Fixes

[#4881](https://github.com/cylc/cylc-flow/pull/4881) - Fix bug where commands
targeting a specific cycle point would not work if using an abbreviated
cycle point format.

-------------------------------------------------------------------------------
## __cylc-8.0rc3 (<span actions:bind='release-date'>Released 2022-05-19</span>)__

Third Release Candidate for Cylc 8 suitable for acceptance testing.

### Enhancements


[#4738](https://github.com/cylc/cylc-flow/pull/4738) and
[#4739](https://github.com/cylc/cylc-flow/pull/4739) - Implement `cylc trigger
[--flow=] [--wait]` for manual triggering with respect to active flows (the
default), specific flows, new flows, or one-off task runs.
default), specific flows, new flows, or one-off task runs. This replaces
the `--reflow` option from earlier pre-release versions.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


[#4743](https://github.com/cylc/cylc-flow/pull/4743) - On stopping a specific
flow, remove active-waiting tasks with no remaining flow numbers.


[#4854](https://github.com/cylc/cylc-flow/pull/4854)
- Expansion and merger of comma separate platform definitions permitted.
- Expansion and merger of comma separated platform definitions permitted.
- Platform definition regular expressions which match "localhost" but are not
"localhost" are now explicitly forbidden and will raise an exception.

Expand Down
47 changes: 30 additions & 17 deletions cylc/flow/id_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
TYPE_CHECKING,
Expand All @@ -26,9 +27,12 @@
# overload,
)

from metomi.isodatetime.exceptions import ISO8601SyntaxError

from cylc.flow import LOG
from cylc.flow.id import IDTokens, Tokens
from cylc.flow.id_cli import contains_fnmatch
from cylc.flow.cycling.loader import get_point

if TYPE_CHECKING:
# from typing_extensions import Literal
Expand Down Expand Up @@ -109,7 +113,7 @@ def filter_ids(
_not_matched: 'List[str]' = []

# enable / disable pattern matching
match: 'Callable[[Any, Any], bool]'
match: Callable[[Any, Any], bool]
if pattern_match:
match = fnmatchcase
else:
Expand All @@ -128,7 +132,7 @@ def filter_ids(
]
_not_matched.extend(pattern_ids)

id_tokens_map = {}
id_tokens_map: Dict[str, Tokens] = {}
for id_ in ids:
try:
id_tokens_map[id_] = Tokens(id_, relative=True)
Expand All @@ -152,8 +156,7 @@ def filter_ids(
for icycle, itasks in pool.items():
if not itasks:
continue
str_cycle = str(icycle)
if not match(str_cycle, cycle):
if not point_match(icycle, cycle, pattern_match):
continue
if cycle_sel == '*':
cycles.append(icycle)
Expand All @@ -164,7 +167,7 @@ def filter_ids(
break

# filter by task
elif lowest_token == IDTokens.Task: # noqa: SIM106
elif lowest_token == IDTokens.Task:
cycle = tokens[IDTokens.Cycle.value]
cycle_sel_raw = tokens.get(IDTokens.Cycle.value + '_sel')
cycle_sel = cycle_sel_raw or '*'
Expand All @@ -173,8 +176,7 @@ def filter_ids(
task_sel = task_sel_raw or '*'
for pool in pools:
for icycle, itasks in pool.items():
str_cycle = str(icycle)
if not match(str_cycle, cycle):
if not point_match(icycle, cycle, pattern_match):
continue
for itask in itasks.values():
if (
Expand All @@ -189,15 +191,7 @@ def filter_ids(
or match(itask.state.status, cycle_sel)
)
# check namespace name
and (
# task name
match(itask.tdef.name, task)
# family name
or any(
match(ns, task)
for ns in itask.tdef.namespace_hierarchy
)
)
and itask.name_match(task, match_func=match)
# check task selector
and (
(
Expand All @@ -222,7 +216,7 @@ def filter_ids(
_cycles.extend(cycles)
_tasks.extend(tasks)

ret: 'List[Any]' = []
ret: List[Any] = []
if out == IDTokens.Cycle:
_cycles.extend({
itask.point
Expand All @@ -236,3 +230,22 @@ def filter_ids(
_tasks.extend(pool[icycle].values())
ret = _tasks
return ret, _not_matched


def point_match(
point: 'PointBase', value: str, pattern_match: bool = True
) -> bool:
"""Return whether a cycle point matches a string/pattern.

Args:
point: Cycle point to compare against.
value: String/pattern to test.
pattern_match: Whether to allow glob patterns in the value.
"""
try:
return point == get_point(value)
except (ValueError, ISO8601SyntaxError):
# Could be glob pattern
if pattern_match:
return fnmatchcase(str(point), value)
return False
33 changes: 12 additions & 21 deletions cylc/flow/task_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,16 @@
"""Provide a class to represent a task proxy in a running workflow."""

from collections import Counter
from contextlib import suppress
from copy import copy
from fnmatch import fnmatchcase
from time import time
from typing import Any, Dict, List, Set, Tuple, Optional, TYPE_CHECKING
from typing import (
Any, Callable, Dict, List, Set, Tuple, Optional, TYPE_CHECKING
)

from metomi.isodatetime.timezone import get_local_time_zone

from cylc.flow import LOG
from cylc.flow.cycling.loader import standardise_point_string
from cylc.flow.exceptions import PointParsingError
from cylc.flow.id import Tokens
from cylc.flow.platforms import get_platform
from cylc.flow.task_action_timer import TimerFlags
Expand Down Expand Up @@ -420,30 +419,22 @@ def reset_try_timers(self):
for timer in self.try_timers.values():
timer.timeout = None

def point_match(self, point: Optional[str]) -> bool:
"""Return whether a string/glob matches the task's point.

None is treated as '*'.
"""
if point is None:
return True
with suppress(PointParsingError): # point_str may be a glob
point = standardise_point_string(point)
return fnmatchcase(str(self.point), point)

def status_match(self, status: Optional[str]) -> bool:
"""Return whether a string matches the task's status.

None/an empty string is treated as a match.
"""
return (not status) or self.state.status == status

def name_match(self, name: str) -> bool:
"""Return whether a string/glob matches the task's name."""
if fnmatchcase(self.tdef.name, name):
return True
return any(
fnmatchcase(ns, name) for ns in self.tdef.namespace_hierarchy
def name_match(
self,
value: str,
match_func: Callable[[Any, Any], bool] = fnmatchcase
) -> bool:
"""Return whether a string/pattern matches the task's name or any of
its parent family names."""
return match_func(self.tdef.name, value) or any(
match_func(ns, value) for ns in self.tdef.namespace_hierarchy
)

def merge_flows(self, flow_nums: Set) -> None:
Expand Down
80 changes: 57 additions & 23 deletions tests/unit/test_id_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,51 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from types import SimpleNamespace
from typing import TYPE_CHECKING, Callable
from unittest.mock import create_autospec

import pytest

from cylc.flow.id import IDTokens, Tokens
from cylc.flow.id_match import filter_ids
from cylc.flow.id_match import filter_ids, point_match
from cylc.flow.task_pool import Pool
from cylc.flow.cycling.integer import IntegerPoint, CYCLER_TYPE_INTEGER
from cylc.flow.cycling.iso8601 import ISO8601Point
from cylc.flow.task_proxy import TaskProxy
from cylc.flow.taskdef import TaskDef

if TYPE_CHECKING:
from cylc.flow.cycling import PointBase


def get_task_id(itask: TaskProxy) -> str:
return f"{itask.tokens.relative_id}:{itask.state.status}"


@pytest.fixture
def task_pool():
def task_pool(set_cycling_type: Callable):
def _task_proxy(id_, hier):
tokens = Tokens(id_, relative=True)
itask = SimpleNamespace()
itask.id_ = id_
itask.point = int(tokens['cycle'])
itask.state = SimpleNamespace()
itask.state.status = tokens['task_sel']
itask.tdef = SimpleNamespace()
itask.tdef.name = tokens['task']
if tokens['task'] in hier:
hier = hier[tokens['task']]
else:
hier = []
hier = hier.get(tokens['task'], [])
hier.append('root')
itask.tdef.namespace_hierarchy = hier
return itask
tdef = create_autospec(TaskDef, namespace_hierarchy=hier)
tdef.name = tokens['task']
return TaskProxy(
tdef,
start_point=IntegerPoint(tokens['cycle']),
status=tokens['task_sel'],
)

def _task_pool(pool, hier) -> 'Pool':
return {
cycle: {
IntegerPoint(cycle): {
id_.split(':')[0]: _task_proxy(id_, hier)
for id_ in ids
}
for cycle, ids in pool.items()
}

set_cycling_type(CYCLER_TYPE_INTEGER)
return _task_pool


Expand Down Expand Up @@ -119,7 +127,7 @@ def test_filter_ids_task_mode(task_pool, ids, matched, not_matched):
)

_matched, _not_matched = filter_ids([pool], ids)
assert [itask.id_ for itask in _matched] == matched
assert [get_task_id(itask) for itask in _matched] == matched
assert _not_matched == not_matched


Expand Down Expand Up @@ -180,7 +188,7 @@ def test_filter_ids_cycle_mode(task_pool, ids, matched, not_matched):
)

_matched, _not_matched = filter_ids([pool], ids, out=IDTokens.Cycle)
assert _matched == matched
assert _matched == [IntegerPoint(i) for i in matched]
assert _not_matched == not_matched


Expand Down Expand Up @@ -212,7 +220,7 @@ def test_filter_ids_pattern_match_off(task_pool):
out=IDTokens.Task,
pattern_match=False,
)
assert [itask.id_ for itask in _matched] == ['1/a:x']
assert [get_task_id(itask) for itask in _matched] == ['1/a:x']
assert _not_matched == []


Expand All @@ -234,7 +242,7 @@ def test_filter_ids_toggle_pattern_matching(task_pool, caplog):
out=IDTokens.Task,
pattern_match=True,
)
assert [itask.id_ for itask in _matched] == ['1/a:x']
assert [get_task_id(itask) for itask in _matched] == ['1/a:x']
assert _not_matched == []

# ensure pattern matching can be disabled
Expand All @@ -245,7 +253,7 @@ def test_filter_ids_toggle_pattern_matching(task_pool, caplog):
out=IDTokens.Task,
pattern_match=False,
)
assert [itask.id_ for itask in _matched] == []
assert [get_task_id(itask) for itask in _matched] == []
assert _not_matched == ['*/*']

# ensure the ID is logged
Expand Down Expand Up @@ -281,7 +289,7 @@ def test_filter_ids_namespace_hierarchy(task_pool, ids, matched, not_matched):
pattern_match=False,
)

assert [itask.id_ for itask in _matched] == matched
assert [get_task_id(itask) for itask in _matched] == matched
assert _not_matched == not_matched


Expand All @@ -295,3 +303,29 @@ def test_filter_ids_log_errors(caplog):
_, _not_matched = filter_ids({}, ['/////'])
assert _not_matched == ['/////']
assert caplog.record_tuples == [('cylc', 30, 'Invalid ID: /////')]


@pytest.mark.parametrize(
'point, value, pattern_match, expected',
[
(IntegerPoint(23), '23', True, True),
(IntegerPoint(23), '23', False, True),
(IntegerPoint(23), '2*', True, True),
(IntegerPoint(23), '2*', False, False),
(IntegerPoint(23), '2a', True, False),
(ISO8601Point('2049-01-01T00:00Z'), '2049', True, True),
(ISO8601Point('2049-01-01T00:00Z'), '2049', False, True),
(ISO8601Point('2049-03-01T00:00Z'), '2049', True, False),
(ISO8601Point('2049-03-01T00:00Z'), '2049*', True, True),
(ISO8601Point('2049-03-01T00:00Z'), '2049*', False, False),
(ISO8601Point('2049-01-01T00:00Z'), '20490101T00Z', False, True),
(ISO8601Point('2049-01-01T00:00Z'), '20490101T03+03', False, True),
(ISO8601Point('2049-01-01T00:00Z'), '2049a', True, False),
]
)
def test_point_match(
point: 'PointBase', value: str, pattern_match: bool, expected: bool,
set_cycling_type: Callable
):
set_cycling_type(point.TYPE, time_zone='Z')
assert point_match(point, value, pattern_match) is expected
Loading