交易数据的持仓时间计算

184 阅读3分钟

Pandas 实现


import pandas as pd
data_initial=pd.read_excel('fx.xlsx')
data_initial['TRAN_DT']=pd.to_datetime(data_initial['TRAN_DT'])

#要对data_initial按照交易日期从早到晚排序,然后reindex,以保证后续的代码顺利运行
data_initial=data_initial.sort_values(by='TRAN_DT').reset_index(drop=True)

data_initial.head()
#再建一个表用于存储最后数据
data_out=data_initial.copy()
# 统计都有哪些人,后面会按照不同的人进行循环
customer_list=data_initial['CUST_ID'].unique()
# 算每笔交易的持仓时间
for i in customer_list:
    df=data_initial[data_initial['CUST_ID']==i]
    currency_list=df['NAT_CURR_CD'].unique()
    for j in currency_list:
        df_curr=df[df['NAT_CURR_CD']==j]
        df_Buy=df_curr[df_curr['BUY_SELL_IND']=='B']
        df_Sell=df_curr[df_curr['BUY_SELL_IND']=='S']
        for sell_index, sell_row in df_Sell.iterrows():#这些index因为从未修改过,所以在data_initial中也通用
            bought=df_Buy[df_Buy['TRAN_DT']<sell_row['TRAN_DT']]#选择日期在卖出交易前的买入交易
            if sell_row['TRAN_AMT']>= bought['TRAN_AMT'].values.sum():
                residual=sell_row['TRAN_AMT']-bought['TRAN_AMT'].values.sum()
                hold=((residual*(sell_row['TRAN_DT']-data_initial.loc[0,'TRAN_DT'])).total_seconds() / 3600 /24)+(((sell_row['TRAN_DT']-bought['TRAN_DT'])*bought['TRAN_AMT']).dt.total_seconds() / 3600 /24).sum()
                hold=hold/sell_row['TRAN_AMT'] #hold的单位为/天,除掉的卖出量是加权法的分母
                data_out.loc[sell_index,'HOLD_PERIOD']=hold
                df_Buy.drop(bought.index, inplace=True)
                # print('situation1')
            else:
                #卖出的量中,如果某次买入交易的量能卖完则这次买入交易放入bought1;否则放入bought2
                bought1=bought[bought['TRAN_AMT'].cumsum()<sell_row['TRAN_AMT']]
                bought2=bought[~(bought['TRAN_AMT'].cumsum()<sell_row['TRAN_AMT'])]
                residual=sell_row['TRAN_AMT']-bought1['TRAN_AMT'].values.sum()
                
                hold=((residual*(sell_row['TRAN_DT']-bought2.iloc[0]['TRAN_DT'])).total_seconds() / 3600 /24)+(((sell_row['TRAN_DT']-bought1['TRAN_DT'])*bought1['TRAN_AMT']).dt.total_seconds() / 3600 /24).sum()
                hold=hold/sell_row['TRAN_AMT']
                data_out.loc[sell_index,'HOLD_PERIOD']=hold
                df_Buy.drop(bought1.index, inplace=True)
                df_Buy.loc[bought2.index[0],'TRAN_AMT']=df_Buy.loc[bought2.index[0],'TRAN_AMT']-residual
                # print('situation2')
data_out
# 每个人的持仓时间特征
data_out.groupby(['CUST_ID']).agg({'HOLD_PERIOD': ['median', 'mean','min','max']})

Pyspark实现

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, expr, sum as spark_sum, when
from pyspark.sql.window import Window
from pyspark.sql.types import DoubleType

# Initialize Spark session
spark = SparkSession.builder \
    .appName("Calculate Hold Period") \
    .getOrCreate()

# Read Excel file into Spark DataFrame
data_initial = spark.read.format("com.crealytics.spark.excel") \
    .option("header", "true") \
    .load("fx.xlsx")

# Convert TRAN_DT column to datetime
data_initial = data_initial.withColumn("TRAN_DT", col("TRAN_DT").cast("timestamp"))

# Sort DataFrame by TRAN_DT and reset index
data_initial = data_initial.orderBy("TRAN_DT").withColumn("index", expr("monotonically_increasing_id()")).drop("index")

# Create a copy of data_initial for final output
data_out = data_initial.select("*")

# Get unique customer IDs
customer_list = [row.CUST_ID for row in data_initial.select("CUST_ID").distinct().collect()]

# Define a window specification to calculate cumulative sum within each partition
windowSpec = Window.partitionBy("CUST_ID", "NAT_CURR_CD").orderBy("TRAN_DT")

# Function to calculate hold period
def calculate_hold_period(df):
    for i in customer_list:
        df_cust = df.filter(col("CUST_ID") == i)
        currency_list = [row.NAT_CURR_CD for row in df_cust.select("NAT_CURR_CD").distinct().collect()]
        for j in currency_list:
            df_curr = df_cust.filter(col("NAT_CURR_CD") == j)
            df_Buy = df_curr.filter(col("BUY_SELL_IND") == 'B')
            df_Sell = df_curr.filter(col("BUY_SELL_IND") == 'S')
            for row in df_Sell.collect():
                bought = df_Buy.filter(col("TRAN_DT") < row.TRAN_DT)
                residual = row.TRAN_AMT - bought.select(spark_sum("TRAN_AMT").over(windowSpec).alias("cumsum")).filter(col("cumsum") <= row.TRAN_AMT).agg(expr("coalesce(sum(TRAN_AMT), 0)").alias("sum")).collect()[0]["sum"]
                if row.TRAN_AMT >= residual:
                    hold = ((row.TRAN_DT - df_initial.first().TRAN_DT).days * residual + (expr("(sum(UNIX_TIMESTAMP(TRAN_DT) - UNIX_TIMESTAMP(bought.TRAN_DT)) / 86400)").over(windowSpec) * bought.TRAN_AMT).agg(expr("coalesce(sum(value), 0)").alias("sum")).collect()[0]["sum"]) / row.TRAN_AMT
                    data_out = data_out.withColumn("HOLD_PERIOD", when(col("TRAN_DT") == row.TRAN_DT, hold).otherwise(col("HOLD_PERIOD")))
                    df_Buy = df_Buy.filter(~col("TRAN_DT").isin(bought.select("TRAN_DT").collect()))
                else:
                    bought1 = bought.filter(expr("sum(TRAN_AMT)").over(windowSpec) < row.TRAN_AMT)
                    bought2 = bought.filter(~expr("sum(TRAN_AMT)").over(windowSpec) < row.TRAN_AMT)
                    residual = row.TRAN_AMT - bought1.select(spark_sum("TRAN_AMT").over(windowSpec).alias("cumsum")).filter(col("cumsum") <= row.TRAN_AMT).agg(expr("coalesce(sum(TRAN_AMT), 0)").alias("sum")).collect()[0]["sum"]
                    hold = ((row.TRAN_DT - bought2.first().TRAN_DT).days * residual + (expr("(sum(UNIX_TIMESTAMP(TRAN_DT) - UNIX_TIMESTAMP(bought1.TRAN_DT)) / 86400)").over(windowSpec) * bought1.TRAN_AMT).agg(expr("coalesce(sum(value), 0)").alias("sum")).collect()[0]["sum"]) / row.TRAN_AMT
                    data_out = data_out.withColumn("HOLD_PERIOD", when(col("TRAN_DT") == row.TRAN_DT, hold).otherwise(col("HOLD_PERIOD")))
                    df_Buy = df_Buy.filter(~col("TRAN_DT").isin(bought1.select("TRAN_DT").collect()))
                    df_Buy = df_Buy.withColumn("TRAN_AMT", when(col("TRAN_DT") == bought2.first().TRAN_DT, col("TRAN_AMT") - residual).otherwise(col("TRAN_AMT")))
    return data_out

# Calculate hold period
data_out = calculate_hold_period(data_initial)

# Display the DataFrame
data_out.show()

# Calculate hold period statistics per customer
data_out.groupBy("CUST_ID").agg(expr("percentile_approx(HOLD_PERIOD, 0.5)").alias("median_hold_period"),
                                 expr("avg(HOLD_PERIOD)").alias("mean_hold_period"),
                                 expr("min(HOLD_PERIOD)").alias("min_hold_period"),
                                 expr("max(HOLD_PERIOD)").alias("max_hold_period")).show()

# Stop Spark session
spark.stop()