test

61 阅读6分钟
#!/usr/bin/env python3
"""
股指期货1分钟K线交易系统

完整流程:
1. 从JSON文件加载1分钟K线数据
2. 创建特征和标签
3. 训练LightGBM模型
4. 运行回测
5. 生成分析报告

使用方法:
    python future_index_trading_system.py --data_path demo/chap06/data.json

    
json_future_data_loader.py - JSON数据加载器
single_future_handler.py - 单标数据处理器
single_future_strategy.py - 单标交易策略
future_index_trading_system.py - 主程序

"""

import argparse
import warnings
import numpy as np
import pandas as pd
import lightgbm as lgb

# 导入自定义模块(相对导入)
import json_future_data_loader
import single_future_handler

warnings.filterwarnings('ignore')


# ============================================
# 模型训练
# ============================================

def train_model(train_data: pd.DataFrame, n_estimators: int = 100):
    """
    训练LightGBM模型

    Parameters:
    -----------
    train_data : pd.DataFrame
        训练数据,包含特征列和LABEL0列
    n_estimators : int
        树的数量

    Returns:
    --------
    model
        训练好的lightgbm模型
    """
    print("\n" + "=" * 60)
    print("训练LightGBM模型")
    print("=" * 60)

    # 分离特征和标签
    feature_cols = [col for col in train_data.columns if 'LABEL' not in col]
    label_cols = [col for col in train_data.columns if 'LABEL' in col]

    X_train = train_data[feature_cols]
    y_train = train_data[label_cols[0]]

    # 处理NaN值
    X_train = X_train.fillna(X_train.mean())
    X_train = X_train.fillna(0)
    y_train = y_train.fillna(y_train.median())

    print(f"特征数量: {len(feature_cols)}")
    print(f"训练样本数: {len(X_train)}")

    # 创建lightgbm数据集
    train_dataset = lgb.Dataset(X_train, label=y_train)

    # 训练参数
    params = {
        'objective': 'regression',
        'metric': 'mse',
        'num_leaves': 31,
        'max_depth': 6,
        'learning_rate': 0.1,
        'verbose': -1
    }

    # 训练模型
    print("\n开始训练...")
    model = lgb.train(params, train_dataset, num_boost_round=n_estimators)

    # 评估模型
    train_pred = model.predict(X_train)
    train_ic = np.corrcoef(train_pred, y_train.values)[0, 1]
    train_mse = np.mean((train_pred - y_train.values) ** 2)

    print(f"\n训练集IC: {train_ic:.4f}")
    print(f"训练集MSE: {train_mse:.6f}")

    return model


# ============================================
# 回测
# ============================================

def run_backtest(
    model,
    test_data: pd.DataFrame,
    test_df: pd.DataFrame,  # 添加原始OHLCV数据
    start_time: str = None,
    end_time: str = None,
    hold_bars: int = 5,
    long_threshold: float = 0.003,
    short_threshold: float = -0.003,
    init_cash: float = 1000000
):
    """
    运行回测

    Parameters:
    -----------
    model
        训练好的lightgbm模型
    test_data : pd.DataFrame
        测试数据(包含特征和标签)
    test_df : pd.DataFrame
        原始OHLCV数据
    start_time : str
        回测开始时间
    end_time : str
        回测结束时间
    hold_bars : int
        持仓K线数
    long_threshold : float
        做多阈值
    short_threshold : float
        做空阈值
    init_cash : float
        初始资金

    Returns:
    --------
    dict
        回测结果
    """
    print("\n" + "=" * 60)
    print("运行回测")
    print("=" * 60)

    # 确定回测时间范围
    if start_time is None:
        start_time = test_data.index.min().strftime('%Y-%m-%d %H:%M:%S')
    if end_time is None:
        end_time = test_data.index.max().strftime('%Y-%m-%d %H:%M:%S')

    print(f"回测时间: {start_time}{end_time}")

    # 获取特征
    feature_cols = [col for col in test_data.columns if 'LABEL' not in col]
    X_test = test_data[feature_cols]

    # 处理NaN值
    X_test = X_test.fillna(X_test.mean())
    X_test = X_test.fillna(0)

    # 生成预测
    print("生成预测信号...")
    predictions = model.predict(X_test)

    print(f"预测值统计:")
    print(f"  最小值: {predictions.min():.6f}")
    print(f"  最大值: {predictions.max():.6f}")
    print(f"  平均值: {predictions.mean():.6f}")
    print(f"  标准差: {predictions.std():.6f}")

    # 创建信号DataFrame
    signal_df = pd.DataFrame({
        'score': predictions
    }, index=X_test.index)

    # 简单回测逻辑(使用原始OHLCV数据)
    results = simple_backtest(
        signal_df,
        test_df,  # 使用原始OHLCV数据
        hold_bars=hold_bars,
        long_threshold=long_threshold,
        short_threshold=short_threshold,
        init_cash=init_cash
    )

    return results


def simple_backtest(
    signal_df: pd.DataFrame,
    test_data: pd.DataFrame,
    hold_bars: int = 5,
    long_threshold: float = 0.003,
    short_threshold: float = -0.003,
    init_cash: float = 1000000
):
    """
    简单回测逻辑

    Parameters:
    -----------
    signal_df : pd.DataFrame
        预测信号
    test_data : pd.DataFrame
        测试数据(包含close价格)
    hold_bars : int
        持仓K线数
    long_threshold : float
        做多阈值
    short_threshold : float
        做空阈值
    init_cash : float
        初始资金

    Returns:
    --------
    dict
        回测结果
    """
    cash = init_cash
    position = 0  # 0=空仓, >0=多仓数量, <0=空仓数量
    position_entry_idx = None
    position_entry_price = None

    trades = []
    portfolio_values = []

    for i, (idx, row) in enumerate(signal_df.iterrows()):
        if idx not in test_data.index:
            continue

        current_price = test_data.loc[idx, 'close']

        # 检查平仓
        if position != 0 and position_entry_idx is not None:
            bars_held = i - position_entry_idx
            if bars_held >= hold_bars:
                # 平仓
                if position > 0:
                    cash = position * current_price
                else:
                    cash = cash - abs(position) * current_price

                trades.append({
                    'time': idx,
                    'action': 'close',
                    'price': current_price,
                    'position': position
                })

                position = 0
                position_entry_idx = None
                position_entry_price = None

        # 检查开仓
        if position == 0:
            pred = row['score']

            if pred >= long_threshold:
                # 开多仓
                position = cash / current_price * 0.95  # 使用95%资金
                cash = 0
                position_entry_idx = i
                position_entry_price = current_price

                trades.append({
                    'time': idx,
                    'action': 'long',
                    'price': current_price,
                    'position': position
                })

            elif pred <= short_threshold:
                # 开空仓
                sell_value = cash * 0.95
                position = -sell_value / current_price
                position_entry_idx = i
                position_entry_price = current_price

                trades.append({
                    'time': idx,
                    'action': 'short',
                    'price': current_price,
                    'position': position
                })

        # 计算当前组合价值
        if position > 0:
            portfolio_value = cash + position * current_price
        elif position < 0:
            portfolio_value = cash + position * current_price
        else:
            portfolio_value = cash

        portfolio_values.append({
            'time': idx,
            'value': portfolio_value,
            'position': position
        })

    # 最后持仓平仓
    if position != 0:
        last_price = test_data.iloc[-1]['close']
        if position > 0:
            cash = position * last_price
        else:
            cash = cash - abs(position) * last_price
        portfolio_values[-1]['value'] = cash
        portfolio_values[-1]['position'] = 0

    # 计算收益率
    portfolio_df = pd.DataFrame(portfolio_values).set_index('time')
    portfolio_df['return'] = portfolio_df['value'].pct_change()

    return {
        'portfolio': portfolio_df,
        'trades': pd.DataFrame(trades)
    }


# ============================================
# 结果分析
# ============================================

def analyze_results(results: dict):
    """
    分析回测结果

    Parameters:
    -----------
    results : dict
        回测结果
    """
    print("\n" + "=" * 60)
    print("回测结果分析")
    print("=" * 60)

    portfolio_df = results['portfolio']
    trades_df = results['trades']

    # 计算收益率
    returns = portfolio_df['return'].dropna()

    if len(returns) > 0:
        # 计算累计收益率
        cumulative_returns = (1 + returns).cumprod() - 1

        final_return = cumulative_returns.iloc[-1]
        print(f"\n总收益率: {final_return:.4f} ({final_return*100:.2f}%)")
        print(f"平均收益率: {returns.mean():.6f}")
        print(f"收益率标准差: {returns.std():.6f}")

        if returns.std() > 0:
            sharpe = returns.mean() / returns.std() * np.sqrt(len(returns))
            print(f"夏普比率: {sharpe:.4f}")

        # 最大回撤
        cum_returns = (1 + returns).cumprod()
        peak = cum_returns.expanding().max()
        drawdown = (cum_returns - peak) / peak
        max_drawdown = drawdown.min()
        print(f"最大回撤: {max_drawdown:.4f} ({max_drawdown*100:.2f}%)")

    # 交易统计
    if len(trades_df) > 0:
        print(f"\n交易统计:")
        print(f"  总交易次数: {len(trades_df)}")

        long_trades = trades_df[trades_df['action'] == 'long']
        short_trades = trades_df[trades_df['action'] == 'short']

        print(f"  做多次数: {len(long_trades)}")
        print(f"  做空次数: {len(short_trades)}")


# ============================================
# 主函数
# ============================================

def main():
    """主函数"""
    # 解析命令行参数
    parser = argparse.ArgumentParser(description='股指期货1分钟K线交易系统')
    parser.add_argument('--data_path', type=str, default='demo/chap06/data.json',
                        help='JSON数据文件路径')
    parser.add_argument('--instrument', type=str, default='IF0',
                        help='标的代码')
    parser.add_argument('--train_ratio', type=float, default=0.7,
                        help='训练集比例')
    parser.add_argument('--n_estimators', type=int, default=100,
                        help='LightGBM树的数量')
    parser.add_argument('--hold_bars', type=int, default=5,
                        help='持仓K线数')
    parser.add_argument('--long_threshold', type=float, default=0.003,
                        help='做多阈值')
    parser.add_argument('--short_threshold', type=float, default=-0.003,
                        help='做空阈值')
    parser.add_argument('--init_cash', type=float, default=1000000,
                        help='初始资金')

    args = parser.parse_args()

    print("=" * 60)
    print("股指期货1分钟K线交易系统")
    print("=" * 60)
    print(f"\n配置参数:")
    print(f"  数据文件: {args.data_path}")
    print(f"  标的代码: {args.instrument}")
    print(f"  训练集比例: {args.train_ratio}")
    print(f"  树的数量: {args.n_estimators}")
    print(f"  持仓K线数: {args.hold_bars}")
    print(f"  做多阈值: {args.long_threshold}")
    print(f"  做空阈值: {args.short_threshold}")
    print(f"  初始资金: {args.init_cash:,.0f}")

    try:
        # 1. 加载数据
        print("\n" + "=" * 60)
        print("步骤1: 加载数据")
        print("=" * 60)

        loader = json_future_data_loader.JSONFutureDataLoader(args.data_path, args.instrument)
        df = loader.load()

        # 2. 划分数据
        print("\n" + "=" * 60)
        print("步骤2: 划分数据")
        print("=" * 60)

        train_size = int(len(df) * args.train_ratio)
        train_df = df.iloc[:train_size]
        test_df = df.iloc[train_size:]

        print(f"训练集: {len(train_df)} 条 ({train_df.index[0]}{train_df.index[-1]})")
        print(f"测试集: {len(test_df)} 条 ({test_df.index[0]}{test_df.index[-1]})")

        # 3. 创建特征和标签
        print("\n" + "=" * 60)
        print("步骤3: 创建特征和标签")
        print("=" * 60)

        handler = single_future_handler.SingleFutureHandler(train_df, feature_window=10, label_horizon=5)
        train_data = handler.create_features_labels()

        print(f"训练数据: {train_data.shape}")
        print(f"特征数量: {len([col for col in train_data.columns if 'LABEL' not in col])}")

        # 4. 训练模型
        model = train_model(train_data, n_estimators=args.n_estimators)

        # 5. 准备测试数据
        print("\n" + "=" * 60)
        print("步骤5: 准备测试数据")
        print("=" * 60)

        test_handler = single_future_handler.SingleFutureHandler(test_df, feature_window=10, label_horizon=5)
        test_data = test_handler.create_features_labels()

        print(f"测试数据: {test_data.shape}")

        # 6. 运行回测
        results = run_backtest(
            model=model,
            test_data=test_data,
            test_df=test_df,  # 传入原始OHLCV数据
            hold_bars=args.hold_bars,
            long_threshold=args.long_threshold,
            short_threshold=args.short_threshold,
            init_cash=args.init_cash
        )

        # 7. 分析结果
        analyze_results(results)

        print("\n" + "=" * 60)
        print("回测完成!")
        print("=" * 60)

    except FileNotFoundError as e:
        print(f"\n错误: 找不到文件 - {e}")
    except Exception as e:
        print(f"\n错误: {e}")
        import traceback
        traceback.print_exc()


if __name__ == '__main__':
    main()