Skip to content

Commit

Permalink
Merge pull request microsoft#520 from wangwenxi-handsome/nested_decis…
Browse files Browse the repository at this point in the history
…ion_exe

abstract Quote class from Exchange
  • Loading branch information
you-n-g authored Jul 23, 2021
2 parents 379a8d0 + f5af958 commit 12cb519
Show file tree
Hide file tree
Showing 3 changed files with 572 additions and 138 deletions.
106 changes: 53 additions & 53 deletions qlib/backtest/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from qlib.backtest.position import Position
import random
import logging
from typing import List, Tuple, Union
from typing import List, Tuple, Union, Callable, Iterable

import numpy as np
import pandas as pd
Expand All @@ -16,6 +16,7 @@
from ..utils.resam import resam_ts_data, ts_data_last
from ..log import get_module_logger
from .order import Order, OrderDir, OrderHelper
from .high_performance_ds import PandasQuote


class Exchange:
Expand All @@ -33,6 +34,7 @@ def __init__(
close_cost=0.0025,
min_cost=5,
extra_quote=None,
quote_cls=PandasQuote,
**kwargs,
):
"""__init__
Expand Down Expand Up @@ -103,10 +105,11 @@ def __init__(

# TODO: the quote, trade_dates, codes are not necessray.
# It is just for performance consideration.
self.limit_type = self._get_limit_type(limit_threshold)
if limit_threshold is None:
if C.region == REG_CN:
self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold")
elif self._get_limit_type(limit_threshold) == self.LT_FLT and abs(limit_threshold) > 0.1:
elif self.limit_type == self.LT_FLT and abs(limit_threshold) > 0.1:
if C.region == REG_CN:
self.logger.warning(f"limit_threshold may not be set to a reasonable value")

Expand All @@ -128,10 +131,9 @@ def __init__(
# $change is for calculating the limit of the stock

necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"}
if self._get_limit_type(limit_threshold) == self.LT_TP_EXP:
if self.limit_type == self.LT_TP_EXP:
for exp in limit_threshold:
necessary_fields.add(exp)
subscribe_fields = list(necessary_fields | set(subscribe_fields))
all_fields = list(necessary_fields | set(subscribe_fields))

self.all_fields = all_fields
Expand All @@ -141,39 +143,43 @@ def __init__(
self.limit_threshold: Union[Tuple[str, str], float, None] = limit_threshold
self.volume_threshold = volume_threshold
self.extra_quote = extra_quote
self.set_quote(codes, start_time, end_time)

def set_quote(self, codes, start_time, end_time):
if len(codes) == 0:
codes = D.instruments()

self.quote = D.features(codes, self.all_fields, start_time, end_time, freq=self.freq, disk_cache=True).dropna(
subset=["$close"]
)
self.quote.columns = self.all_fields

self.get_quote_from_qlib()

# init quote by quote_df
self.quote_cls = quote_cls
self.quote = self.quote_cls(self.quote_df)

def get_quote_from_qlib(self):
# get stock data from qlib
if len(self.codes) == 0:
self.codes = D.instruments()
self.quote_df = D.features(
self.codes, self.all_fields, self.start_time, self.end_time, freq=self.freq, disk_cache=True
).dropna(subset=["$close"])
self.quote_df.columns = self.all_fields

# check buy_price data and sell_price data
for attr in "buy_price", "sell_price":
pstr = getattr(self, attr) # price string
if self.quote[pstr].isna().any():
if self.quote_df[pstr].isna().any():
self.logger.warning("{} field data contains nan.".format(pstr))

if self.quote["$factor"].isna().any():
# update trade_w_adj_price
if self.quote_df["$factor"].isna().any():
# The 'factor.day.bin' file not exists, and `factor` field contains `nan`
# Use adjusted price
self.trade_w_adj_price = True
self.logger.warning("factor.day.bin file not exists or factor contains `nan`. Order using adjusted_price.")
if self.trade_unit is not None:
self.logger.warning(f"trade unit {self.trade_unit} is not supported in adjusted_price mode.")

else:
# The `factor.day.bin` file exists and all data `close` and `factor` are not `nan`
# Use normal price
self.trade_w_adj_price = False

# update limit
self._update_limit()
self._update_limit(self.limit_threshold)

quote_df = self.quote
# concat extra_quote
if self.extra_quote is not None:
# process extra_quote
if "$close" not in self.extra_quote:
Expand All @@ -192,21 +198,15 @@ def set_quote(self, codes, start_time, end_time):
if "limit_buy" not in self.extra_quote.columns:
self.extra_quote["limit_buy"] = False
self.logger.warning("No limit_buy set for extra_quote. All stock will be able to be bought.")

assert set(self.extra_quote.columns) == set(quote_df.columns) - {"$change"}
quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0)

quote_dict = {}
for stock_id, stock_val in quote_df.groupby(level="instrument"):
quote_dict[stock_id] = stock_val.droplevel(level="instrument")

self.quote = quote_dict
assert set(self.extra_quote.columns) == set(self.quote_df.columns) - {"$change"}
self.quote_df = pd.concat([self.quote_df, extra_quote], sort=False, axis=0)

LT_TP_EXP = "(exp)" # Tuple[str, str]
LT_FLT = "float" # float
LT_NONE = "none" # none

def _get_limit_type(self, limit_threshold):
"""get limit type"""
if isinstance(limit_threshold, Tuple):
return self.LT_TP_EXP
elif isinstance(limit_threshold, float):
Expand All @@ -216,19 +216,19 @@ def _get_limit_type(self, limit_threshold):
else:
raise NotImplementedError(f"This type of `limit_threshold` is not supported")

def _update_limit(self):
def _update_limit(self, limit_threshold):
# check limit_threshold
lt_type = self._get_limit_type(self.limit_threshold)
if lt_type == self.LT_NONE:
self.quote["limit_buy"] = False
self.quote["limit_sell"] = False
elif lt_type == self.LT_TP_EXP:
limit_type = self._get_limit_type(limit_threshold)
if limit_type == self.LT_NONE:
self.quote_df["limit_buy"] = False
self.quote_df["limit_sell"] = False
elif limit_type == self.LT_TP_EXP:
# set limit
self.quote["limit_buy"] = self.quote[self.limit_threshold[0]]
self.quote["limit_sell"] = self.quote[self.limit_threshold[1]]
elif lt_type == self.LT_FLT:
self.quote["limit_buy"] = self.quote["$change"].ge(self.limit_threshold)
self.quote["limit_sell"] = self.quote["$change"].le(-self.limit_threshold) # pylint: disable=E1130
self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]]
self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]]
elif limit_type == self.LT_FLT:
self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold)
self.quote_df["limit_sell"] = self.quote_df["$change"].le(-limit_threshold) # pylint: disable=E1130

def check_stock_limit(self, stock_id, start_time, end_time, direction=None):
"""
Expand All @@ -242,20 +242,20 @@ def check_stock_limit(self, stock_id, start_time, end_time, direction=None):
"""
if direction is None:
buy_limit = resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all")
sell_limit = resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all")
buy_limit = self.quote.get_data(stock_id, start_time, end_time, fields="limit_buy", method="all")
sell_limit = self.quote.get_data(stock_id, start_time, end_time, fields="limit_sell", method="all")
return buy_limit or sell_limit
elif direction == Order.BUY:
return resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all")
return self.quote.get_data(stock_id, start_time, end_time, fields="limit_buy", method="all")
elif direction == Order.SELL:
return resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all")
return self.quote.get_data(stock_id, start_time, end_time, fields="limit_sell", method="all")
else:
raise ValueError(f"direction {direction} is not supported!")

def check_stock_suspended(self, stock_id, start_time, end_time):
# is suspended
if stock_id in self.quote:
return resam_ts_data(self.quote[stock_id], start_time, end_time, method=None) is None
if stock_id in self.quote.get_all_stock():
return self.quote.get_data(stock_id, start_time, end_time) is None
else:
return True

Expand Down Expand Up @@ -316,13 +316,13 @@ def deal_order(self, order, trade_account=None, position=None):
return trade_val, trade_cost, trade_price

def get_quote_info(self, stock_id, start_time, end_time, method=ts_data_last):
return resam_ts_data(self.quote[stock_id], start_time, end_time, method=method)
return self.quote.get_data(stock_id, start_time, end_time, method=method)

def get_close(self, stock_id, start_time, end_time, method=ts_data_last):
return resam_ts_data(self.quote[stock_id]["$close"], start_time, end_time, method=method)
return self.quote.get_data(stock_id, start_time, end_time, fields="$close", method=method)

def get_volume(self, stock_id, start_time, end_time, method="sum"):
return resam_ts_data(self.quote[stock_id]["$volume"], start_time, end_time, method=method)
return self.quote.get_data(stock_id, start_time, end_time, fields="$volume", method=method)

def get_deal_price(self, stock_id, start_time, end_time, direction: OrderDir, method=ts_data_last):
if direction == OrderDir.SELL:
Expand All @@ -331,7 +331,7 @@ def get_deal_price(self, stock_id, start_time, end_time, direction: OrderDir, me
pstr = self.buy_price
else:
raise NotImplementedError(f"This type of input is not supported")
deal_price = resam_ts_data(self.quote[stock_id][pstr], start_time, end_time, method=method)
deal_price = self.quote.get_data(stock_id, start_time, end_time, fields=pstr, method=method)
if method is not None and (np.isclose(deal_price, 0.0) or np.isnan(deal_price)):
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!")
self.logger.warning(f"setting deal_price to close price")
Expand All @@ -347,9 +347,9 @@ def get_factor(self, stock_id, start_time, end_time) -> Union[float, None]:
`float`: return factor if the factor exists
"""
assert (start_time is not None and end_time is not None, "the time range must be given")
if stock_id not in self.quote:
if stock_id not in self.quote.get_all_stock():
return None
return resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method=ts_data_last)
return self.quote.get_data(stock_id, start_time, end_time, fields="$factor", method=ts_data_last)

def generate_amount_position_from_weight_position(
self, weight_position, cash, start_time, end_time, direction=OrderDir.BUY
Expand Down
Loading

0 comments on commit 12cb519

Please sign in to comment.