Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto injecting model and dataset for Recorder #645

Merged
merged 2 commits into from
Oct 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions qlib/data/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import Union, List, Type
from scipy.stats import percentileofscore

from .base import Expression, ExpressionOps
from .base import Expression, ExpressionOps, Feature
from ..log import get_module_logger
from ..utils import get_callable_kwargs

Expand Down Expand Up @@ -1485,6 +1485,7 @@ def __init__(self, feature_left, feature_right, N):
IdxMax,
IdxMin,
If,
Feature,
]


Expand Down Expand Up @@ -1517,7 +1518,7 @@ def register(self, ops_list: List[Union[Type[ExpressionOps], dict]]):
else:
_ops_class = _operator

if not issubclass(_ops_class, ExpressionOps):
if not issubclass(_ops_class, Expression):
raise TypeError("operator must be subclass of ExpressionOps, not {}".format(_ops_class))

if _ops_class.__name__ in self._ops:
Expand Down
21 changes: 15 additions & 6 deletions qlib/model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,19 @@ def fill_placeholder(config: dict, config_extend: dict):
# bfs
top = 0
tail = 1
item_quene = [config]
item_queue = [config]
while top < tail:
now_item = item_quene[top]
now_item = item_queue[top]
top += 1
if isinstance(now_item, list):
item_keys = range(len(now_item))
elif isinstance(now_item, dict):
item_keys = now_item.keys()
for key in item_keys:
if isinstance(now_item[key], list) or isinstance(now_item[key], dict):
item_quene.append(now_item[key])
item_queue.append(now_item[key])
tail += 1
elif now_item[key] in config_extend.keys():
elif isinstance(now_item[key], str) and now_item[key] in config_extend.keys():
now_item[key] = config_extend[now_item[key]]
return config

Expand Down Expand Up @@ -114,10 +114,19 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
task_config = fill_placeholder(task_config, placehorder_value)
# generate records: prediction, backtest, and analysis
records = task_config.get("record", [])
if isinstance(records, dict): # prevent only one dict
if isinstance(records, dict): # uniform the data format to list
records = [records]

for record in records:
r = init_instance_by_config(record, recorder=rec)
# Some recorder require the parameter `model` and `dataset`.
# try to automatically pass in them to the initialization function
# to make defining the tasking easier
r = init_instance_by_config(
record,
recorder=rec,
default_module="qlib.workflow.record_temp",
try_kwargs={"model": model, "dataset": dataset},
)
r.generate()
return rec

Expand Down
21 changes: 18 additions & 3 deletions qlib/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Union, Tuple, Any, Text, Optional
from typing import Dict, Union, Tuple, Any, Text, Optional
from types import ModuleType
from urllib.parse import urlparse

Expand Down Expand Up @@ -232,7 +232,11 @@ def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, Mod


def init_instance_by_config(
config: Union[str, dict, object], default_module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs
config: Union[str, dict, object],
default_module=None,
accept_types: Union[type, Tuple[type]] = (),
try_kwargs: Dict = {},
**kwargs,
) -> Any:
"""
get initialized instance with config
Expand Down Expand Up @@ -270,6 +274,10 @@ def init_instance_by_config(
Optional. If the config is a instance of specific type, return the config directly.
This will be passed into the second parameter of isinstance.

try_kwargs: Dict
Try to pass in kwargs in `try_kwargs` when initialized the instance
If error occurred, it will fail back to initialization without try_kwargs.

Returns
-------
object:
Expand All @@ -286,7 +294,14 @@ def init_instance_by_config(
return pickle.load(f)

klass, cls_kwargs = get_callable_kwargs(config, default_module=default_module)
return klass(**cls_kwargs, **kwargs)

try:
return klass(**cls_kwargs, **try_kwargs, **kwargs)
except (TypeError,):
# TypeError for handling errors like
# 1: `XXX() got multiple values for keyword argument 'YYY'`
# 2: `XXX() got an unexpected keyword argument 'YYY'
return klass(**cls_kwargs, **kwargs)


@contextlib.contextmanager
Expand Down