自定义模型(Custom Model) + 自定义策略(Custom Strategy)

140 阅读2分钟
import qlib

from qlib.constant import REG_CN

from qlib.utils import init_instance_by_config

from qlib.workflow import R

from qlib.workflow.record_temp import SignalRecord, PortAnaRecord

  


from qlib.model.base import Model

from qlib.strategy.base import BaseStrategy

import pandas as pd

  


# --- 1. 初始化 QLib ---

provider_uri = "~/.qlib/qlib_data/cn_data"

qlib.init(provider_uri=provider_uri, region=REG_CN)

  


# --- 2. 定义自定义模型 ---

class MyCustomModel(Model):

    """一个非常简单的自定义模型示例:对每只股票做一个线性预测 + 均值校正。"""

    def __init__(self, alpha=1.0, **kwargs):

        super().__init__(**kwargs)

        self.alpha = alpha

        self.coef_ = None  # 线性系数

  


    def fit(self, dataset):

        # dataset 是 QLib Dataset 对象

        train_x, train_y = self._prepare_data(dataset, segment="train")

        # 简单线性回归:假设 train_x 是 DataFrame, train_y 是 Series

        # coef = alpha * 均值差 (很 dumb 的示例)

        self.coef_ = (train_y.mean() - train_x.mean()) * self.alpha

        return self

  


    def predict(self, dataset):

        # 获取特征

        test_x, _ = self._prepare_data(dataset, segment="test")

        # 生成预测。这里我们假设 test_x 是 DataFrame,每列代表一个因子。

        # 我们对所有因子使用相同 coef_

        preds = test_x * 0  + self.coef_

        # 构造需要的 DataFrame: MultiIndex (datetime, instrument),并有一列 "score"

        idx = test_x.index  # 假设已经是 MultiIndex (datetime, instrument)

        df = pd.DataFrame({"score": preds.values.flatten()}, index=idx)

        return df

  


# --- 3. 定义自定义策略 ---

class MyCustomStrategy(BaseStrategy):

    """一个简单策略:如果模型分数大于阈值则买入,否则清仓."""

    def __init__(self, score_threshold=0.0, **kwargs):

        super().__init__(**kwargs)

        self.score_threshold = score_threshold

  


    def generate_trade_decision(self, **kwargs):

        """

        这个方法在每个交易周期(如每天)被调用,传入模型预测分数等信息。

        返回 TradeDecision 对象,定义仓位调整。

        """

        # kwargs 中通常包含 `score`、`current_positions`, `datetime` 等

        score = kwargs["score"# 一个 DataFrame, index 是 instrument

        current_pos = kwargs["current_positions"# 当前持仓比例

  


        target_positions = {}

        for inst, s in score.items():

            if s > self.score_threshold:

                target_positions[inst] = 1.0 / len(score)  # 等权分配 (示例)

            else:

                target_positions[inst] = 0.0

  


        # BaseStrategy 有 API 来从 target_positions 构造订单

        return self._target_positions_to_trade_decision(target_positions)

  


# --- 4. 配置 dataset, model, strategy, executor, backtest ---

  


task = {

    "model": {

        "class": "MyCustomModel",

        "module_path": __name__,  # 当前模块

        "kwargs": {

            "alpha": 0.5,

        },

    },

    "dataset": {

        "class": "DatasetH",

        "module_path": "qlib.data.dataset",

        "kwargs": {

            "handler": {

                "class": "Alpha158",

                "module_path": "qlib.contrib.data.handler",

                "kwargs": {

                    "start_time": "2010-01-01",

                    "end_time": "2020-12-31",

                    "fit_start_time": "2010-01-01",

                    "fit_end_time": "2015-12-31",

                    "instruments": "csi300",

                },

            },

            "segments": {

                "train": ("2010-01-01", "2015-12-31"),

                "valid": ("2016-01-01", "2017-12-31"),

                "test": ("2018-01-01", "2020-12-31"),

            },

        },

    },

}

  


# 初始化实例

model = init_instance_by_config(task["model"])

dataset = init_instance_by_config(task["dataset"])

  


# --- 5. 训练模型 ---

with R.start(experiment_name="train_custom_model"):

    model.fit(dataset)

    R.save_objects(trained_model=model)

    rid = R.get_recorder().id

  


# --- 6. 回测 / 预测信号 + 回测分析 ---

with R.start(experiment_name="backtest_custom_strategy"):

    recorder = R.get_recorder(recorder_id=rid, experiment_name="train_custom_model")

    trained_model = recorder.load_object("trained_model")

  


    # 生成预测信号

    sr = SignalRecord(trained_model, dataset, recorder)

    sr.generate()

  


    # 配置回测

    backtest_config = {

        "executor": {

            "class": "SimulatorExecutor",

            "module_path": "qlib.backtest.executor",

            "kwargs": {

                "time_per_step": "day",

                "generate_portfolio_metrics": True,

            },

        },

        "strategy": {

            "class": "MyCustomStrategy",

            "module_path": __name__,

            "kwargs": {

                "score_threshold": 0.01,

            },

        },

        "backtest": {

            "start_time": "2018-01-01",

            "end_time": "2020-12-31",

            "account": 1e7,

            "benchmark": "SH000300",

            "exchange_kwargs": {

                "freq": "day",

                "deal_price": "close",

                "limit_threshold": 0.095,

                "open_cost": 0.0005,

                "close_cost": 0.0015,

                "min_cost": 5,

            },

        },

    }

  


    par = PortAnaRecord(recorder, backtest_config, freq="day")

    par.generate()

  


print("回测完成!")