均值回归策略--基于CSV文件

51 阅读4分钟
# -*- coding: utf-8 -*-
"""
基于CSV文件的均值回归策略
"""

import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from typing import Tuple, List, Dict, Optional
import warnings
warnings.filterwarnings('ignore')

from config import CSV_DATA_PATH


class MeanReversionStrategyCSV:
    """
    基于CSV文件的均值回归策略类
    """
    
    def __init__(self):
        """初始化CSV数据"""
        self.data = None
        self.load_csv_data()
    
    def load_csv_data(self):
        """加载CSV数据"""
        try:
            self.data = pd.read_csv(CSV_DATA_PATH)
            # 确保日期列格式正确
            self.data['trade_date'] = pd.to_datetime(self.data['trade_date'])
            print(f"CSV数据加载成功,共 {len(self.data)} 条记录")
        except Exception as e:
            print(f"CSV数据加载失败: {e}")
            raise
    
    def get_stock_data(self, ts_code: str, start_date: str, end_date: str) -> pd.DataFrame:
        """
        从CSV数据中筛选股票数据
        
        Args:
            ts_code (str): 股票代码
            start_date (str): 开始日期 (YYYY-MM-DD)
            end_date (str): 结束日期 (YYYY-MM-DD)
            
        Returns:
            pd.DataFrame: 股票数据
        """
        try:
            # 转换日期格式
            start_date = pd.to_datetime(start_date)
            end_date = pd.to_datetime(end_date)
            
            # 筛选数据
            mask = (
                (self.data['ts_code'] == ts_code) &
                (self.data['trade_date'] >= start_date) &
                (self.data['trade_date'] <= end_date)
            )
            
            df = self.data[mask].copy()
            df = df.sort_values('trade_date').reset_index(drop=True)
            
            return df
        except Exception as e:
            print(f"获取股票数据失败: {e}")
            return pd.DataFrame()
    
    def calculate_technical_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        计算技术指标
        
        Args:
            df (pd.DataFrame): 股票数据
            
        Returns:
            pd.DataFrame: 包含技术指标的数据
        """
        if df.empty:
            return df
        
        # 计算移动平均线
        df['ma5'] = df['close'].rolling(window=5).mean()
        df['ma10'] = df['close'].rolling(window=10).mean()
        df['ma20'] = df['close'].rolling(window=20).mean()
        df['ma60'] = df['close'].rolling(window=60).mean()
        
        # 计算布林带
        df['bb_middle'] = df['close'].rolling(window=20).mean()
        bb_std = df['close'].rolling(window=20).std()
        df['bb_upper'] = df['bb_middle'] + (bb_std * 2)
        df['bb_lower'] = df['bb_middle'] - (bb_std * 2)
        
        # 计算RSI
        delta = df['close'].diff()
        gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
        loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
        rs = gain / loss
        df['rsi'] = 100 - (100 / (1 + rs))
        
        # 计算价格偏离度
        df['price_deviation'] = (df['close'] - df['ma20']) / df['ma20'] * 100
        
        return df
    
    def generate_signals(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        生成交易信号
        
        Args:
            df (pd.DataFrame): 包含技术指标的数据
            
        Returns:
            pd.DataFrame: 包含交易信号的数据
        """
        if df.empty:
            return df
        
        df['signal'] = 0  # 0: 持有, 1: 买入, -1: 卖出
        
        # 均值回归信号条件
        for i in range(20, len(df)):
            current_price = df.iloc[i]['close']
            ma20 = df.iloc[i]['ma20']
            bb_lower = df.iloc[i]['bb_lower']
            bb_upper = df.iloc[i]['bb_upper']
            rsi = df.iloc[i]['rsi']
            price_deviation = df.iloc[i]['price_deviation']
            
            # 买入信号:价格低于布林带下轨或RSI超卖
            if (current_price < bb_lower or rsi < 30 or price_deviation < -5):
                df.iloc[i, df.columns.get_loc('signal')] = 1
            
            # 卖出信号:价格高于布林带上轨或RSI超买
            elif (current_price > bb_upper or rsi > 70 or price_deviation > 5):
                df.iloc[i, df.columns.get_loc('signal')] = -1
        
        return df
    
    def calculate_returns(self, df: pd.DataFrame) -> Dict:
        """
        计算策略收益
        
        Args:
            df (pd.DataFrame): 包含交易信号的数据
            
        Returns:
            Dict: 收益统计信息
        """
        if df.empty:
            return {}
        
        # 计算日收益率
        df['daily_return'] = df['close'].pct_change()
        
        # 计算策略收益率
        df['strategy_return'] = df['signal'].shift(1) * df['daily_return']
        
        # 计算累计收益
        df['cumulative_return'] = (1 + df['daily_return']).cumprod()
        df['strategy_cumulative_return'] = (1 + df['strategy_return']).cumprod()
        
        # 统计信息
        total_return = df['strategy_cumulative_return'].iloc[-1] - 1
        buy_hold_return = df['cumulative_return'].iloc[-1] - 1
        sharpe_ratio = df['strategy_return'].mean() / df['strategy_return'].std() * np.sqrt(252)
        max_drawdown = (df['strategy_cumulative_return'] / df['strategy_cumulative_return'].cummax() - 1).min()
        
        # 交易统计
        buy_signals = len(df[df['signal'] == 1])
        sell_signals = len(df[df['signal'] == -1])
        
        return {
            'total_return': total_return,
            'buy_hold_return': buy_hold_return,
            'excess_return': total_return - buy_hold_return,
            'sharpe_ratio': sharpe_ratio,
            'max_drawdown': max_drawdown,
            'buy_signals': buy_signals,
            'sell_signals': sell_signals,
            'total_trades': buy_signals + sell_signals
        }
    
    def run_strategy(self, ts_code: str, start_date: str, end_date: str) -> Dict:
        """
        运行均值回归策略
        
        Args:
            ts_code (str): 股票代码
            start_date (str): 开始日期 (YYYY-MM-DD)
            end_date (str): 结束日期 (YYYY-MM-DD)
            
        Returns:
            Dict: 策略结果
        """
        print(f"运行均值回归策略 - 股票代码: {ts_code}")
        print(f"时间范围: {start_date}{end_date}")
        
        # 获取数据
        df = self.get_stock_data(ts_code, start_date, end_date)
        if df.empty:
            return {"error": "未获取到数据"}
        
        # 计算技术指标
        df = self.calculate_technical_indicators(df)
        
        # 生成交易信号
        df = self.generate_signals(df)
        
        # 计算收益
        results = self.calculate_returns(df)
        
        # 添加数据信息
        results['ts_code'] = ts_code
        results['start_date'] = start_date
        results['end_date'] = end_date
        results['data_points'] = len(df)
        
        return results


def mean_reversion_csv_strategy(ts_code: str, start_date: str, end_date: str) -> Dict:
    """
    均值回归策略主函数 (CSV版本)
    
    Args:
        ts_code (str): 股票代码
        start_date (str): 开始日期 (YYYY-MM-DD)
        end_date (str): 结束日期 (YYYY-MM-DD)
        
    Returns:
        Dict: 策略结果
    """
    strategy = MeanReversionStrategyCSV()
    results = strategy.run_strategy(ts_code, start_date, end_date)
    return results


if __name__ == "__main__":
    # 测试代码
    ts_code = "000002.SZ"
    start_date = "2023-01-01"
    end_date = "2023-12-31"
    
    results = mean_reversion_csv_strategy(ts_code, start_date, end_date)
    print("策略结果:")
    for key, value in results.items():
        print(f"{key}: {value}")