Skip to content

Commit

Permalink
hydra: support plugins (iterative#10240)
Browse files Browse the repository at this point in the history
* hydra: support plugins

* Update dvc/utils/hydra.py

* fix tests

* load plugins after some checks run

---------

Co-authored-by: Ruslan Kuprieiev <[email protected]>
Co-authored-by: Saugat Pachhai (सौगात) <[email protected]>
Co-authored-by: skshetry <[email protected]>
  • Loading branch information
4 people authored and BradyJ27 committed Apr 22, 2024
1 parent 253d18f commit 4355cc8
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 4 deletions.
1 change: 1 addition & 0 deletions dvc/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ def __call__(self, data):
Exclusive("config_dir", "config_source"): str,
Exclusive("config_module", "config_source"): str,
"config_name": str,
"plugins_path": str,
},
"studio": {
"token": str,
Expand Down
4 changes: 4 additions & 0 deletions dvc/repo/experiments/queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,11 +486,15 @@ def _update_params(self, params: Dict[str, List[str]]):
else:
config_dir = None
config_name = hydra_config.get("config_name", "config")
plugins_path = os.path.join(
self.repo.root_dir, hydra_config.get("plugins_path", "")
)
compose_and_dump(
path,
config_dir,
config_module,
config_name,
plugins_path,
overrides,
)
else:
Expand Down
15 changes: 15 additions & 0 deletions dvc/utils/hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,24 @@
logger = logger.getChild(__name__)


def load_hydra_plugins(plugins_path: str):
import sys

from hydra.core.plugins import Plugins

sys.path.append(plugins_path)
try:
Plugins.instance()
finally:
sys.path.remove(plugins_path)


def compose_and_dump(
output_file: "StrPath",
config_dir: Optional[str],
config_module: Optional[str],
config_name: str,
plugins_path: str,
overrides: List[str],
) -> None:
"""Compose Hydra config and dumpt it to `output_file`.
Expand All @@ -30,6 +43,7 @@ def compose_and_dump(
Ignored if `config_dir` is not `None`.
config_name: Name of the config file containing defaults,
without the .yaml extension.
plugins_path: Path to auto discover Hydra plugins.
overrides: List of `Hydra Override`_ patterns.
.. _Hydra Override:
Expand All @@ -47,6 +61,7 @@ def compose_and_dump(
initialize_config_dir if config_dir else initialize_config_module
)

load_hydra_plugins(plugins_path)
with initialize_config( # type: ignore[attr-defined]
config_source, version_base=None
):
Expand Down
38 changes: 34 additions & 4 deletions tests/func/utils/test_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ def test_compose_and_dump_overrides(tmp_dir, suffix, overrides, expected):
output_file = tmp_dir / f"params.{suffix}"
config_dir = hydra_setup(tmp_dir, "conf", "config")
config_module = None
compose_and_dump(output_file, config_dir, config_module, config_name, overrides)
compose_and_dump(
output_file, config_dir, config_module, config_name, str(tmp_dir), overrides
)
assert output_file.parse() == expected


Expand Down Expand Up @@ -229,7 +231,9 @@ def test_compose_and_dump_dir_module(
)

with error_context:
compose_and_dump(output_file, config_dir, config_module, config_name, [])
compose_and_dump(
output_file, config_dir, config_module, config_name, str(tmp_dir), []
)
assert output_file.parse() == config_content


Expand All @@ -241,7 +245,7 @@ def test_compose_and_dump_yaml_handles_string(tmp_dir):
config.parent.mkdir()
config.write_text("foo: 'no'\n")
output_file = tmp_dir / "params.yaml"
compose_and_dump(output_file, str(config.parent), None, "config", [])
compose_and_dump(output_file, str(config.parent), None, "config", str(tmp_dir), [])
assert output_file.read_text() == "foo: 'no'\n"


Expand All @@ -253,12 +257,38 @@ def test_compose_and_dump_resolves_interpolation(tmp_dir):
config.parent.mkdir()
config.dump({"data": {"root": "path/to/root", "raw": "${.root}/raw"}})
output_file = tmp_dir / "params.yaml"
compose_and_dump(output_file, str(config.parent), None, "config", [])
compose_and_dump(output_file, str(config.parent), None, "config", str(tmp_dir), [])
assert output_file.parse() == {
"data": {"root": "path/to/root", "raw": "path/to/root/raw"}
}


def test_compose_and_dump_plugins(tmp_dir):
"""Ensure Hydra plugins are loaded."""
from hydra.core.plugins import Plugins

from dvc.utils.hydra import compose_and_dump

# clear cached plugins
Plugins._instances.pop(Plugins, None)

config = tmp_dir / "conf" / "config.yaml"
config.parent.mkdir()
config.write_text("foo: '${plus_10:1}'\n")

plugins = tmp_dir / "hydra_plugins"
plugins.mkdir()
(plugins / "resolver.py").write_text(
"""\
from omegaconf import OmegaConf
OmegaConf.register_new_resolver('plus_10', lambda x: x + 10)"""
)

output_file = tmp_dir / "params.yaml"
compose_and_dump(output_file, str(config.parent), None, "config", str(tmp_dir), [])
assert output_file.read_text() == "foo: 11\n"


@pytest.mark.parametrize(
"overrides, expected",
[
Expand Down

0 comments on commit 4355cc8

Please sign in to comment.