"""
基于CSV文件的均值回归策略
"""
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from typing import Tuple, List, Dict, Optional
import warnings
warnings.filterwarnings('ignore')
from config import CSV_DATA_PATH
class MeanReversionStrategyCSV:
"""
基于CSV文件的均值回归策略类
"""
def __init__(self):
"""初始化CSV数据"""
self.data = None
self.load_csv_data()
def load_csv_data(self):
"""加载CSV数据"""
try:
self.data = pd.read_csv(CSV_DATA_PATH)
self.data['trade_date'] = pd.to_datetime(self.data['trade_date'])
print(f"CSV数据加载成功,共 {len(self.data)} 条记录")
except Exception as e:
print(f"CSV数据加载失败: {e}")
raise
def get_stock_data(self, ts_code: str, start_date: str, end_date: str) -> pd.DataFrame:
"""
从CSV数据中筛选股票数据
Args:
ts_code (str): 股票代码
start_date (str): 开始日期 (YYYY-MM-DD)
end_date (str): 结束日期 (YYYY-MM-DD)
Returns:
pd.DataFrame: 股票数据
"""
try:
start_date = pd.to_datetime(start_date)
end_date = pd.to_datetime(end_date)
mask = (
(self.data['ts_code'] == ts_code) &
(self.data['trade_date'] >= start_date) &
(self.data['trade_date'] <= end_date)
)
df = self.data[mask].copy()
df = df.sort_values('trade_date').reset_index(drop=True)
return df
except Exception as e:
print(f"获取股票数据失败: {e}")
return pd.DataFrame()
def calculate_technical_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
"""
计算技术指标
Args:
df (pd.DataFrame): 股票数据
Returns:
pd.DataFrame: 包含技术指标的数据
"""
if df.empty:
return df
df['ma5'] = df['close'].rolling(window=5).mean()
df['ma10'] = df['close'].rolling(window=10).mean()
df['ma20'] = df['close'].rolling(window=20).mean()
df['ma60'] = df['close'].rolling(window=60).mean()
df['bb_middle'] = df['close'].rolling(window=20).mean()
bb_std = df['close'].rolling(window=20).std()
df['bb_upper'] = df['bb_middle'] + (bb_std * 2)
df['bb_lower'] = df['bb_middle'] - (bb_std * 2)
delta = df['close'].diff()
gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
rs = gain / loss
df['rsi'] = 100 - (100 / (1 + rs))
df['price_deviation'] = (df['close'] - df['ma20']) / df['ma20'] * 100
return df
def generate_signals(self, df: pd.DataFrame) -> pd.DataFrame:
"""
生成交易信号
Args:
df (pd.DataFrame): 包含技术指标的数据
Returns:
pd.DataFrame: 包含交易信号的数据
"""
if df.empty:
return df
df['signal'] = 0
for i in range(20, len(df)):
current_price = df.iloc[i]['close']
ma20 = df.iloc[i]['ma20']
bb_lower = df.iloc[i]['bb_lower']
bb_upper = df.iloc[i]['bb_upper']
rsi = df.iloc[i]['rsi']
price_deviation = df.iloc[i]['price_deviation']
if (current_price < bb_lower or rsi < 30 or price_deviation < -5):
df.iloc[i, df.columns.get_loc('signal')] = 1
elif (current_price > bb_upper or rsi > 70 or price_deviation > 5):
df.iloc[i, df.columns.get_loc('signal')] = -1
return df
def calculate_returns(self, df: pd.DataFrame) -> Dict:
"""
计算策略收益
Args:
df (pd.DataFrame): 包含交易信号的数据
Returns:
Dict: 收益统计信息
"""
if df.empty:
return {}
df['daily_return'] = df['close'].pct_change()
df['strategy_return'] = df['signal'].shift(1) * df['daily_return']
df['cumulative_return'] = (1 + df['daily_return']).cumprod()
df['strategy_cumulative_return'] = (1 + df['strategy_return']).cumprod()
total_return = df['strategy_cumulative_return'].iloc[-1] - 1
buy_hold_return = df['cumulative_return'].iloc[-1] - 1
sharpe_ratio = df['strategy_return'].mean() / df['strategy_return'].std() * np.sqrt(252)
max_drawdown = (df['strategy_cumulative_return'] / df['strategy_cumulative_return'].cummax() - 1).min()
buy_signals = len(df[df['signal'] == 1])
sell_signals = len(df[df['signal'] == -1])
return {
'total_return': total_return,
'buy_hold_return': buy_hold_return,
'excess_return': total_return - buy_hold_return,
'sharpe_ratio': sharpe_ratio,
'max_drawdown': max_drawdown,
'buy_signals': buy_signals,
'sell_signals': sell_signals,
'total_trades': buy_signals + sell_signals
}
def run_strategy(self, ts_code: str, start_date: str, end_date: str) -> Dict:
"""
运行均值回归策略
Args:
ts_code (str): 股票代码
start_date (str): 开始日期 (YYYY-MM-DD)
end_date (str): 结束日期 (YYYY-MM-DD)
Returns:
Dict: 策略结果
"""
print(f"运行均值回归策略 - 股票代码: {ts_code}")
print(f"时间范围: {start_date} 到 {end_date}")
df = self.get_stock_data(ts_code, start_date, end_date)
if df.empty:
return {"error": "未获取到数据"}
df = self.calculate_technical_indicators(df)
df = self.generate_signals(df)
results = self.calculate_returns(df)
results['ts_code'] = ts_code
results['start_date'] = start_date
results['end_date'] = end_date
results['data_points'] = len(df)
return results
def mean_reversion_csv_strategy(ts_code: str, start_date: str, end_date: str) -> Dict:
"""
均值回归策略主函数 (CSV版本)
Args:
ts_code (str): 股票代码
start_date (str): 开始日期 (YYYY-MM-DD)
end_date (str): 结束日期 (YYYY-MM-DD)
Returns:
Dict: 策略结果
"""
strategy = MeanReversionStrategyCSV()
results = strategy.run_strategy(ts_code, start_date, end_date)
return results
if __name__ == "__main__":
ts_code = "000002.SZ"
start_date = "2023-01-01"
end_date = "2023-12-31"
results = mean_reversion_csv_strategy(ts_code, start_date, end_date)
print("策略结果:")
for key, value in results.items():
print(f"{key}: {value}")