From 5a16a5ac2d43b2c08eea2537fcee7c457e1e94fa Mon Sep 17 00:00:00 2001 From: you-n-g Date: Fri, 15 Oct 2021 13:50:24 +0800 Subject: [PATCH] Auto injecting model and dataset for Recorder (#645) * Auto injecting model and dataset for Recorder * Support using Feature in expression --- qlib/data/ops.py | 5 +++-- qlib/model/trainer.py | 21 +++++++++++++++------ qlib/utils/__init__.py | 21 ++++++++++++++++++--- 3 files changed, 36 insertions(+), 11 deletions(-) diff --git a/qlib/data/ops.py b/qlib/data/ops.py index 532072f89d..fc69e2e2f4 100644 --- a/qlib/data/ops.py +++ b/qlib/data/ops.py @@ -13,7 +13,7 @@ from typing import Union, List, Type from scipy.stats import percentileofscore -from .base import Expression, ExpressionOps +from .base import Expression, ExpressionOps, Feature from ..log import get_module_logger from ..utils import get_callable_kwargs @@ -1485,6 +1485,7 @@ def __init__(self, feature_left, feature_right, N): IdxMax, IdxMin, If, + Feature, ] @@ -1517,7 +1518,7 @@ def register(self, ops_list: List[Union[Type[ExpressionOps], dict]]): else: _ops_class = _operator - if not issubclass(_ops_class, ExpressionOps): + if not issubclass(_ops_class, Expression): raise TypeError("operator must be subclass of ExpressionOps, not {}".format(_ops_class)) if _ops_class.__name__ in self._ops: diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 9c6866823f..1e2b25eab5 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -70,9 +70,9 @@ def fill_placeholder(config: dict, config_extend: dict): # bfs top = 0 tail = 1 - item_quene = [config] + item_queue = [config] while top < tail: - now_item = item_quene[top] + now_item = item_queue[top] top += 1 if isinstance(now_item, list): item_keys = range(len(now_item)) @@ -80,9 +80,9 @@ def fill_placeholder(config: dict, config_extend: dict): item_keys = now_item.keys() for key in item_keys: if isinstance(now_item[key], list) or isinstance(now_item[key], dict): - item_quene.append(now_item[key]) + item_queue.append(now_item[key]) tail += 1 - elif now_item[key] in config_extend.keys(): + elif isinstance(now_item[key], str) and now_item[key] in config_extend.keys(): now_item[key] = config_extend[now_item[key]] return config @@ -114,10 +114,19 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder: task_config = fill_placeholder(task_config, placehorder_value) # generate records: prediction, backtest, and analysis records = task_config.get("record", []) - if isinstance(records, dict): # prevent only one dict + if isinstance(records, dict): # uniform the data format to list records = [records] + for record in records: - r = init_instance_by_config(record, recorder=rec) + # Some recorder require the parameter `model` and `dataset`. + # try to automatically pass in them to the initialization function + # to make defining the tasking easier + r = init_instance_by_config( + record, + recorder=rec, + default_module="qlib.workflow.record_temp", + try_kwargs={"model": model, "dataset": dataset}, + ) r.generate() return rec diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 6a3f871d97..f6a6632ead 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -27,7 +27,7 @@ import numpy as np import pandas as pd from pathlib import Path -from typing import Union, Tuple, Any, Text, Optional +from typing import Dict, Union, Tuple, Any, Text, Optional from types import ModuleType from urllib.parse import urlparse @@ -232,7 +232,11 @@ def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, Mod def init_instance_by_config( - config: Union[str, dict, object], default_module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs + config: Union[str, dict, object], + default_module=None, + accept_types: Union[type, Tuple[type]] = (), + try_kwargs: Dict = {}, + **kwargs, ) -> Any: """ get initialized instance with config @@ -270,6 +274,10 @@ def init_instance_by_config( Optional. If the config is a instance of specific type, return the config directly. This will be passed into the second parameter of isinstance. + try_kwargs: Dict + Try to pass in kwargs in `try_kwargs` when initialized the instance + If error occurred, it will fail back to initialization without try_kwargs. + Returns ------- object: @@ -286,7 +294,14 @@ def init_instance_by_config( return pickle.load(f) klass, cls_kwargs = get_callable_kwargs(config, default_module=default_module) - return klass(**cls_kwargs, **kwargs) + + try: + return klass(**cls_kwargs, **try_kwargs, **kwargs) + except (TypeError,): + # TypeError for handling errors like + # 1: `XXX() got multiple values for keyword argument 'YYY'` + # 2: `XXX() got an unexpected keyword argument 'YYY' + return klass(**cls_kwargs, **kwargs) @contextlib.contextmanager