From 87e182a7367450adb93ce46f346aa3999d90eaef Mon Sep 17 00:00:00 2001 From: zhupr Date: Mon, 8 Nov 2021 23:58:02 +0800 Subject: [PATCH 1/2] add default protocol_version --- qlib/config.py | 5 + qlib/contrib/online/manager.py | 9 +- qlib/contrib/online/utils.py | 6 +- qlib/data/cache.py | 10 +- qlib/data/client.py | 205 +++++++++++++++++---------------- qlib/utils/objm.py | 2 +- qlib/utils/serial.py | 5 +- qlib/workflow/task/manage.py | 8 +- 8 files changed, 129 insertions(+), 121 deletions(-) diff --git a/qlib/config.py b/qlib/config.py index 029434a886..d143bad9da 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -73,6 +73,9 @@ def set_conf_from_C(self, config_c): REG_CN = "cn" REG_US = "us" +# pickle.dump protocol version: https://docs.python.org/3/library/pickle.html#data-stream-format +PROTOCOL_VERSION = 4 + NUM_USABLE_CPU = max(multiprocessing.cpu_count() - 2, 1) DISK_DATASET_CACHE = "DiskDatasetCache" @@ -107,6 +110,8 @@ def set_conf_from_C(self, config_c): # for simple dataset cache "local_cache_path": None, "kernels": NUM_USABLE_CPU, + # pickle.dump protocol version + "dump_protocol_version": PROTOCOL_VERSION, # How many tasks belong to one process. Recommend 1 for high-frequency data and None for daily data. "maxtasksperchild": None, # If joblib_backend is None, use loky diff --git a/qlib/contrib/online/manager.py b/qlib/contrib/online/manager.py index 70b7bad408..7b07c4c076 100644 --- a/qlib/contrib/online/manager.py +++ b/qlib/contrib/online/manager.py @@ -1,17 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import os -import pickle import yaml import pathlib import pandas as pd import shutil -from ..backtest.account import Account -from ..backtest.exchange import Exchange +from ...backtest.account import Account from .user import User -from .utils import load_instance -from ...utils import save_instance, init_instance_by_config +from .utils import load_instance, save_instance +from ...utils import init_instance_by_config class UserManager: diff --git a/qlib/contrib/online/utils.py b/qlib/contrib/online/utils.py index 71a6a91ec2..2a775ba626 100644 --- a/qlib/contrib/online/utils.py +++ b/qlib/contrib/online/utils.py @@ -6,10 +6,10 @@ import yaml import pandas as pd from ...data import D +from ...config import C from ...log import get_module_logger -from ...utils import get_module_by_module_path, init_instance_by_config from ...utils import get_next_trading_date -from ..backtest.exchange import Exchange +from ...backtest.exchange import Exchange log = get_module_logger("utils") @@ -42,7 +42,7 @@ def save_instance(instance, file_path): """ file_path = pathlib.Path(file_path) with file_path.open("wb") as fr: - pickle.dump(instance, fr) + pickle.dump(instance, fr, C.dump_protocol_version) def create_user_folder(path): diff --git a/qlib/data/cache.py b/qlib/data/cache.py index 180c3a744f..6193dcf92b 100644 --- a/qlib/data/cache.py +++ b/qlib/data/cache.py @@ -230,7 +230,7 @@ def visit(cache_path: Union[str, Path]): d["meta"]["visits"] = d["meta"]["visits"] + 1 except KeyError: raise KeyError("Unknown meta keyword") - pickle.dump(d, f) + pickle.dump(d, f, protocol=C.dump_protocol_version) except Exception as e: get_module_logger("CacheUtils").warning(f"visit {cache_path} cache error: {e}") @@ -573,7 +573,7 @@ def gen_expression_cache(self, expression_data, cache_path, instrument, field, f meta_path = cache_path.with_suffix(".meta") with meta_path.open("wb") as f: - pickle.dump(meta, f) + pickle.dump(meta, f, protocol=C.dump_protocol_version) meta_path.chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH) df = expression_data.to_frame() @@ -638,7 +638,7 @@ def update(self, sid, cache_uri, freq: str = "day"): # update meta file d["info"]["last_update"] = str(new_calendar[-1]) with meta_path.open("wb") as f: - pickle.dump(d, f) + pickle.dump(d, f, protocol=C.dump_protocol_version) return 0 @@ -935,7 +935,7 @@ def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, f "meta": {"last_visit": time.time(), "visits": 1}, } with cache_path.with_suffix(".meta").open("wb") as f: - pickle.dump(meta, f) + pickle.dump(meta, f, protocol=C.dump_protocol_version) cache_path.with_suffix(".meta").chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH) # write index file im = DiskDatasetCache.IndexManager(cache_path) @@ -1057,7 +1057,7 @@ def update(self, cache_uri, freq: str = "day"): # update meta file d["info"]["last_update"] = str(new_calendar[-1]) with meta_path.open("wb") as f: - pickle.dump(d, f) + pickle.dump(d, f, protocol=C.dump_protocol_version) return 0 diff --git a/qlib/data/client.py b/qlib/data/client.py index 5244a7e45c..fc96161e85 100644 --- a/qlib/data/client.py +++ b/qlib/data/client.py @@ -1,102 +1,103 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - - -from __future__ import division -from __future__ import print_function - -import socketio - -import qlib -from ..log import get_module_logger -import pickle - - -class Client: - """A client class - - Provide the connection tool functions for ClientProvider. - """ - - def __init__(self, host, port): - super(Client, self).__init__() - self.sio = socketio.Client() - self.server_host = host - self.server_port = port - self.logger = get_module_logger(self.__class__.__name__) - # bind connect/disconnect callbacks - self.sio.on( - "connect", - lambda: self.logger.debug("Connect to server {}".format(self.sio.connection_url)), - ) - self.sio.on("disconnect", lambda: self.logger.debug("Disconnect from server!")) - - def connect_server(self): - """Connect to server.""" - try: - self.sio.connect("ws://" + self.server_host + ":" + str(self.server_port)) - except socketio.exceptions.ConnectionError: - self.logger.error("Cannot connect to server - check your network or server status") - - def disconnect(self): - """Disconnect from server.""" - try: - self.sio.eio.disconnect(True) - except Exception as e: - self.logger.error("Cannot disconnect from server : %s" % e) - - def send_request(self, request_type, request_content, msg_queue, msg_proc_func=None): - """Send a certain request to server. - - Parameters - ---------- - request_type : str - type of proposed request, 'calendar'/'instrument'/'feature'. - request_content : dict - records the information of the request. - msg_proc_func : func - the function to process the message when receiving response, should have arg `*args`. - msg_queue: Queue - The queue to pass the messsage after callback. - """ - head_info = {"version": qlib.__version__} - - def request_callback(*args): - """callback_wrapper - - :param *args: args[0] is the response content - """ - # args[0] is the response content - self.logger.debug("receive data and enter queue") - msg = dict(args[0]) - if msg["detailed_info"] is not None: - if msg["status"] != 0: - self.logger.error(msg["detailed_info"]) - else: - self.logger.info(msg["detailed_info"]) - if msg["status"] != 0: - ex = ValueError(f"Bad response(status=={msg['status']}), detailed info: {msg['detailed_info']}") - msg_queue.put(ex) - else: - if msg_proc_func is not None: - try: - ret = msg_proc_func(msg["result"]) - except Exception as e: - self.logger.exception("Error when processing message.") - ret = e - else: - ret = msg["result"] - msg_queue.put(ret) - self.disconnect() - self.logger.debug("disconnected") - - self.logger.debug("try connecting") - self.connect_server() - self.logger.debug("connected") - # The pickle is for passing some parameters with special type(such as - # pd.Timestamp) - request_content = {"head": head_info, "body": pickle.dumps(request_content)} - self.sio.on(request_type + "_response", request_callback) - self.logger.debug("try sending") - self.sio.emit(request_type + "_request", request_content) - self.sio.wait() +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from __future__ import division +from __future__ import print_function + +import socketio + +import qlib +from ..config import C +from ..log import get_module_logger +import pickle + + +class Client: + """A client class + + Provide the connection tool functions for ClientProvider. + """ + + def __init__(self, host, port): + super(Client, self).__init__() + self.sio = socketio.Client() + self.server_host = host + self.server_port = port + self.logger = get_module_logger(self.__class__.__name__) + # bind connect/disconnect callbacks + self.sio.on( + "connect", + lambda: self.logger.debug("Connect to server {}".format(self.sio.connection_url)), + ) + self.sio.on("disconnect", lambda: self.logger.debug("Disconnect from server!")) + + def connect_server(self): + """Connect to server.""" + try: + self.sio.connect("ws://" + self.server_host + ":" + str(self.server_port)) + except socketio.exceptions.ConnectionError: + self.logger.error("Cannot connect to server - check your network or server status") + + def disconnect(self): + """Disconnect from server.""" + try: + self.sio.eio.disconnect(True) + except Exception as e: + self.logger.error("Cannot disconnect from server : %s" % e) + + def send_request(self, request_type, request_content, msg_queue, msg_proc_func=None): + """Send a certain request to server. + + Parameters + ---------- + request_type : str + type of proposed request, 'calendar'/'instrument'/'feature'. + request_content : dict + records the information of the request. + msg_proc_func : func + the function to process the message when receiving response, should have arg `*args`. + msg_queue: Queue + The queue to pass the messsage after callback. + """ + head_info = {"version": qlib.__version__} + + def request_callback(*args): + """callback_wrapper + + :param *args: args[0] is the response content + """ + # args[0] is the response content + self.logger.debug("receive data and enter queue") + msg = dict(args[0]) + if msg["detailed_info"] is not None: + if msg["status"] != 0: + self.logger.error(msg["detailed_info"]) + else: + self.logger.info(msg["detailed_info"]) + if msg["status"] != 0: + ex = ValueError(f"Bad response(status=={msg['status']}), detailed info: {msg['detailed_info']}") + msg_queue.put(ex) + else: + if msg_proc_func is not None: + try: + ret = msg_proc_func(msg["result"]) + except Exception as e: + self.logger.exception("Error when processing message.") + ret = e + else: + ret = msg["result"] + msg_queue.put(ret) + self.disconnect() + self.logger.debug("disconnected") + + self.logger.debug("try connecting") + self.connect_server() + self.logger.debug("connected") + # The pickle is for passing some parameters with special type(such as + # pd.Timestamp) + request_content = {"head": head_info, "body": pickle.dumps(request_content, protocol=C.dump_protocol_version)} + self.sio.on(request_type + "_response", request_callback) + self.logger.debug("try sending") + self.sio.emit(request_type + "_request", request_content) + self.sio.wait() diff --git a/qlib/utils/objm.py b/qlib/utils/objm.py index eebd529c66..c125a6ae1e 100644 --- a/qlib/utils/objm.py +++ b/qlib/utils/objm.py @@ -106,7 +106,7 @@ def create_path(self) -> str: def save_obj(self, obj, name): with (self.path / name).open("wb") as f: - pickle.dump(obj, f) + pickle.dump(obj, f, protocol=C.dump_protocol_version) def save_objs(self, obj_name_l): for obj, name in obj_name_l: diff --git a/qlib/utils/serial.py b/qlib/utils/serial.py index 04d16ab7a2..d8949785b7 100644 --- a/qlib/utils/serial.py +++ b/qlib/utils/serial.py @@ -5,6 +5,7 @@ import dill from pathlib import Path from typing import Union +from ..config import C class Serializable: @@ -85,7 +86,7 @@ def to_pickle(self, path: Union[Path, str], dump_all: bool = None, exclude: list """ self.config(dump_all=dump_all, exclude=exclude) with Path(path).open("wb") as f: - self.get_backend().dump(self, f) + self.get_backend().dump(self, f, protocol=C.dump_protocol_version) @classmethod def load(cls, filepath): @@ -140,4 +141,4 @@ def general_dump(obj, path: Union[Path, str]): obj.to_pickle(path) else: with path.open("wb") as f: - pickle.dump(obj, f) + pickle.dump(obj, f, protocol=C.dump_protocol_version) diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index ec735c1972..ab4febf6a9 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -27,6 +27,7 @@ from tqdm.cli import tqdm from .utils import get_mongodb +from ...config import C class TaskManager: @@ -108,7 +109,7 @@ def _encode_task(self, task): for prefix in self.ENCODE_FIELDS_PREFIX: for k in list(task.keys()): if k.startswith(prefix): - task[k] = Binary(pickle.dumps(task[k])) + task[k] = Binary(pickle.dumps(task[k], protocol=C.dump_protocol_version)) return task def _decode_task(self, task): @@ -359,7 +360,10 @@ def commit_task_res(self, task, res, status=STATUS_DONE): # A workaround to use the class attribute. if status is None: status = TaskManager.STATUS_DONE - self.task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}}) + self.task_pool.update_one( + {"_id": task["_id"]}, + {"$set": {"status": status, "res": Binary(pickle.dumps(res, protocol=C.dump_protocol_version))}}, + ) def return_task(self, task, status=STATUS_WAITING): """ From fc27139f5d6a5e2f381224455e765e5a3279a76f Mon Sep 17 00:00:00 2001 From: zhupr Date: Wed, 10 Nov 2021 14:09:45 +0800 Subject: [PATCH 2/2] add comment to serial.Serializable.get_backend --- qlib/data/base.py | 5 +++-- qlib/data/data.py | 5 +++-- qlib/data/ops.py | 4 ++-- qlib/utils/serial.py | 2 ++ 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/qlib/data/base.py b/qlib/data/base.py index 1c32c949f3..99c5533579 100644 --- a/qlib/data/base.py +++ b/qlib/data/base.py @@ -154,10 +154,11 @@ def load(self, instrument, start_index, end_index, freq): raise ValueError("Invalid index range: {} {}".format(start_index, end_index)) try: series = self._load_internal(instrument, start_index, end_index, freq) - except Exception: + except Exception as e: get_module_logger("data").error( f"Loading data error: instrument={instrument}, expression={str(self)}, " - f"start_index={start_index}, end_index={end_index}, freq={freq}" + f"start_index={start_index}, end_index={end_index}, freq={freq}. " + f"error info: {str(e)}" ) raise series.name = str(self) diff --git a/qlib/data/data.py b/qlib/data/data.py index 9f27b5dadf..f4759c5afb 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -726,10 +726,11 @@ def expression(self, instrument, field, start_time=None, end_time=None, freq="da lft_etd, rght_etd = expression.get_extended_window_size() try: series = expression.load(instrument, max(0, start_index - lft_etd), end_index + rght_etd, freq) - except Exception: + except Exception as e: get_module_logger("data").error( f"Loading expression error: " - f"instrument={instrument}, field=({field}), start_time={start_time}, end_time={end_time}, freq={freq}" + f"instrument={instrument}, field=({field}), start_time={start_time}, end_time={end_time}, freq={freq}. " + f"error info: {str(e)}" ) raise # Ensure that each column type is consistent diff --git a/qlib/data/ops.py b/qlib/data/ops.py index 4a859f345c..9384177a8d 100644 --- a/qlib/data/ops.py +++ b/qlib/data/ops.py @@ -312,12 +312,12 @@ def _load_internal(self, instrument, start_index, end_index, freq): warning_info = ( f"Loading {instrument}: {str(self)}; np.{self.func}(series_left, series_right), " f"The length of series_left and series_right is different: ({len(series_left)}, {len(series_right)}), " - f"series_left is {str(self.feature_left)}, series_right is {str(self.feature_left)}. Please check the data" + f"series_left is {str(self.feature_left)}, series_right is {str(self.feature_right)}. Please check the data" ) else: warning_info = ( f"Loading {instrument}: {str(self)}; np.{self.func}(series_left, series_right), " - f"series_left is {str(self.feature_left)}, series_right is {str(self.feature_left)}. Please check the data" + f"series_left is {str(self.feature_left)}, series_right is {str(self.feature_right)}. Please check the data" ) try: res = getattr(np, self.func)(series_left, series_right) diff --git a/qlib/utils/serial.py b/qlib/utils/serial.py index d8949785b7..4e9d7739bb 100644 --- a/qlib/utils/serial.py +++ b/qlib/utils/serial.py @@ -86,6 +86,7 @@ def to_pickle(self, path: Union[Path, str], dump_all: bool = None, exclude: list """ self.config(dump_all=dump_all, exclude=exclude) with Path(path).open("wb") as f: + # pickle interface like backend; such as dill self.get_backend().dump(self, f, protocol=C.dump_protocol_version) @classmethod @@ -117,6 +118,7 @@ def get_backend(cls): Returns: module: pickle or dill module based on pickle_backend """ + # NOTE: pickle interface like backend; such as dill if cls.pickle_backend == "pickle": return pickle elif cls.pickle_backend == "dill":