Skip to content

Commit

Permalink
Auto injecting model and dataset for Recorder (microsoft#645)
Browse files Browse the repository at this point in the history
* Auto injecting model and dataset for Recorder

* Support using Feature in expression
  • Loading branch information
you-n-g authored Oct 15, 2021
1 parent 3b11912 commit 5a16a5a
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 11 deletions.
5 changes: 3 additions & 2 deletions qlib/data/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import Union, List, Type
from scipy.stats import percentileofscore

from .base import Expression, ExpressionOps
from .base import Expression, ExpressionOps, Feature
from ..log import get_module_logger
from ..utils import get_callable_kwargs

Expand Down Expand Up @@ -1485,6 +1485,7 @@ def __init__(self, feature_left, feature_right, N):
IdxMax,
IdxMin,
If,
Feature,
]


Expand Down Expand Up @@ -1517,7 +1518,7 @@ def register(self, ops_list: List[Union[Type[ExpressionOps], dict]]):
else:
_ops_class = _operator

if not issubclass(_ops_class, ExpressionOps):
if not issubclass(_ops_class, Expression):
raise TypeError("operator must be subclass of ExpressionOps, not {}".format(_ops_class))

if _ops_class.__name__ in self._ops:
Expand Down
21 changes: 15 additions & 6 deletions qlib/model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,19 @@ def fill_placeholder(config: dict, config_extend: dict):
# bfs
top = 0
tail = 1
item_quene = [config]
item_queue = [config]
while top < tail:
now_item = item_quene[top]
now_item = item_queue[top]
top += 1
if isinstance(now_item, list):
item_keys = range(len(now_item))
elif isinstance(now_item, dict):
item_keys = now_item.keys()
for key in item_keys:
if isinstance(now_item[key], list) or isinstance(now_item[key], dict):
item_quene.append(now_item[key])
item_queue.append(now_item[key])
tail += 1
elif now_item[key] in config_extend.keys():
elif isinstance(now_item[key], str) and now_item[key] in config_extend.keys():
now_item[key] = config_extend[now_item[key]]
return config

Expand Down Expand Up @@ -114,10 +114,19 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
task_config = fill_placeholder(task_config, placehorder_value)
# generate records: prediction, backtest, and analysis
records = task_config.get("record", [])
if isinstance(records, dict): # prevent only one dict
if isinstance(records, dict): # uniform the data format to list
records = [records]

for record in records:
r = init_instance_by_config(record, recorder=rec)
# Some recorder require the parameter `model` and `dataset`.
# try to automatically pass in them to the initialization function
# to make defining the tasking easier
r = init_instance_by_config(
record,
recorder=rec,
default_module="qlib.workflow.record_temp",
try_kwargs={"model": model, "dataset": dataset},
)
r.generate()
return rec

Expand Down
21 changes: 18 additions & 3 deletions qlib/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Union, Tuple, Any, Text, Optional
from typing import Dict, Union, Tuple, Any, Text, Optional
from types import ModuleType
from urllib.parse import urlparse

Expand Down Expand Up @@ -232,7 +232,11 @@ def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, Mod


def init_instance_by_config(
config: Union[str, dict, object], default_module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs
config: Union[str, dict, object],
default_module=None,
accept_types: Union[type, Tuple[type]] = (),
try_kwargs: Dict = {},
**kwargs,
) -> Any:
"""
get initialized instance with config
Expand Down Expand Up @@ -270,6 +274,10 @@ def init_instance_by_config(
Optional. If the config is a instance of specific type, return the config directly.
This will be passed into the second parameter of isinstance.
try_kwargs: Dict
Try to pass in kwargs in `try_kwargs` when initialized the instance
If error occurred, it will fail back to initialization without try_kwargs.
Returns
-------
object:
Expand All @@ -286,7 +294,14 @@ def init_instance_by_config(
return pickle.load(f)

klass, cls_kwargs = get_callable_kwargs(config, default_module=default_module)
return klass(**cls_kwargs, **kwargs)

try:
return klass(**cls_kwargs, **try_kwargs, **kwargs)
except (TypeError,):
# TypeError for handling errors like
# 1: `XXX() got multiple values for keyword argument 'YYY'`
# 2: `XXX() got an unexpected keyword argument 'YYY'
return klass(**cls_kwargs, **kwargs)


@contextlib.contextmanager
Expand Down

0 comments on commit 5a16a5a

Please sign in to comment.