基于LightGBM的建模实战(附Python全码)

588 阅读25分钟

公众号:尤而小屋
作者:Peter
编辑:Peter

大家好,我是Peter~

本文是UCI数据集建模的第3篇,第一篇是数据的探索性分析EDA部分,第二篇是基于LightGBM模型的baseline。

本文是第3篇,主要是对LightGBM模型的优化,最终准确率提升2%+

导入库

导入建模所需要的各种库:

In [1]:

import pandas as pd 
import numpy as np
pd.set_option('display.max_columns', 100)
from IPython.display import display_html


import plotly_express as px
import plotly.graph_objects as go

import matplotlib
import matplotlib.pyplot as plt
plt.rcParams["font.sans-serif"]=["SimHei"] # 设置字体
plt.rcParams["axes.unicode_minus"]=False # 解决“-”负号的乱码问题

import seaborn as sns
%matplotlib inline 

import missingno as ms 
import gc

from datetime import datetime 
from sklearn.model_selection import train_test_split,StratifiedKFold,GridSearchCV
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.decomposition import PCA
from imblearn.under_sampling import ClusterCentroids
from imblearn.over_sampling import KMeansSMOTE, SMOTE
from sklearn.model_selection import KFold

from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, auc
from sklearn.metrics import roc_auc_score,precision_recall_curve, confusion_matrix,classification_report

# Classifiers
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn import tree
from pydotplus import graph_from_dot_data
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier,AdaBoostClassifier
from catboost import CatBoostClassifier
import lightgbm as lgb
import xgboost as xgb

from scipy import stats

import warnings 
warnings.filterwarnings("ignore")

2 导入数据

In [2]:

df = pd.read_csv("UCI.csv")

df.head()

Out[2]:

IDLIMIT_BALSEXEDUCATIONMARRIAGEAGEPAY_0PAY_2PAY_3PAY_4PAY_5PAY_6BILL_AMT1BILL_AMT2BILL_AMT3BILL_AMT4BILL_AMT5BILL_AMT6PAY_AMT1PAY_AMT2PAY_AMT3PAY_AMT4PAY_AMT5PAY_AMT6default.payment.next.month
0120000.02212422-1-1-2-23913.03102.0689.00.00.00.00.0689.00.00.00.00.01
12120000.022226-1200022682.01725.02682.03272.03455.03261.00.01000.01000.01000.00.02000.01
2390000.02223400000029239.014027.013559.014331.014948.015549.01518.01500.01000.01000.01000.05000.00
3450000.02213700000046990.048233.049291.028314.028959.029547.02000.02019.01200.01100.01069.01000.00
4550000.012157-10-10008617.05670.035835.020940.019146.019131.02000.036681.010000.09000.0689.0679.00

3 数据基本信息

1、整体数据量

整理的数据量大小:30000条记录,25个字段信息

In [3]:

df.shape

Out[3]:

(30000, 25)

2、数据字段信息

In [4]:

df.columns  # 全部的字段名

Out[4]:

Index(['ID', 'LIMIT_BAL', 'SEX', 'EDUCATION', 'MARRIAGE', 'AGE', 'PAY_0',       'PAY_2', 'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6', 'BILL_AMT1', 'BILL_AMT2',       'BILL_AMT3', 'BILL_AMT4', 'BILL_AMT5', 'BILL_AMT6', 'PAY_AMT1',       'PAY_AMT2', 'PAY_AMT3', 'PAY_AMT4', 'PAY_AMT5', 'PAY_AMT6',       'default.payment.next.month'],
      dtype='object')

不同的字段类型统计:

In [5]:

df.dtypes

Out[5]:

ID                              int64
LIMIT_BAL                     float64
SEX                             int64
EDUCATION                       int64
MARRIAGE                        int64
AGE                             int64
PAY_0                           int64
PAY_2                           int64
PAY_3                           int64
PAY_4                           int64
PAY_5                           int64
PAY_6                           int64
BILL_AMT1                     float64
BILL_AMT2                     float64
BILL_AMT3                     float64
BILL_AMT4                     float64
BILL_AMT5                     float64
BILL_AMT6                     float64
PAY_AMT1                      float64
PAY_AMT2                      float64
PAY_AMT3                      float64
PAY_AMT4                      float64
PAY_AMT5                      float64
PAY_AMT6                      float64
default.payment.next.month      int64
dtype: object

In [6]:

pd.value_counts(df.dtypes)  # 统计不同类型的个数

Out[6]:

float64    13
int64      12
Name: count, dtype: int64

从结果中能够看到全部是数值型字段,几乎各占一半。最后一个字段default.payment.next.month是我们最终的目标字段。

字段名称的具体解释:

  • ID:ID唯一值
  • LIMIT_BAL:可透支金额(新台币计算,包含个人或者家庭)
  • SEX:性别:1-男, 2-女
  • EDUCATION:1-研究生;2-本科;3-高中;4-其他;0/5/6-未知
  • MARRIAGE:婚姻状态;1-已婚,2-单身;3-其他
  • AGE:年龄
  • PAY_0:2005年9月的还款状态(-2-未消费,-1-按时还款, 1-延迟一个月还款, 2-延迟两个月还款,...,8-延迟8个月还款, 9-延迟9个月还款)
  • PAY_2:2005年8月的还款状态(同上)
  • PAY_3:2005年7月的还款状态(同上)
  • PAY_4:2005年6月的还款状态(同上)
  • PAY_5:2005年5月的还款状态(同上)
  • PAY_6:2005年4月的还款状态(同上)
  • BILL_AMT1:2005年9月的账单金额
  • BILL_AMT2:2005年8月的账单金额
  • BILL_AMT3:2005年7月的账单金额
  • BILL_AMT4:2005年6月的账单金额
  • BILL_AMT5:2005年5月的账单金额
  • BILL_AMT6:2005年4月的账单金额
  • PAY_AMT1:2005年9月之前的付款金额
  • PAY_AMT2:2005年8月之前的付款金额
  • PAY_AMT3:2005年7月之前的付款金额
  • PAY_AMT4:2005年6月之前的付款金额
  • PAY_AMT5:2005年5月之前的付款金额
  • PAY_AMT6:2005年4月之前的付款金额
  • default.payment.next.month:最终目标变量,下个月还款违约情况(1-是,逾期;0-否,未逾期)

说明内容:

  1. PAY_ATM如果低于银行规定的最低还款额,则视为违约;
  2. PAY_ATM如果大于上月账单金额BILL_AMT,则视为及时还;
  3. PAY_AMT如果大于最低还款额但低于上月账单金额,则视为延迟还款。

3、数据的描述统计信息

In [7]:

df.describe().T  # 字段较多,转置后显示更直观  

Out[7]:

countmeanstdmin25%50%75%max
ID30000.015000.5000008660.3983741.07500.7515000.522500.2530000.0
LIMIT_BAL30000.0167484.322667129747.66156710000.050000.00140000.0240000.001000000.0
SEX30000.01.6037330.4891291.01.002.02.002.0
EDUCATION30000.01.8531330.7903490.01.002.02.006.0
MARRIAGE30000.01.5518670.5219700.01.002.02.003.0
AGE30000.035.4855009.21790421.028.0034.041.0079.0
PAY_030000.0-0.0167001.123802-2.0-1.000.00.008.0
PAY_230000.0-0.1337671.197186-2.0-1.000.00.008.0
PAY_330000.0-0.1662001.196868-2.0-1.000.00.008.0
PAY_430000.0-0.2206671.169139-2.0-1.000.00.008.0
PAY_530000.0-0.2662001.133187-2.0-1.000.00.008.0
PAY_630000.0-0.2911001.149988-2.0-1.000.00.008.0
BILL_AMT130000.051223.33090073635.860576-165580.03558.7522381.567091.00964511.0
BILL_AMT230000.049179.07516771173.768783-69777.02984.7521200.064006.25983931.0
BILL_AMT330000.047013.15480069349.387427-157264.02666.2520088.560164.751664089.0
BILL_AMT430000.043262.94896764332.856134-170000.02326.7519052.054506.00891586.0
BILL_AMT530000.040311.40096760797.155770-81334.01763.0018104.550190.50927171.0
BILL_AMT630000.038871.76040059554.107537-339603.01256.0017071.049198.25961664.0
PAY_AMT130000.05663.58050016563.2803540.01000.002100.05006.00873552.0
PAY_AMT230000.05921.16350023040.8704020.0833.002009.05000.001684259.0
PAY_AMT330000.05225.68150017606.9614700.0390.001800.04505.00896040.0
PAY_AMT430000.04826.07686715666.1597440.0296.001500.04013.25621000.0
PAY_AMT530000.04799.38763315278.3056790.0252.501500.04031.50426529.0
PAY_AMT630000.05215.50256717777.4657750.0117.751500.04000.00528666.0
default.payment.next.month30000.00.2212000.4150620.00.000.00.001.0

4、字段整体信息

In [8]:

df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 30000 entries, 0 to 29999
Data columns (total 25 columns):
 #   Column                      Non-Null Count  Dtype  
---  ------                      --------------  -----  
 0   ID                          30000 non-null  int64  
 1   LIMIT_BAL                   30000 non-null  float64
 2   SEX                         30000 non-null  int64  
 3   EDUCATION                   30000 non-null  int64  
 4   MARRIAGE                    30000 non-null  int64  
 5   AGE                         30000 non-null  int64  
 6   PAY_0                       30000 non-null  int64  
 7   PAY_2                       30000 non-null  int64  
 8   PAY_3                       30000 non-null  int64  
 9   PAY_4                       30000 non-null  int64  
 10  PAY_5                       30000 non-null  int64  
 11  PAY_6                       30000 non-null  int64  
 12  BILL_AMT1                   30000 non-null  float64
 13  BILL_AMT2                   30000 non-null  float64
 14  BILL_AMT3                   30000 non-null  float64
 15  BILL_AMT4                   30000 non-null  float64
 16  BILL_AMT5                   30000 non-null  float64
 17  BILL_AMT6                   30000 non-null  float64
 18  PAY_AMT1                    30000 non-null  float64
 19  PAY_AMT2                    30000 non-null  float64
 20  PAY_AMT3                    30000 non-null  float64
 21  PAY_AMT4                    30000 non-null  float64
 22  PAY_AMT5                    30000 non-null  float64
 23  PAY_AMT6                    30000 non-null  float64
 24  default.payment.next.month  30000 non-null  int64  
dtypes: float64(13), int64(12)
memory usage: 5.7 MB

为了数据处理方便,将原始的default.payment.next.month字段重新命名成Label:

In [9]:

df.rename(columns={"default.payment.next.month":"Label"},inplace=True)

4 缺失值

4.1 缺失值统计

统计每个字段的缺失值:

In [10]:

df.isnull().sum().sort_values(ascending=False)

Out[10]:

ID           0
BILL_AMT2    0
PAY_AMT6     0
PAY_AMT5     0
PAY_AMT4     0
PAY_AMT3     0
PAY_AMT2     0
PAY_AMT1     0
BILL_AMT6    0
BILL_AMT5    0
BILL_AMT4    0
BILL_AMT3    0
BILL_AMT1    0
LIMIT_BAL    0
PAY_6        0
PAY_5        0
PAY_4        0
PAY_3        0
PAY_2        0
PAY_0        0
AGE          0
MARRIAGE     0
EDUCATION    0
SEX          0
Label        0
dtype: int64

In [11]:

# 缺失值个数
total = df.isnull().sum().sort_values(ascending=False)

In [12]:

# 缺失值比例
percent = (df.isnull().sum() / df.isnull().count() * 100).sort_values(ascending=False) 

percent

Out[12]:

ID           0.0
BILL_AMT2    0.0
PAY_AMT6     0.0
PAY_AMT5     0.0
PAY_AMT4     0.0
PAY_AMT3     0.0
PAY_AMT2     0.0
PAY_AMT1     0.0
BILL_AMT6    0.0
BILL_AMT5    0.0
BILL_AMT4    0.0
BILL_AMT3    0.0
BILL_AMT1    0.0
LIMIT_BAL    0.0
PAY_6        0.0
PAY_5        0.0
PAY_4        0.0
PAY_3        0.0
PAY_2        0.0
PAY_0        0.0
AGE          0.0
MARRIAGE     0.0
EDUCATION    0.0
SEX          0.0
Label        0.0
dtype: float64

将个数和比例的合并,显示完整的缺失值信息:

In [13]:

pd.concat([total, percent],axis=1,keys=["Total","Percent"]).T

Out[13]:

4.2 缺失值可视化

In [14]:

ms.bar(df,color="blue")                                                     

plt.show()

另一种写法:

In [15]:

# ms.matrix(df, labels=True,label_rotation=45)
# plt.show()

下面进行不同字段的详细数据探索过程:

In [16]:

df.columns

Out[16]:

Index(['ID', 'LIMIT_BAL', 'SEX', 'EDUCATION', 'MARRIAGE', 'AGE', 'PAY_0',       'PAY_2', 'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6', 'BILL_AMT1', 'BILL_AMT2',       'BILL_AMT3', 'BILL_AMT4', 'BILL_AMT5', 'BILL_AMT6', 'PAY_AMT1',       'PAY_AMT2', 'PAY_AMT3', 'PAY_AMT4', 'PAY_AMT5', 'PAY_AMT6', 'Label'],
      dtype='object')

ID字段对建模无效,直接删除:

In [17]:

df.drop("ID",inplace=True,axis=1) 

5 统计信息

5.1 Personal Information

查看用户的信用额度、学历、婚姻状态、年龄等字段的统计信息:

In [18]:

df[['LIMIT_BAL', 'EDUCATION', 'MARRIAGE', 'AGE']].describe()

Out[18]:

LIMIT_BALEDUCATIONMARRIAGEAGE
count30000.00000030000.00000030000.00000030000.000000
mean167484.3226671.8531331.55186735.485500
std129747.6615670.7903490.5219709.217904
min10000.0000000.0000000.00000021.000000
25%50000.0000001.0000001.00000028.000000
50%140000.0000002.0000002.00000034.000000
75%240000.0000002.0000002.00000041.000000
max1000000.0000006.0000003.00000079.000000

In [19]:

df["EDUCATION"].value_counts().sort_values(ascending=False)

Out[19]:

EDUCATION
2    14030
1    10585
3     4917
5      280
4      123
6       51
0       14
Name: count, dtype: int64

用户的学历中,出现最多的是:本科生EDUCATION=2

In [20]:

df["MARRIAGE"].value_counts().sort_values(ascending=False)        

Out[20]:

MARRIAGE
2    15964
1    13659
3      323
0       54
Name: count, dtype: int64

用户的婚姻状态中,出现最多的是MARRIAGE=2,已婚人群。

5.2 LIMIT_BAL

LIMIT_BAL的分布

In [21]:

df["LIMIT_BAL"].value_counts().sort_values(ascending=False)

Out[21]:

LIMIT_BAL
50000.0      3365
20000.0      1976
30000.0      1610
80000.0      1567
200000.0     1528
             ... 
800000.0        2
1000000.0       1
327680.0        1
760000.0        1
690000.0        1
Name: count, Length: 81, dtype: int64

可以看到信用额度最为频繁的是50,000

In [22]:

plt.figure(figsize = (14,6))
plt.title('Density Plot of LIMIT_BAL')

sns.set_color_codes("pastel")
sns.distplot(df['LIMIT_BAL'],kde=True,bins=200)

plt.show()  

5.3 PAY0-PAY6

每月之前的对应还款状态:

In [23]:

df[["PAY_0","PAY_2","PAY_3","PAY_4","PAY_5","PAY_6"]].describe()

Out[23]:

PAY_0PAY_2PAY_3PAY_4PAY_5PAY_6
count30000.00000030000.00000030000.00000030000.00000030000.00000030000.000000
mean-0.016700-0.133767-0.166200-0.220667-0.266200-0.291100
std1.1238021.1971861.1968681.1691391.1331871.149988
min-2.000000-2.000000-2.000000-2.000000-2.000000-2.000000
25%-1.000000-1.000000-1.000000-1.000000-1.000000-1.000000
50%0.0000000.0000000.0000000.0000000.0000000.000000
75%0.0000000.0000000.0000000.0000000.0000000.000000
max8.0000008.0000008.0000008.0000008.0000008.000000

不同还款状态的对比:

In [24]:

repay = df[['PAY_0', 'PAY_2', 'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6', 'Label']]

repay = pd.melt(repay, 
                id_vars="Label",
                var_name="Payment Status",
                value_name="Delay(Month)"
               )
repay.head()

Out[24]:

LabelPayment StatusDelay(Month)
01PAY_02
11PAY_0-1
20PAY_00
30PAY_00
40PAY_0-1

In [25]:

fig = px.box(repay, x="Payment Status", y="Delay(Month)",color="Label")

fig.show()

5.4 BILL_AMT1-BILL_AMT6

当月的账单金额

In [26]:

df[['BILL_AMT1', 'BILL_AMT2', 'BILL_AMT3', 'BILL_AMT4', 'BILL_AMT5', 'BILL_AMT6']].describe()

Out[26]:

BILL_AMT1BILL_AMT2BILL_AMT3BILL_AMT4BILL_AMT5BILL_AMT6
count30000.00000030000.0000003.000000e+0430000.00000030000.00000030000.000000
mean51223.33090049179.0751674.701315e+0443262.94896740311.40096738871.760400
std73635.86057671173.7687836.934939e+0464332.85613460797.15577059554.107537
min-165580.000000-69777.000000-1.572640e+05-170000.000000-81334.000000-339603.000000
25%3558.7500002984.7500002.666250e+032326.7500001763.0000001256.000000
50%22381.50000021200.0000002.008850e+0419052.00000018104.50000017071.000000
75%67091.00000064006.2500006.016475e+0454506.00000050190.50000049198.250000
max964511.000000983931.0000001.664089e+06891586.000000927171.000000961664.000000

是否违约客户的对比:

In [27]:

df.columns

Out[27]:

Index(['LIMIT_BAL', 'SEX', 'EDUCATION', 'MARRIAGE', 'AGE', 'PAY_0', 'PAY_2',       'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6', 'BILL_AMT1', 'BILL_AMT2',       'BILL_AMT3', 'BILL_AMT4', 'BILL_AMT5', 'BILL_AMT6', 'PAY_AMT1',       'PAY_AMT2', 'PAY_AMT3', 'PAY_AMT4', 'PAY_AMT5', 'PAY_AMT6', 'Label'],
      dtype='object')

In [28]:

BILL_AMTS = ['BILL_AMT1', 'BILL_AMT2', 'BILL_AMT3', 'BILL_AMT4', 'BILL_AMT5', 'BILL_AMT6']

plt.figure(figsize=(12,6))

for i, col in enumerate(BILL_AMTS):
    plt.subplot(2,3,i+1)
    sns.kdeplot(df.loc[(df["Label"] == 0),col], label="NO DEFAULT", color="red",shade=True)
    sns.kdeplot(df.loc[(df["Label"] == 1),col], label="DEFAULT", color="blue",shade=True)
    
    plt.xlim(-40000, 200000)
    plt.ylabel("")
    plt.xlabel(col, fontsize=12)
    plt.legend()
    plt.tight_layout()
    
plt.show()

5.5 PAY_AMT1-PAY_AMT6

每月之前的对应付款金额

In [29]:

df[['PAY_AMT1', 'PAY_AMT2', 'PAY_AMT3', 'PAY_AMT4', 'PAY_AMT5', 'PAY_AMT6']].describe()

Out[29]:

PAY_AMT1PAY_AMT2PAY_AMT3PAY_AMT4PAY_AMT5PAY_AMT6
count30000.0000003.000000e+0430000.0000030000.00000030000.00000030000.000000
mean5663.5805005.921163e+035225.681504826.0768674799.3876335215.502567
std16563.2803542.304087e+0417606.9614715666.15974415278.30567917777.465775
min0.0000000.000000e+000.000000.0000000.0000000.000000
25%1000.0000008.330000e+02390.00000296.000000252.500000117.750000
50%2100.0000002.009000e+031800.000001500.0000001500.0000001500.000000
75%5006.0000005.000000e+034505.000004013.2500004031.5000004000.000000
max873552.0000001.684259e+06896040.00000621000.000000426529.000000528666.000000

In [30]:

PAY_AMTS = ['PAY_AMT1', 'PAY_AMT2', 'PAY_AMT3', 'PAY_AMT4', 'PAY_AMT5', 'PAY_AMT6']

plt.figure(figsize=(12,6))

for i, col in enumerate(PAY_AMTS):
    plt.subplot(2,3,i+1)
    sns.kdeplot(df.loc[(df["Label"] == 0),col], label="NO DEFAULT", color="red", shade=True)
    sns.kdeplot(df.loc[(df["Label"] == 1),col], label="DEFAULT", color="blue", shade=True)
    
    plt.xlim(-10000, 70000)
    plt.ylabel("")
    plt.xlabel(col, fontsize=12)
    plt.legend()
    plt.tight_layout()
    
plt.show()

6 Label

是否发生违约(default.payment.next.month重命名为Label)的人数进行对比:

In [31]:

df["Label"].value_counts()

Out[31]:

Label
0    23364
1     6636
Name: count, dtype: int64

In [32]:

label = df["Label"].value_counts()
df_label = pd.DataFrame(label).reset_index()  

df_label

Out[32]:

Labelcount
0023364
116636

In [33]:

# plt.figure(figsize = (6,6))
# plt.title('Default = 0 & Not Default = 1')         
# sns.set_color_codes("pastel")

# sns.barplot(x = 'Label', y="count", data=df_label) 
# locs, labels = plt.xticks() 
# plt.show()

In [34]:

plt.figure(figsize = (5,5))
graph = sns.countplot(x="Label", data=df, palette=["red","blue"])

i = 0     

for p in graph.patches:
    print(type(p))
    h = p.get_height()
    percentage = round( 100 * df["Label"].value_counts()[i] / len(df),2)
    str_percentage = f"{percentage} %"
    graph.text(p.get_x()+p.get_width()/2., h - 100, str_percentage, ha="center")  
    
    i += 1
    
plt.title("class distribution")
plt.xticks([0,1], ["Non-Default","Default"])
plt.xlabel("Default Payment Next Month",fontsize=12)
plt.ylabel("Number of Clients")

plt.show()

可以看到二者是很不均衡的。

In [35]:

# value_counts = df['Label'].value_counts()

# # 计算每个值的百分比
# percentages = value_counts / len(df)
# # 使用matplotlib绘制柱状图
# plt.bar(value_counts.index, value_counts.values)    

# # 在柱状图上添加百分比标签 
# for i, v in enumerate(percentages.values):                     
#     plt.text(i, v + 1, f'{v*100:.2f}%', ha='center',va="bottom")  
    
# # 设置xy轴标签、标题
# plt.title("Class Distribution")
# plt.xticks([0,1], ["Non-Default","Default"])
# plt.xlabel("Default Payment Next Month",fontsize=12)
# plt.ylabel("Number of Clients")

# plt.show()

In [36]:

value_counts = df['Label'].value_counts()  

# 计算每个值的百分比
percentages = value_counts / len(df)
# 使用matplotlib绘制柱状图
plt.bar(value_counts.index, value_counts.values)    

# 在柱状图上添加百分比标签 
for i, v in enumerate(percentages.values):
    plt.text(i, v + 1, f'{v*100:.2f}%', ha='center',va="bottom")
    
# 设置xy轴标签、标题
plt.title("Class Distribution")
plt.xticks([0,1], ["Non-Default","Default"])
plt.xlabel("Default Payment Next Month",fontsize=12)
plt.ylabel("Number of Clients")

plt.show()

7 相关性分析

7.1 相关性热力图

In [37]:

numeric = ['LIMIT_BAL','AGE','PAY_0','PAY_2',
           'PAY_3','PAY_4','PAY_5','PAY_6',
           'BILL_AMT1','BILL_AMT2','BILL_AMT3',
           'BILL_AMT4','BILL_AMT5','BILL_AMT6']  # 全部数值型字段
numeric

Out[37]:

['LIMIT_BAL', 'AGE', 'PAY_0', 'PAY_2', 'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6', 'BILL_AMT1', 'BILL_AMT2', 'BILL_AMT3', 'BILL_AMT4', 'BILL_AMT5', 'BILL_AMT6']

In [38]:

corr = df[numeric].corr()
corr.head()

Out[38]:

LIMIT_BALAGEPAY_0PAY_2PAY_3PAY_4PAY_5PAY_6BILL_AMT1BILL_AMT2BILL_AMT3BILL_AMT4BILL_AMT5BILL_AMT6
LIMIT_BAL1.0000000.144713-0.271214-0.296382-0.286123-0.267460-0.249411-0.2351950.2854300.2783140.2832360.2939880.2955620.290389
AGE0.1447131.000000-0.039447-0.050148-0.053048-0.049722-0.053826-0.0487730.0562390.0542830.0537100.0513530.0493450.047613
PAY_0-0.271214-0.0394471.0000000.6721640.5742450.5388410.5094260.4745530.1870680.1898590.1797850.1791250.1806350.176980
PAY_2-0.296382-0.0501480.6721641.0000000.7665520.6620670.6227800.5755010.2348870.2352570.2241460.2222370.2213480.219403
PAY_3-0.286123-0.0530480.5742450.7665521.0000000.7773590.6867750.6326840.2084730.2372950.2274940.2272020.2251450.222327

In [39]:

mask = np.triu(np.ones_like(corr, dtype=bool))

plt.figure(figsize=(12,10))
sns.heatmap(corr,
            mask=mask,
            vmin=-1,
            vmax=1,
            center=0,
            square=True,
            cbar_kws={'shrink': .5}, 
            annot=True, 
            annot_kws={'size': 10},
            cmap="Blues")

plt.show()

7.2 变量两两关系

In [40]:

plt.figure(figsize=(12,10))

pair_plot = sns.pairplot(df[['BILL_AMT1','BILL_AMT2','BILL_AMT3','BILL_AMT4','BILL_AMT5','BILL_AMT6','Label']], 
                         hue='Label',
                         diag_kind='kde', 
                         corner=True)

pair_plot._legend.remove()

8 正态检验-QQ图

为了检查我们的数据是否为高斯分布,我们使用一种称为分位数-分位数(QQ图)图的图形方法进行定性评估。

在QQ图中,独立变量的分位数与正态分布的预期分位数相对应。如果变量是正态分布的,QQ图中的点应该沿着45度对角线排列。

In [41]:

sns.set_color_codes('pastel')  # 设置样式
fig, axs = plt.subplots(5, 3, figsize=(18,18))  # 图像大小和子图设置

numeric = ['LIMIT_BAL','AGE','BILL_AMT1','BILL_AMT2','BILL_AMT3','BILL_AMT4','BILL_AMT5',
           'BILL_AMT6','PAY_AMT1','PAY_AMT2','PAY_AMT3','PAY_AMT4','PAY_AMT5','PAY_AMT6']

i, j = 0, 0
for f in numeric:
    if j == 3:
        j = 0
        i = i + 1
    stats.probplot(df[f],  # 绘图数据:某个字段的全部取值
                   dist='norm', # 标准化
                   sparams=(df[f].mean(), df[f].std()), 
                   plot=axs[i,j])  # 子图位置
    
    axs[i,j].get_lines()[0].set_marker('.') 
    
    axs[i,j].grid() 
    axs[i,j].get_lines()[1].set_linewidth(3.0)
    j = j+1

fig.tight_layout()
axs[4,2].set_visible(False)
plt.show()

9 数据预处理

9.1 分类型数据处理

针对分类型数据的处理:

In [42]:

df["EDUCATION"].value_counts()

Out[42]:

EDUCATION
2    14030
1    10585
3     4917
5      280
4      123
6       51
0       14
Name: count, dtype: int64

In [43]:

df["GRAD_SCHOOL"] = (df["EDUCATION"] == 1).astype("category")
df["UNIVERSITY"] = (df["EDUCATION"] == 2).astype("category")
df["HIGH_SCHOOL"] = (df["EDUCATION"] == 1).astype("category")

df.drop("EDUCATION",axis=1,inplace=True)

In [44]:

df['MALE'] = (df['SEX'] == 1).astype('category')
df.drop('SEX', axis=1, inplace=True)

In [45]:

df['MARRIED'] = (df['MARRIAGE'] == 1).astype('category')
df.drop('MARRIAGE', axis=1, inplace=True)

9.2 数据切分

In [46]:

# 划分数据

y = df['Label']
X = df.drop('Label', axis=1, inplace=False)

根据y中的类别比例进行切分:

In [47]:

# 切分数据

X_train_raw, X_test_raw, y_train, y_test = train_test_split(X, y, random_state=24, stratify=y)

9.3 特征归一化/标准化

最值归一化:

In [48]:

mm = MinMaxScaler()

X_train_norm = X_train_raw.copy()
X_test_norm = X_test_raw.copy()

In [49]:

# LIMIT_BAL + AGE

X_train_norm['LIMIT_BAL'] = mm.fit_transform(X_train_raw['LIMIT_BAL'].values.reshape(-1, 1))
X_test_norm['LIMIT_BAL'] = mm.transform(X_test_raw['LIMIT_BAL'].values.reshape(-1, 1))
X_train_norm['AGE'] = mm.fit_transform(X_train_raw['AGE'].values.reshape(-1, 1))
X_test_norm['AGE'] = mm.transform(X_test_raw['AGE'].values.reshape(-1, 1))

In [50]:

pay_list = ["PAY_0","PAY_2","PAY_3","PAY_4","PAY_5","PAY_6"]

for pay in pay_list:
    X_train_norm[pay] = mm.fit_transform(X_train_raw[pay].values.reshape(-1, 1))
    X_test_norm[pay] = mm.transform(X_test_raw[pay].values.reshape(-1, 1))

In [51]:

for i in range(1,7):
    X_train_norm['BILL_AMT' + str(i)] = mm.fit_transform(X_train_raw['BILL_AMT' + str(i)].values.reshape(-1, 1))
    X_test_norm['BILL_AMT' + str(i)] = mm.transform(X_test_raw['BILL_AMT' + str(i)].values.reshape(-1, 1))
    X_train_norm['PAY_AMT' + str(i)] = mm.fit_transform(X_train_raw['PAY_AMT' + str(i)].values.reshape(-1, 1))
    X_test_norm['PAY_AMT' + str(i)] = mm.transform(X_test_raw['PAY_AMT' + str(i)].values.reshape(-1, 1))

标准化过程:

In [52]:

ss = StandardScaler()
X_train_std = X_train_raw.copy()
X_test_std = X_test_raw.copy()

X_train_std['LIMIT_BAL'] = ss.fit_transform(X_train_raw['LIMIT_BAL'].values.reshape(-1, 1))
X_test_std['LIMIT_BAL'] = ss.transform(X_test_raw['LIMIT_BAL'].values.reshape(-1, 1))

X_train_std['AGE'] = ss.fit_transform(X_train_raw['AGE'].values.reshape(-1, 1))
X_test_std['AGE'] = ss.transform(X_test_raw['AGE'].values.reshape(-1, 1))

In [53]:

pay_list = ["PAY_0","PAY_2","PAY_3","PAY_4","PAY_5","PAY_6"]

for pay in pay_list:
    X_train_std[pay] = mm.fit_transform(X_train_raw[pay].values.reshape(-1, 1))
    X_test_std[pay] = mm.transform(X_test_raw[pay].values.reshape(-1, 1))

In [54]:

for i in range(1,7):
    X_train_std['BILL_AMT' + str(i)] = ss.fit_transform(X_train_raw['BILL_AMT' + str(i)].values.reshape(-1, 1))
    X_test_std['BILL_AMT' + str(i)] = ss.transform(X_test_raw['BILL_AMT' + str(i)].values.reshape(-1, 1))
    X_train_std['PAY_AMT' + str(i)] = ss.fit_transform(X_train_raw['PAY_AMT' + str(i)].values.reshape(-1, 1))
    X_test_std['PAY_AMT' + str(i)] = ss.transform(X_test_raw['PAY_AMT' + str(i)].values.reshape(-1, 1))

In [55]:

sns.set_color_codes('deep')
numeric = ['LIMIT_BAL','AGE','BILL_AMT1','BILL_AMT2','BILL_AMT3','BILL_AMT4','BILL_AMT5',           'BILL_AMT6','PAY_AMT1','PAY_AMT2','PAY_AMT3','PAY_AMT4','PAY_AMT5','PAY_AMT6']

fig, axs = plt.subplots(1, 2, figsize=(24,6))

sns.boxplot(data=X_train_norm[numeric], ax=axs[0])  
axs[0].set_title('Boxplot of normalized numeric features')
axs[0].set_xticklabels(labels=numeric, rotation=25)
axs[0].set_xlabel(' ')

sns.boxplot(data=X_train_std[numeric], ax=axs[1])
axs[1].set_title('Boxplot of standardized numeric features')
axs[1].set_xticklabels(labels=numeric, rotation=25)
axs[1].set_xlabel(' ')

fig.tight_layout()
plt.show()

9.4 数据降维

In [56]:

pc = len(X_train_norm.columns.values) # 25
pca = PCA(n_components=pc)  # 指定主成分个数
pca.fit(X_train_norm)

sns.reset_orig()
sns.set_color_codes('pastel') # 设置绘图颜色
plt.figure(figsize = (8,4)) # 图的大小
plt.grid()  # 网格设置
plt.title('Explained Variance of Principal Components') # 标题设置
plt.plot(pca.explained_variance_ratio_, marker='o')  # 绘制单个主成分的方差解释比例
plt.plot(np.cumsum(pca.explained_variance_ratio_), marker='o')  # 累计解释方差

plt.legend(["Individual Explained Variance", "Cumulative Explained Variance"])  # 图例设置
plt.xlabel('Principal Component Indexes')  # x-y轴标题
plt.ylabel('Explained Variance Ratio')  
plt.tight_layout()  # 调整布局,更紧凑
plt.axvline(12, 0, ls='--')  # 设置虚线x=12
plt.show()  # 显示图像

代码的各部分含义如下:

  1. pc = len(X_train_norm.columns.values) # 25:计算训练集的特征数量,这里的结果是25。
  2. pca = PCA(n_components=pc) # 指定主成分个数:创建一个PCA对象,指定主成分的数量为pc,即25。
  3. pca.fit(X_train_norm):对训练集X_train_norm进行PCA拟合。
  4. sns.reset_orig()sns.set_color_codes('pastel'):这两行代码是使用seaborn库来设置绘图的颜色。reset_orig()会重置颜色到默认设置,set_color_codes('pastel')会将颜色设置为柔和色调。
  5. plt.figure(figsize = (8,4)):创建一个新的图形,设置其大小为8x4。
  6. plt.grid():在图形上显示网格。
  7. plt.title('Explained Variance of Principal Components'):设置图形的标题为“主成分的方差解释”。
  8. plt.plot(pca.explained_variance_ratio_, marker='o'):绘制单个主成分的方差解释比例。
  9. plt.plot(np.cumsum(pca.explained_variance_ratio_), marker='o'):绘制累积方差解释比例。
  10. plt.legend(["Individual Explained Variance", "Cumulative Explained Variance"]):为图形添加图例,分别表示单个主成分的方差解释和累积方差解释。
  11. plt.xlabel('Principal Component Indexes'):设置x轴的标签为“主成分索引”。
  12. plt.ylabel('Explained Variance Ratio'):设置y轴的标签为“方差解释比例”。
  13. plt.tight_layout():自动调整图形布局,使其看起来紧凑。
  14. plt.axvline(12, 0, ls='--'):在x=12的位置画一条从y=0到y=1的虚线。这可能是为了标示某个特定的主成分。
  15. plt.show():显示图形。

根据PCA的定义,主成分的顺序是不重要的,它们只按照其方差大小进行排序。

9.4.1 计算累计解释方差

In [57]:

cumsum = np.cumsum(pca.explained_variance_ratio_)  # 计算累计解释性方差
cumsum

Out[57]:

array([0.44924877, 0.6321187 , 0.8046163 , 0.87590932, 0.92253799,
       0.95438576, 0.96762706, 0.97773098, 0.9842774 , 0.98824928,
       0.99088299, 0.99280785, 0.99444757, 0.99576128, 0.99690533,
       0.99781622, 0.99844676, 0.99890236, 0.99924315, 0.99955744,
       0.9997182 , 0.99983861, 0.99992993, 1.        , 1.        ])

In [58]:

indexes = ['PC' + str(i) for i in range(1, pc+1)]

cumsum_df = pd.DataFrame(data=cumsum, index=indexes, columns=['var1'])

cumsum_df.head()

Out[58]:

var1
PC10.449249
PC20.632119
PC30.804616
PC40.875909
PC50.922538

In [59]:

# 保留4位小数
cumsum_df['var2'] = pd.Series([round(val, 4) for val in cumsum_df['var1']], 
                              index = cumsum_df.index)
# 转成百分比
cumsum_df['Cumulative Explained Variance'] = pd.Series(["{0:.2f}%".format(val * 100) for val in cumsum_df['var2']], 
                                                       index = cumsum_df.index)

cumsum_df.head()

Out[59]:

var1var2Cumulative Explained Variance
PC10.4492490.449244.92%
PC20.6321190.632163.21%
PC30.8046160.804680.46%
PC40.8759090.875987.59%
PC50.9225380.922592.25%

In [60]:

cumsum_df = cumsum_df.drop(['var1','var2'], axis=1, inplace=False)
cumsum_df.T.iloc[:,:15]

Out[60]:

PC1PC2PC3PC4PC5PC6PC7PC8PC9PC10PC11PC12PC13PC14PC15
Cumulative Explained Variance44.92%63.21%80.46%87.59%92.25%95.44%96.76%97.77%98.43%98.82%99.09%99.28%99.44%99.58%99.69%

9.4.2 指定主成分个数12

In [61]:

pc = 12
pca = PCA(n_components=pc)
pca.fit(X_train_norm)

X_train = pd.DataFrame(pca.transform(X_train_norm))
X_test = pd.DataFrame(pca.transform(X_test_norm))

# 列名设置
X_train.columns = ['PC' + str(i) for i in range(1, pc+1)]
X_test.columns = ['PC' + str(i) for i in range(1, pc+1)]

X_train.head()

Out[61]:

PC1PC2PC3PC4PC5PC6PC7PC8PC9PC10PC11PC12
0-0.234536-0.3105560.8124430.5833860.0864860.193288-0.045393-0.0595470.031720-0.001745-0.004745-0.003148
1-0.781139-0.520069-0.198721-0.239243-0.055078-0.059366-0.0909880.049630-0.0702820.0595280.0338930.003430
2-0.787315-0.1311430.747751-0.1878880.166084-0.2723720.157680-0.0083140.252000-0.0746370.0299090.058873
3-0.6361740.390267-0.599050-0.132501-0.213672-0.049675-0.114476-0.0064380.0583770.0357400.0523770.030388
4-0.790242-0.497498-0.205812-0.2270870.045253-0.137781-0.179086-0.0101230.0197000.0081930.0019960.011253

10 数据不均衡处理

10.1 目标变量类别数统计

In [62]:

count = pd.value_counts(y_train)                               
count

Out[62]:

Label
0    17523
1     4977
Name: count, dtype: int64

In [63]:

percentage = pd.value_counts(y_train, normalize=True) 
percentage

Out[63]:

Label
0    0.7788
1    0.2212
Name: proportion, dtype: float64

In [64]:

class_count_df = pd.DataFrame(data=count.values,
                              index=['Non-defaulters', 'Defaulters'], 
                              columns=['Number'] 
                             )
class_count_df

Out[64]:

Number
Non-defaulters17523
Defaulters4977

In [65]:

class_count_df["Percentage"] = percentage.values 
class_count_df

Out[65]:

NumberPercentage
Non-defaulters175230.7788
Defaulters49770.2212

In [66]:

class_count_df["Percentage"] = class_count_df["Percentage"].apply(lambda x: "{:.2%}".format(x))
class_count_df

Out[66]:

NumberPercentage
Non-defaulters1752377.88%
Defaulters497722.12%

基于自定义函数的实现的小数转成百分比:

def to_percent(x):  
    return "{:.2%}".format(x)

df[col] = df[col].apply(to_percent)

10.2 方法1:基于聚类中心的欠采样Cluster Centroid Undersampling

具体实施过程:

In [67]:

oversample = ClusterCentroids(random_state=24)  # 设置对象

# 针对X_train和y_train 的欠采样
X_train_cc, y_train_cc = oversample.fit_resample(X_train, y_train)  

In [68]:

count_cc = pd.value_counts(y_train_cc)  # 换成采样后的数据y_train_cc                              
percentage_cc = pd.value_counts(y_train_cc, normalize=True) 
class_count_df_cc = pd.DataFrame(data=count_cc.values,
                              index=['Non-defaulters', 'Defaulters'], 
                              columns=['Number']
                             )

class_count_df_cc["Percentage"] = percentage_cc.values
class_count_df_cc["Percentage"] = class_count_df_cc["Percentage"].apply(lambda x: "{:.2%}".format(x))
class_count_df_cc

Out[68]:

NumberPercentage
Non-defaulters497750.00%
Defaulters497750.00%

此时我们发现y=0和y=1是均衡的,保证数据和少数类样本相同。

10.3 方法2:合成少数累过采样技术Synthetic Minority Oversampling Technique(SMOTE)

SMOTE(Synthetic Minority Oversampling Technique)是一种过采样方法,旨在解决数据集不平衡问题。它通过对少数类样本进行插值生成合成样本,从而增加少数类样本的数量。SMOTE的主要步骤包括:

  • 对于每一个少数类样本,计算其与所有其他少数类样本之间的距离,并找到其K个最近邻居。
  • 从这K个最近邻居中随机选择一个样本,并计算该样本与当前样本的差异。
  • 根据差异比例,生成一个新的合成样本,该样本位于两个样本之间的连线上。
  • 重复上述步骤,生成指定数量的合成样本。

SMOTE算法的关键是通过插值生成合成样本,从而使得少数类样本的特征空间得到扩展。这有助于模型更好地探索和学习少数类的特征,提高模型的性能。

具体实施过程:

In [69]:

oversample = SMOTE(random_state=24)

X_train_smote, y_train_smote = oversample.fit_resample(X_train, y_train)

In [70]:

count_smote = pd.value_counts(y_train_smote)  # y_train_smote                              
percentage_smote = pd.value_counts(y_train_smote, normalize=True) 
class_count_df_smote = pd.DataFrame(data=count_smote.values,
                              index=['Non-defaulters', 'Defaulters'], 
                              columns=['Number']
                             )
class_count_df_smote["Percentage"] = percentage_smote.values
class_count_df_smote["Percentage"] = class_count_df_smote["Percentage"].apply(lambda x: "{:.2%}".format(x))
class_count_df_smote

Out[70]:

NumberPercentage
Non-defaulters1752350.00%
Defaulters1752350.00%

此时我们发现,少数类的样本经过过采样变得和多数类样本数相同。

10.4 方法3:结合K-Means聚类 + SMOTE

具体实施过程:

In [71]:

oversample = KMeansSMOTE(cluster_balance_threshold=0.00001, random_state=24)

X_train_ksmote, y_train_ksmote = oversample.fit_resample(X_train, y_train)

In [72]:

count_ksmote = pd.value_counts(y_train_ksmote)  # y_train_ksmote                              
percentage_ksmote = pd.value_counts(y_train_ksmote, normalize=True) 
class_count_df_ksmote = pd.DataFrame(data=count_ksmote.values,
                              index=['Non-defaulters', 'Defaulters'], 
                              columns=['Number']
                             )
class_count_df_ksmote["Percentage"] = percentage_ksmote.values
class_count_df_ksmote["Percentage"] = class_count_df_ksmote["Percentage"].apply(lambda x: "{:.2%}".format(x))
class_count_df_ksmote

Out[72]:

NumberPercentage
Non-defaulters1752850.01%
Defaulters1752349.99%

10.5 对比三种方法

In [73]:

display(class_count_df)
display(class_count_df_cc)
display(class_count_df_smote)
display(class_count_df_ksmote)
NumberPercentage
Non-defaulters1752377.88%
Defaulters497722.12%
NumberPercentage
Non-defaulters497750.00%
Defaulters497750.00%
NumberPercentage
Non-defaulters1752350.00%
Defaulters1752350.00%
NumberPercentage
Non-defaulters1752850.01%
Defaulters1752349.99%

原始数据中类别是极不均衡;经过3种采样方法处理后,基于聚类中心和SMOTE采样的方法能够类别数相同。

但是如果使用K-Means SMOTE方法采样后,两个类别的比例稍有差别。

11 模型评估

11.1 交叉验证

基于 k-fold cross-validation的交叉验证:将数据分为k折,前面k-1用于训练,剩下1折用于验证。

分类模型评价指标

1、混淆矩阵

 Predicted Negative  Predicted Positive  Actual Negative  TN  FP  Actual Positive  FN  TP \begin{array}{ccc} & \text { Predicted Negative } & \text { Predicted Positive } \\ \hline \text { Actual Negative } & \text { TN } & \text { FP } \\ \text { Actual Positive } & \text { FN } & \text { TP } \end{array}

2、准确率

Accuracy=TP+TNTP+FP+TN+FN{ Accuracy }=\frac{T P+T N}{T P+F P+T N+F N}

3、精确率

 Precision=TPTP+FP\text { Precision} =\frac{T P}{T P+F P}

4、召回率

 Recall=TPTP+FN\text { Recall} =\frac{T P}{T P+F N}

5、F1_score

F1score=21r+1p=2rpr+p{ F1_{score} }=\frac{2}{\frac{1}{r}+\frac{1}{p}}=\frac{2 r p}{r+p}

12 基于LightGBM建立二分类模型(使用非均衡数据)

使用不同的训练集的标签数据进行模型训练:

# pca降维后的数据y_train,
# 基于聚类中心的欠采样y_train_cc
# 基于SMOTE的过采样y_train_smote
# 基于聚类+SMOTE的采样y_train_smote
# y_train,y_train_cc,y_train_smote,y_train_ksmote

12.1 baseline-基础模型

In [74]:

X_train  # 降维与归一化后的特征数据      

训练集中的目标值:

In [75]:

y_train              

Out[75]:

24832    0
969      0
20833    1
21670    0
25380    0
        ..
20828    1
897      1
16452    0
3888     0
5743     0
Name: Label, Length: 22500, dtype: int64

模型训练:

In [76]:

# 模型训练

lgb_clf = lgb.LGBMClassifier()
lgb_clf.fit(X_train, y_train)
[LightGBM] [Info] Number of positive: 4977, number of negative: 17523
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000607 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 3060
[LightGBM] [Info] Number of data points in the train set: 22500, number of used features: 12
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.221200 -> initscore=-1.258687
[LightGBM] [Info] Start training from score -1.258687

模型预测:

In [77]:

# 模型预测

y_pred = lgb_clf.predict(X_test)
y_pred

Out[77]:

array([1, 0, 0, ..., 0, 0, 0], dtype=int64)

基于baseline的准确率acc:

In [78]:

acc = accuracy_score(y_test, y_pred)

print(acc)
0.8130666666666667

12.2 优化模型(交叉验证+超参数调优)

In [79]:

# 定义LightGBM分类器
lgb_clf_new = lgb.LGBMClassifier()

12.2.1 超参数范围

LightGBM算法一般对以下超参数进行调优:

  • num_leaves(叶子节点数):控制树的深度,影响模型的复杂度和训练速度。较小的值会导致更深的树,更大的值会减少树的深度。
  • learning_rate(学习率):控制每次迭代时的权重更新步长,影响模型的收敛速度和泛化能力。较小的值会导致更慢的收敛速度,较大的值可能导致过拟合。
  • n_estimators(树的数量):控制模型的复杂度,影响模型的拟合能力和训练时间。较大的值会增加模型的复杂度,但也可能导致过拟合。
  • max_depth(最大深度):控制树的最大深度,影响模型的复杂度和训练速度。较小的值会导致更深的树,更大的值会减少树的深度。
  • min_child_samples(最小叶子节点样本数):控制一个叶子节点在分裂前所需的最小样本数,影响模型的复杂度和过拟合程度。较小的值会导致更多的叶子节点,更大的值会减少叶子节点的数量。
  • subsample(随机采样比例):控制每个子节点上随机选择的特征比例,影响模型的训练速度和泛化能力。较小的值会导致更多的特征被选择,较大的值会减少特征的选择数量。
  • colsample_bytree(列采样比例):控制每棵树在分裂时随机选择的特征比例,影响模型的训练速度和泛化能力。较小的值会导致更多的特征被选择,较大的值会减少特征的选择数量。
  • reg_alpha(L1正则化系数):控制L1正则化的强度,影响模型的稀疏性和泛化能力。较小的值会导致更强的正则化,较大的值会减少正则化的强度。
  • reg_lambda(L2正则化系数):控制L2正则化的强度,影响模型的稀疏性和泛化能力。较小的值会导致更强的正则化,较大的值会减少正则化的强度。

blog.csdn.net/deephub/art…

In [80]:

# 设置超参数网格搜索范围

param_grid = {
    'num_leaves': [31, 63, 127], 
    'learning_rate': [0.01, 0.02, 0.03, 0.04, 0.05],
    'n_estimators': [100, 200, 300],
    'max_depth': [4,5,6,7]
}

12.2.2 使用K折交叉验证

In [81]:

# 使用k折交叉验证和网格搜索进行超参数调优

# 5折交叉验证实例对象
# cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)  

cv = 5

# 网格搜索
grid_search = GridSearchCV(lgb_clf_new, #  lgb模型
                           param_grid,  # 参数
                           scoring='accuracy',  # 评估指标 
                           cv=cv,  # 5折交叉验证
                           n_jobs=-1  
                          ) 

In [82]:

# 网格搜索对象的训练

grid_search.fit(X_train, y_train)

确定最佳参数组合:

In [83]:

grid_search.best_params_   

Out[83]:

{'learning_rate': 0.02, 'max_depth': 5, 'n_estimators': 300, 'num_leaves': 63}

12.2.3 建立新模型

基于网格搜索得到的最佳参数组合建立新的模型:

In [84]:

new_model = lgb.LGBMClassifier(learning_rate=0.02, 
                                max_depth=5, 
                                n_estimators=300, 
                                num_leaves=63)

new_model.fit(X_train,y_train)

12.2.4 新模型评估

In [85]:

y_pred_new = new_model.predict(X_test) 
y_pred_new

Out[85]:

array([1, 0, 0, ..., 0, 0, 0], dtype=int64)

模型的准确率:

In [86]:

acc_new = accuracy_score(y_test, y_pred_new)

print(acc_new)
0.8330666666666667