Pandas 实现
import pandas as pd
data_initial=pd.read_excel('fx.xlsx')
data_initial['TRAN_DT']=pd.to_datetime(data_initial['TRAN_DT'])
data_initial=data_initial.sort_values(by='TRAN_DT').reset_index(drop=True)
data_initial.head()
data_out=data_initial.copy()
customer_list=data_initial['CUST_ID'].unique()
for i in customer_list:
df=data_initial[data_initial['CUST_ID']==i]
currency_list=df['NAT_CURR_CD'].unique()
for j in currency_list:
df_curr=df[df['NAT_CURR_CD']==j]
df_Buy=df_curr[df_curr['BUY_SELL_IND']=='B']
df_Sell=df_curr[df_curr['BUY_SELL_IND']=='S']
for sell_index, sell_row in df_Sell.iterrows():
bought=df_Buy[df_Buy['TRAN_DT']<sell_row['TRAN_DT']]
if sell_row['TRAN_AMT']>= bought['TRAN_AMT'].values.sum():
residual=sell_row['TRAN_AMT']-bought['TRAN_AMT'].values.sum()
hold=((residual*(sell_row['TRAN_DT']-data_initial.loc[0,'TRAN_DT'])).total_seconds() / 3600 /24)+(((sell_row['TRAN_DT']-bought['TRAN_DT'])*bought['TRAN_AMT']).dt.total_seconds() / 3600 /24).sum()
hold=hold/sell_row['TRAN_AMT']
data_out.loc[sell_index,'HOLD_PERIOD']=hold
df_Buy.drop(bought.index, inplace=True)
else:
bought1=bought[bought['TRAN_AMT'].cumsum()<sell_row['TRAN_AMT']]
bought2=bought[~(bought['TRAN_AMT'].cumsum()<sell_row['TRAN_AMT'])]
residual=sell_row['TRAN_AMT']-bought1['TRAN_AMT'].values.sum()
hold=((residual*(sell_row['TRAN_DT']-bought2.iloc[0]['TRAN_DT'])).total_seconds() / 3600 /24)+(((sell_row['TRAN_DT']-bought1['TRAN_DT'])*bought1['TRAN_AMT']).dt.total_seconds() / 3600 /24).sum()
hold=hold/sell_row['TRAN_AMT']
data_out.loc[sell_index,'HOLD_PERIOD']=hold
df_Buy.drop(bought1.index, inplace=True)
df_Buy.loc[bought2.index[0],'TRAN_AMT']=df_Buy.loc[bought2.index[0],'TRAN_AMT']-residual
data_out
data_out.groupby(['CUST_ID']).agg({'HOLD_PERIOD': ['median', 'mean','min','max']})
Pyspark实现
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, expr, sum as spark_sum, when
from pyspark.sql.window import Window
from pyspark.sql.types import DoubleType
spark = SparkSession.builder \
.appName("Calculate Hold Period") \
.getOrCreate()
data_initial = spark.read.format("com.crealytics.spark.excel") \
.option("header", "true") \
.load("fx.xlsx")
data_initial = data_initial.withColumn("TRAN_DT", col("TRAN_DT").cast("timestamp"))
data_initial = data_initial.orderBy("TRAN_DT").withColumn("index", expr("monotonically_increasing_id()")).drop("index")
data_out = data_initial.select("*")
customer_list = [row.CUST_ID for row in data_initial.select("CUST_ID").distinct().collect()]
windowSpec = Window.partitionBy("CUST_ID", "NAT_CURR_CD").orderBy("TRAN_DT")
def calculate_hold_period(df):
for i in customer_list:
df_cust = df.filter(col("CUST_ID") == i)
currency_list = [row.NAT_CURR_CD for row in df_cust.select("NAT_CURR_CD").distinct().collect()]
for j in currency_list:
df_curr = df_cust.filter(col("NAT_CURR_CD") == j)
df_Buy = df_curr.filter(col("BUY_SELL_IND") == 'B')
df_Sell = df_curr.filter(col("BUY_SELL_IND") == 'S')
for row in df_Sell.collect():
bought = df_Buy.filter(col("TRAN_DT") < row.TRAN_DT)
residual = row.TRAN_AMT - bought.select(spark_sum("TRAN_AMT").over(windowSpec).alias("cumsum")).filter(col("cumsum") <= row.TRAN_AMT).agg(expr("coalesce(sum(TRAN_AMT), 0)").alias("sum")).collect()[0]["sum"]
if row.TRAN_AMT >= residual:
hold = ((row.TRAN_DT - df_initial.first().TRAN_DT).days * residual + (expr("(sum(UNIX_TIMESTAMP(TRAN_DT) - UNIX_TIMESTAMP(bought.TRAN_DT)) / 86400)").over(windowSpec) * bought.TRAN_AMT).agg(expr("coalesce(sum(value), 0)").alias("sum")).collect()[0]["sum"]) / row.TRAN_AMT
data_out = data_out.withColumn("HOLD_PERIOD", when(col("TRAN_DT") == row.TRAN_DT, hold).otherwise(col("HOLD_PERIOD")))
df_Buy = df_Buy.filter(~col("TRAN_DT").isin(bought.select("TRAN_DT").collect()))
else:
bought1 = bought.filter(expr("sum(TRAN_AMT)").over(windowSpec) < row.TRAN_AMT)
bought2 = bought.filter(~expr("sum(TRAN_AMT)").over(windowSpec) < row.TRAN_AMT)
residual = row.TRAN_AMT - bought1.select(spark_sum("TRAN_AMT").over(windowSpec).alias("cumsum")).filter(col("cumsum") <= row.TRAN_AMT).agg(expr("coalesce(sum(TRAN_AMT), 0)").alias("sum")).collect()[0]["sum"]
hold = ((row.TRAN_DT - bought2.first().TRAN_DT).days * residual + (expr("(sum(UNIX_TIMESTAMP(TRAN_DT) - UNIX_TIMESTAMP(bought1.TRAN_DT)) / 86400)").over(windowSpec) * bought1.TRAN_AMT).agg(expr("coalesce(sum(value), 0)").alias("sum")).collect()[0]["sum"]) / row.TRAN_AMT
data_out = data_out.withColumn("HOLD_PERIOD", when(col("TRAN_DT") == row.TRAN_DT, hold).otherwise(col("HOLD_PERIOD")))
df_Buy = df_Buy.filter(~col("TRAN_DT").isin(bought1.select("TRAN_DT").collect()))
df_Buy = df_Buy.withColumn("TRAN_AMT", when(col("TRAN_DT") == bought2.first().TRAN_DT, col("TRAN_AMT") - residual).otherwise(col("TRAN_AMT")))
return data_out
data_out = calculate_hold_period(data_initial)
data_out.show()
data_out.groupBy("CUST_ID").agg(expr("percentile_approx(HOLD_PERIOD, 0.5)").alias("median_hold_period"),
expr("avg(HOLD_PERIOD)").alias("mean_hold_period"),
expr("min(HOLD_PERIOD)").alias("min_hold_period"),
expr("max(HOLD_PERIOD)").alias("max_hold_period")).show()
spark.stop()