Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

abstract Quote class from Exchange #520

Merged
106 changes: 53 additions & 53 deletions qlib/backtest/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from qlib.backtest.position import Position
import random
import logging
from typing import List, Tuple, Union
from typing import List, Tuple, Union, Callable, Iterable

import numpy as np
import pandas as pd
Expand All @@ -16,6 +16,7 @@
from ..utils.resam import resam_ts_data, ts_data_last
from ..log import get_module_logger
from .order import Order, OrderDir, OrderHelper
from .high_performane_ds import PandasQuote


class Exchange:
Expand All @@ -33,6 +34,7 @@ def __init__(
close_cost=0.0025,
min_cost=5,
extra_quote=None,
quote_cls=PandasQuote,
**kwargs,
):
"""__init__
Expand Down Expand Up @@ -103,10 +105,11 @@ def __init__(

# TODO: the quote, trade_dates, codes are not necessray.
# It is just for performance consideration.
self.limit_type = self._get_limit_type(limit_threshold)
if limit_threshold is None:
if C.region == REG_CN:
self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold")
elif self._get_limit_type(limit_threshold) == self.LT_FLT and abs(limit_threshold) > 0.1:
elif self.limit_type == self.LT_FLT and abs(limit_threshold) > 0.1:
if C.region == REG_CN:
self.logger.warning(f"limit_threshold may not be set to a reasonable value")

Expand All @@ -128,10 +131,9 @@ def __init__(
# $change is for calculating the limit of the stock

necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"}
if self._get_limit_type(limit_threshold) == self.LT_TP_EXP:
if self.limit_type == self.LT_TP_EXP:
for exp in limit_threshold:
necessary_fields.add(exp)
subscribe_fields = list(necessary_fields | set(subscribe_fields))
all_fields = list(necessary_fields | set(subscribe_fields))

self.all_fields = all_fields
Expand All @@ -141,39 +143,43 @@ def __init__(
self.limit_threshold: Union[Tuple[str, str], float, None] = limit_threshold
self.volume_threshold = volume_threshold
self.extra_quote = extra_quote
self.set_quote(codes, start_time, end_time)

def set_quote(self, codes, start_time, end_time):
if len(codes) == 0:
codes = D.instruments()

self.quote = D.features(codes, self.all_fields, start_time, end_time, freq=self.freq, disk_cache=True).dropna(
subset=["$close"]
)
self.quote.columns = self.all_fields

self.get_quote_from_qlib()

# init quote by quote_df
self.quote_cls = quote_cls
self.quote = self.quote_cls(self.quote_df)

def get_quote_from_qlib(self):
# get stock data from qlib
if len(self.codes) == 0:
self.codes = D.instruments()
self.quote_df = D.features(
self.codes, self.all_fields, self.start_time, self.end_time, freq=self.freq, disk_cache=True
).dropna(subset=["$close"])
self.quote_df.columns = self.all_fields

# check buy_price data and sell_price data
for attr in "buy_price", "sell_price":
pstr = getattr(self, attr) # price string
if self.quote[pstr].isna().any():
if self.quote_df[pstr].isna().any():
self.logger.warning("{} field data contains nan.".format(pstr))

if self.quote["$factor"].isna().any():
# update trade_w_adj_price
if self.quote_df["$factor"].isna().any():
# The 'factor.day.bin' file not exists, and `factor` field contains `nan`
# Use adjusted price
self.trade_w_adj_price = True
self.logger.warning("factor.day.bin file not exists or factor contains `nan`. Order using adjusted_price.")
if self.trade_unit is not None:
self.logger.warning(f"trade unit {self.trade_unit} is not supported in adjusted_price mode.")

else:
# The `factor.day.bin` file exists and all data `close` and `factor` are not `nan`
# Use normal price
self.trade_w_adj_price = False

# update limit
self._update_limit()
self._update_limit(self.limit_threshold)

quote_df = self.quote
# concat extra_quote
if self.extra_quote is not None:
# process extra_quote
if "$close" not in self.extra_quote:
Expand All @@ -192,21 +198,15 @@ def set_quote(self, codes, start_time, end_time):
if "limit_buy" not in self.extra_quote.columns:
self.extra_quote["limit_buy"] = False
self.logger.warning("No limit_buy set for extra_quote. All stock will be able to be bought.")

assert set(self.extra_quote.columns) == set(quote_df.columns) - {"$change"}
quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0)

quote_dict = {}
for stock_id, stock_val in quote_df.groupby(level="instrument"):
quote_dict[stock_id] = stock_val.droplevel(level="instrument")

self.quote = quote_dict
assert set(self.extra_quote.columns) == set(self.quote_df.columns) - {"$change"}
self.quote_df = pd.concat([self.quote_df, extra_quote], sort=False, axis=0)

LT_TP_EXP = "(exp)" # Tuple[str, str]
LT_FLT = "float" # float
LT_NONE = "none" # none

def _get_limit_type(self, limit_threshold):
"""get limit type"""
if isinstance(limit_threshold, Tuple):
return self.LT_TP_EXP
elif isinstance(limit_threshold, float):
Expand All @@ -216,19 +216,19 @@ def _get_limit_type(self, limit_threshold):
else:
raise NotImplementedError(f"This type of `limit_threshold` is not supported")

def _update_limit(self):
def _update_limit(self, limit_threshold):
# check limit_threshold
lt_type = self._get_limit_type(self.limit_threshold)
if lt_type == self.LT_NONE:
self.quote["limit_buy"] = False
self.quote["limit_sell"] = False
elif lt_type == self.LT_TP_EXP:
limit_type = self._get_limit_type(limit_threshold)
if limit_type == self.LT_NONE:
self.quote_df["limit_buy"] = False
self.quote_df["limit_sell"] = False
elif limit_type == self.LT_TP_EXP:
# set limit
self.quote["limit_buy"] = self.quote[self.limit_threshold[0]]
self.quote["limit_sell"] = self.quote[self.limit_threshold[1]]
elif lt_type == self.LT_FLT:
self.quote["limit_buy"] = self.quote["$change"].ge(self.limit_threshold)
self.quote["limit_sell"] = self.quote["$change"].le(-self.limit_threshold) # pylint: disable=E1130
self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]]
self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]]
elif limit_type == self.LT_FLT:
self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold)
self.quote_df["limit_sell"] = self.quote_df["$change"].le(-limit_threshold) # pylint: disable=E1130

def check_stock_limit(self, stock_id, start_time, end_time, direction=None):
"""
Expand All @@ -242,20 +242,20 @@ def check_stock_limit(self, stock_id, start_time, end_time, direction=None):

"""
if direction is None:
buy_limit = resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all")
sell_limit = resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all")
buy_limit = self.quote.get_data(stock_id, start_time, end_time, fields="limit_buy", method="all")
sell_limit = self.quote.get_data(stock_id, start_time, end_time, fields="limit_sell", method="all")
return buy_limit or sell_limit
elif direction == Order.BUY:
return resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all")
return self.quote.get_data(stock_id, start_time, end_time, fields="limit_buy", method="all")
elif direction == Order.SELL:
return resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all")
return self.quote.get_data(stock_id, start_time, end_time, fields="limit_sell", method="all")
else:
raise ValueError(f"direction {direction} is not supported!")

def check_stock_suspended(self, stock_id, start_time, end_time):
# is suspended
if stock_id in self.quote:
return resam_ts_data(self.quote[stock_id], start_time, end_time, method=None) is None
if stock_id in self.quote.get_all_stock():
return self.quote.get_data(stock_id, start_time, end_time) is None
else:
return True

Expand Down Expand Up @@ -316,13 +316,13 @@ def deal_order(self, order, trade_account=None, position=None):
return trade_val, trade_cost, trade_price

def get_quote_info(self, stock_id, start_time, end_time, method=ts_data_last):
return resam_ts_data(self.quote[stock_id], start_time, end_time, method=method)
return self.quote.get_data(stock_id, start_time, end_time, method=method)

def get_close(self, stock_id, start_time, end_time, method=ts_data_last):
return resam_ts_data(self.quote[stock_id]["$close"], start_time, end_time, method=method)
return self.quote.get_data(stock_id, start_time, end_time, fields="$close", method=method)

def get_volume(self, stock_id, start_time, end_time, method="sum"):
return resam_ts_data(self.quote[stock_id]["$volume"], start_time, end_time, method=method)
return self.quote.get_data(stock_id, start_time, end_time, fields="$volume", method=method)

def get_deal_price(self, stock_id, start_time, end_time, direction: OrderDir, method=ts_data_last):
if direction == OrderDir.SELL:
Expand All @@ -331,7 +331,7 @@ def get_deal_price(self, stock_id, start_time, end_time, direction: OrderDir, me
pstr = self.buy_price
else:
raise NotImplementedError(f"This type of input is not supported")
deal_price = resam_ts_data(self.quote[stock_id][pstr], start_time, end_time, method=method)
deal_price = self.quote.get_data(stock_id, start_time, end_time, fields=pstr, method=method)
if method is not None and (np.isclose(deal_price, 0.0) or np.isnan(deal_price)):
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!")
self.logger.warning(f"setting deal_price to close price")
Expand All @@ -347,9 +347,9 @@ def get_factor(self, stock_id, start_time, end_time) -> Union[float, None]:
`float`: return factor if the factor exists
"""
assert (start_time is not None and end_time is not None, "the time range must be given")
if stock_id not in self.quote:
if stock_id not in self.quote.get_all_stock():
return None
return resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method=ts_data_last)
return self.quote.get_data(stock_id, start_time, end_time, fields="$factor", method=ts_data_last)

def generate_amount_position_from_weight_position(
self, weight_position, cash, start_time, end_time, direction=OrderDir.BUY
Expand Down
Loading