import qlib
from qlib.constant import REG_CN
from qlib.utils import init_instance_by_config
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
from qlib.model.base import Model
from qlib.strategy.base import BaseStrategy
import pandas as pd
provider_uri = "~/.qlib/qlib_data/cn_data"
qlib.init(provider_uri=provider_uri, region=REG_CN)
class MyCustomModel(Model):
"""一个非常简单的自定义模型示例:对每只股票做一个线性预测 + 均值校正。"""
def __init__(self, alpha=1.0, **kwargs):
super().__init__(**kwargs)
self.alpha = alpha
self.coef_ = None
def fit(self, dataset):
train_x, train_y = self._prepare_data(dataset, segment="train")
self.coef_ = (train_y.mean() - train_x.mean()) * self.alpha
return self
def predict(self, dataset):
test_x, _ = self._prepare_data(dataset, segment="test")
preds = test_x * 0 + self.coef_
idx = test_x.index
df = pd.DataFrame({"score": preds.values.flatten()}, index=idx)
return df
class MyCustomStrategy(BaseStrategy):
"""一个简单策略:如果模型分数大于阈值则买入,否则清仓."""
def __init__(self, score_threshold=0.0, **kwargs):
super().__init__(**kwargs)
self.score_threshold = score_threshold
def generate_trade_decision(self, **kwargs):
"""
这个方法在每个交易周期(如每天)被调用,传入模型预测分数等信息。
返回 TradeDecision 对象,定义仓位调整。
"""
score = kwargs["score"]
current_pos = kwargs["current_positions"]
target_positions = {}
for inst, s in score.items():
if s > self.score_threshold:
target_positions[inst] = 1.0 / len(score)
else:
target_positions[inst] = 0.0
return self._target_positions_to_trade_decision(target_positions)
task = {
"model": {
"class": "MyCustomModel",
"module_path": __name__,
"kwargs": {
"alpha": 0.5,
},
},
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": {
"start_time": "2010-01-01",
"end_time": "2020-12-31",
"fit_start_time": "2010-01-01",
"fit_end_time": "2015-12-31",
"instruments": "csi300",
},
},
"segments": {
"train": ("2010-01-01", "2015-12-31"),
"valid": ("2016-01-01", "2017-12-31"),
"test": ("2018-01-01", "2020-12-31"),
},
},
},
}
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])
with R.start(experiment_name="train_custom_model"):
model.fit(dataset)
R.save_objects(trained_model=model)
rid = R.get_recorder().id
with R.start(experiment_name="backtest_custom_strategy"):
recorder = R.get_recorder(recorder_id=rid, experiment_name="train_custom_model")
trained_model = recorder.load_object("trained_model")
sr = SignalRecord(trained_model, dataset, recorder)
sr.generate()
backtest_config = {
"executor": {
"class": "SimulatorExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": "day",
"generate_portfolio_metrics": True,
},
},
"strategy": {
"class": "MyCustomStrategy",
"module_path": __name__,
"kwargs": {
"score_threshold": 0.01,
},
},
"backtest": {
"start_time": "2018-01-01",
"end_time": "2020-12-31",
"account": 1e7,
"benchmark": "SH000300",
"exchange_kwargs": {
"freq": "day",
"deal_price": "close",
"limit_threshold": 0.095,
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
},
},
}
par = PortAnaRecord(recorder, backtest_config, freq="day")
par.generate()
print("回测完成!")