lightgbm入门学习第一笔记

369 阅读4分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的13,点击查看活动详情

1.载入python包

from pathlib import Path
# 画图的
import matplotlib.pyplot as plt
%matplotlib inline
# 读取数据的
import pandas as pd
# 计算用的
import lightgbm as lgb
# 忽略警告的
import warnings
warnings.filterwarnings('ignore')



try:
    # To enable interactive mode you should install ipywidgets
    # https://github.com/jupyter-widgets/ipywidgets
    from ipywidgets import interact, SelectMultiple
    INTERACTIVE = True
except ImportError:
    INTERACTIVE = False

2.加载数据

# 读取当前目录
print(Path().absolute())
# 向上退一级目录
print(Path().absolute().parents[1])
regression_example_dir = Path().absolute().parents[1] / 'regression'
# pandas读取数据
df_train = pd.read_csv(str(regression_example_dir / 'regression.train'), header=None, sep='\t')
df_test = pd.read_csv(str(regression_example_dir / 'regression.test'), header=None, sep='\t')

y_train = df_train[0]
y_test = df_test[0]
X_train = df_train.drop(0, axis=1)
X_test = df_test.drop(0, axis=1)
C:\Users\Administrator\Desktop\LightGBM-master\examples\python-guide\notebooks
C:\Users\Administrator\Desktop\LightGBM-master\examples
df_train.head()
.dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; }
0 1 2 3 4 5 6 7 8 9 ... 19 20 21 22 23 24 25 26 27 28
0 1 0.869 -0.635 0.226 0.327 -0.690 0.754 -0.249 -1.092 0.000 ... -0.010 -0.046 3.102 1.354 0.980 0.978 0.920 0.722 0.989 0.877
1 1 0.908 0.329 0.359 1.498 -0.313 1.096 -0.558 -1.588 2.173 ... -1.139 -0.001 0.000 0.302 0.833 0.986 0.978 0.780 0.992 0.798
2 1 0.799 1.471 -1.636 0.454 0.426 1.105 1.282 1.382 0.000 ... 1.129 0.900 0.000 0.910 1.108 0.986 0.951 0.803 0.866 0.780
3 0 1.344 -0.877 0.936 1.992 0.882 1.786 -1.647 -0.942 0.000 ... -0.678 -1.360 0.000 0.947 1.029 0.999 0.728 0.869 1.027 0.958
4 1 1.105 0.321 1.522 0.883 -1.205 0.681 -1.070 -0.922 0.000 ... -0.374 0.113 0.000 0.756 1.361 0.987 0.838 1.133 0.872 0.808

5 rows × 29 columns

df_test.head()
.dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; }
0 1 2 3 4 5 6 7 8 9 ... 19 20 21 22 23 24 25 26 27 28
0 1 0.644 0.247 -0.447 0.862 0.374 0.854 -1.126 -0.790 2.173 ... -0.190 -0.744 3.102 0.958 1.061 0.980 0.875 0.581 0.905 0.796
1 0 0.385 1.800 1.037 1.044 0.349 1.502 -0.966 1.734 0.000 ... -0.440 0.638 3.102 0.695 0.909 0.981 0.803 0.813 1.149 1.116
2 0 1.214 -0.166 0.004 0.505 1.434 0.628 -1.174 -1.230 1.087 ... -1.383 1.355 0.000 0.848 0.911 1.043 0.931 1.058 0.744 0.696
3 1 0.420 1.111 0.137 1.516 -1.657 0.854 0.623 1.605 1.087 ... 0.731 1.424 3.102 1.597 1.282 1.105 0.730 0.148 1.231 1.234
4 0 0.897 -1.703 -1.306 1.022 -0.729 0.836 0.859 -0.333 2.173 ... -2.019 -0.289 0.000 0.805 0.930 0.984 1.430 2.198 1.934 1.684

5 rows × 29 columns

3.LightGBM数据集创建

classlightgbm.Dataset(data, label=None, reference=None, weight=None, group=None, init_score=None, feature_name='auto', categorical_feature='auto', params=None, free_raw_data=True)
  • data (str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence or list of numpy array) – Data source of Dataset. If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file.
  • label (list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None)) – Label of the data.
  • reference (Dataset or None, optional (default=None)) – If this is Dataset for validation, training data should be used as reference.
  • weight (list, numpy 1-D array, pandas Series or None, optional (default=None)) – Weight for each instance. Weights should be non-negative.
  • group (list, numpy 1-D array, pandas Series or None, optional (default=None)) – Group/query data. Only used in the learning-to-rank task. sum(group) = n_samples. For example, if you have a 100-document dataset with group = [10, 20, 40, 10, 10, 10], that means that you have 6 groups, where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
  • init_score (list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None, optional (default=None)) – Init score for Dataset.
  • feature_name (list of str, or 'auto', optional (default="auto")) – Feature names. If ‘auto’ and data is pandas DataFrame, data columns names are used.
  • categorical_feature (list of str or int, or 'auto', optional (default="auto")) – Categorical features. If list of int, interpreted as indices. If list of str, interpreted as feature names (need to specify feature_name as well). If ‘auto’ and data is pandas DataFrame, pandas unordered categorical columns are used. All values in categorical features will be cast to int32 and thus should be less than int32 max value (2147483647). Large values could be memory consuming. Consider using consecutive integers starting from zero. All negative values in categorical features will be treated as missing values. The output cannot be monotonically constrained with respect to a categorical feature. Floating point numbers in categorical features will be rounded towards 0.
  • params (dict or None, optional (default=None)) – Other parameters for Dataset.
  • free_raw_data (bool, optional (default=True)) – If True, raw data is freed after constructing inner Dataset.
lgb_train = lgb.Dataset(X_train, y_train)
lgb_test = lgb.Dataset(X_test, y_test, reference=lgb_train)

4.创建配置

params = {
    'num_leaves': 5,
    'metric': ['l1', 'l2'],
    'verbose': -1
}

5.开始训练

  • train(params, train_set[, num_boost_round, ...]) Perform the training with given parameters.
evals_result = {}  # 记录eval结果用于画图
gbm = lgb.train(params,
                lgb_train,
                num_boost_round=100,
                valid_sets=[lgb_train, lgb_test],
                feature_name=[f'f{i + 1}' for i in range(X_train.shape[-1])],
                categorical_feature=[21],
                callbacks=[
                    lgb.log_evaluation(10),
                    lgb.record_evaluation(evals_result)
                ])
[10]	training's l1: 0.457448	training's l2: 0.217995	valid_1's l1: 0.456464	valid_1's l2: 0.21641
[20]	training's l1: 0.436869	training's l2: 0.205099	valid_1's l1: 0.434057	valid_1's l2: 0.201616
[30]	training's l1: 0.421302	training's l2: 0.197421	valid_1's l1: 0.417019	valid_1's l2: 0.192514
[40]	training's l1: 0.411107	training's l2: 0.192856	valid_1's l1: 0.406303	valid_1's l2: 0.187258
[50]	training's l1: 0.403695	training's l2: 0.189593	valid_1's l1: 0.398997	valid_1's l2: 0.183688
[60]	training's l1: 0.398704	training's l2: 0.187043	valid_1's l1: 0.393977	valid_1's l2: 0.181009
[70]	training's l1: 0.394876	training's l2: 0.184982	valid_1's l1: 0.389805	valid_1's l2: 0.178803
[80]	training's l1: 0.391147	training's l2: 0.1828	valid_1's l1: 0.386476	valid_1's l2: 0.176799
[90]	training's l1: 0.388101	training's l2: 0.180817	valid_1's l1: 0.384404	valid_1's l2: 0.175775
[100]	training's l1: 0.385174	training's l2: 0.179171	valid_1's l1: 0.382929	valid_1's l2: 0.175321

6.查看训练过程

def render_metric(metric_name):
    ax = lgb.plot_metric(evals_result, metric=metric_name, figsize=(10, 5))
    plt.show()
if INTERACTIVE:
    # create widget to switch between metrics
    interact(render_metric, metric_name=params['metric'])
else:
    render_metric(params['metric'][0])
interactive(children=(Dropdown(description='metric_name', options=('l1', 'l2'), value='l1'), Output()), _dom_c…

7.画特征图

def render_plot_importance(importance_type, max_features=10,
                           ignore_zero=True, precision=3):
    ax = lgb.plot_importance(gbm, importance_type=importance_type,
                             max_num_features=max_features,
                             ignore_zero=ignore_zero, figsize=(12, 8),
                             precision=precision)
    plt.show()
if INTERACTIVE:
    # create widget for interactive feature importance plot
    interact(render_plot_importance,
             importance_type=['split', 'gain'],
             max_features=(1, X_train.shape[-1]),
             precision=(0, 10))
else:
    render_plot_importance(importance_type='split')
interactive(children=(Dropdown(description='importance_type', options=('split', 'gain'), value='split'), IntSl…

8.特征值分布图

def render_histogram(feature):
    ax = lgb.plot_split_value_histogram(gbm, feature=feature,
                                        bins='auto', figsize=(10, 5))
    plt.show()
if INTERACTIVE:
    # create widget for interactive split value histogram
    interact(render_histogram,
             feature=gbm.feature_name())
else:
    render_histogram(feature='f26')
interactive(children=(Dropdown(description='feature', options=('f1', 'f2', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8',…

9.画树图

9.1graphviz安装

pip install graphviz

9.2 win下graphviz安装

下载地址 gitlab.com/api/v4/proj…

安装并设置环境变量

import os
os.environ["PATH"] += os.pathsep + r'C:\Program Files (x86)\Graphviz\bin/'
def render_tree(tree_index, show_info, precision=3):
    show_info = None if 'None' in show_info else show_info
    return lgb.create_tree_digraph(gbm, tree_index=tree_index,
                                   show_info=show_info, precision=precision)
import os
os.environ["PATH"] += os.pathsep + r'C:\Program Files\Graphviz\bin/'

if INTERACTIVE:
    # create widget to switch between trees and control info in nodes
    interact(render_tree,
             tree_index=(0, gbm.num_trees() - 1),
             show_info=SelectMultiple(  # allow multiple values to be selected
                 options=['None',
                          'split_gain',
                          'internal_value',
                          'internal_count',
                          'internal_weight',
                          'leaf_count',
                          'leaf_weight',
                          'data_percentage'],
                 value=['None']),
             precision=(0, 10))
    tree = None
else:
    tree = render_tree(53, ['None'])
tree
interactive(children=(IntSlider(value=49, description='tree_index', max=99), SelectMultiple(description='show_…

10.模型保存

# 模型保存
gbm.save_model('model.txt')
<lightgbm.basic.Booster at 0x247fb2265e0>

11.预测

y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration)
print(y_pred)
[ 6.33479642e-01  4.25887937e-01  2.56217117e-01  4.86929519e-01  2.33292447e-01  1.63434010e-01  3.68667133e-01  4.09774178e-01  7.04604027e-01  5.02551497e-01  6.40686199e-01  7.19385156e-01  7.41902740e-01  6.96221985e-01  4.40294282e-01  7.37699070e-01  4.08883945e-01  5.36636773e-01  7.08475765e-01  5.20571445e-01  7.10604143e-01  5.26923953e-01  5.78733282e-01  5.14174117e-01  4.11320699e-01  6.53895132e-01  6.43825483e-01  7.84754801e-01  4.01373544e-01  6.46254968e-01  5.77817060e-01  5.47259514e-01  4.71219971e-01  6.44547348e-01  5.26154488e-01  3.37513911e-01  2.05669407e-01  2.06145013e-01  6.24416038e-01  7.37623316e-01 -2.58819554e-04  4.79166726e-01  3.44027673e-01  5.20347726e-01  3.63943230e-01 -5.32882589e-03  5.52210841e-01  5.24886584e-01  8.26508593e-01  4.26354797e-01  1.11591699e+00  4.74355146e-01  1.20555073e-01  9.08264190e-01  7.30272709e-01  5.19243790e-01  2.05566261e-01  4.52380730e-01  5.60240917e-01  6.57711162e-01  4.18448166e-01  7.72577358e-01  7.62705553e-01  3.57582089e-01  6.86433592e-01  3.55394358e-01  6.45050591e-01  1.78149230e-01  8.13064495e-01  6.90434809e-01  3.98568149e-01  7.49444675e-01  8.03400362e-01  2.80792590e-01  8.37334643e-01  8.42950872e-01  3.75903502e-01  5.46412170e-01  8.21602639e-01  4.95276352e-01  3.14823086e-01  6.42259773e-01  7.75677625e-01  8.38889852e-01  1.62250988e-01  7.54015909e-01  5.92549060e-01  6.21280723e-01  4.33799559e-01  6.30701280e-01  3.91863126e-01  3.24147919e-01  4.38905589e-01  5.57890721e-01  5.20962492e-01  7.66845263e-01  3.02650380e-01  1.87034030e-01  8.84502576e-01  1.67438512e-01  4.71943755e-01  4.46656370e-01  5.89327051e-01  2.71114672e-01  3.28554464e-01  8.41607666e-01  6.14197765e-01  6.21224704e-01  2.99919552e-01  3.33537812e-01  5.54302342e-01  6.06740120e-01  3.97077406e-01  2.51648855e-02  5.13863957e-01  1.04445064e+00  3.46554304e-01  8.29346567e-01  2.09741596e-01  5.19185239e-01  3.34880879e-01  7.52658355e-01  7.15034072e-01  5.89717543e-01  3.55840634e-01  6.40827149e-01  6.24056310e-01  4.74466221e-01  4.08941567e-01  5.92834501e-01  7.04916409e-01  4.97805554e-01  3.45694922e-01  3.55286154e-01  6.51943205e-01  8.53404327e-01  3.53576143e-01  8.41978567e-01  4.45335079e-01  5.19711068e-01  4.79440633e-01  5.50946397e-01  3.89575002e-01  5.09236565e-01  4.49776529e-01  2.64213341e-01  3.09399680e-01  6.56398594e-01  7.15026274e-01  2.90668166e-01  4.64332317e-01  6.63167362e-01  6.81758202e-01  3.54421235e-01  6.69146180e-01  6.20771418e-01  5.15806223e-01  5.30713745e-01  6.01431001e-01  3.07559043e-01  9.59358822e-01  8.14795706e-01  7.61612943e-01  6.74558258e-02  6.84204566e-01  3.75838640e-01  5.97677392e-01  8.05567575e-01  2.89311259e-01  9.47516951e-01  9.62208569e-01  3.06077249e-01  3.98820523e-01  7.22561270e-01  3.89382823e-01  6.82481436e-01  5.85777571e-01  4.73969313e-01  3.17099784e-01  1.62786423e-02  9.28805554e-02  4.35312005e-01  1.62981788e-01  3.33488493e-01  1.37538616e-01  8.69976045e-01  6.06629471e-01  5.48061659e-01  9.15549672e-01  4.59790384e-01  6.00812765e-01  2.31650100e-01  5.78019822e-01  3.59072379e-01  4.24190034e-01  7.46503330e-01  4.92974100e-01  4.13881427e-01  2.40509522e-01  5.92133544e-01  3.92611549e-01  1.53861351e-01  6.04505789e-01  1.99720258e-01  6.99974765e-01  1.92597930e-01  5.72096689e-01  4.83427643e-01  8.18670175e-01  2.93183045e-01  4.48306984e-02  7.71798929e-01  4.41011923e-01  3.32238273e-01  5.74983931e-01  3.09891315e-01  8.47782186e-01  4.87507926e-01  4.21680441e-01  4.67273852e-01  7.11907771e-01  2.15964274e-01  7.55342796e-01  6.03479333e-01  4.47199456e-01  1.03738758e+00  5.46845382e-01  5.69833595e-01  6.57568546e-01  3.26544957e-01  6.12572198e-01  2.88721290e-01  5.99573254e-01  5.98503670e-01  4.80925404e-01  4.42936828e-01  6.50282836e-01  4.35627017e-01  6.69963953e-01  6.01741465e-01  5.39032155e-01  1.89632639e-01  1.61598320e-01  3.49669360e-01  4.41367875e-01  2.76164840e-01  2.68309969e-01  7.03520925e-01  4.48102888e-01  2.00202346e-01  2.18110963e-01  5.16792491e-01  4.06454191e-01  8.81546665e-01  7.71701934e-01  4.78158613e-01  5.67138930e-01  8.45366996e-01  3.26615315e-01  3.80870623e-01  4.83237280e-01  4.23641252e-01  3.44510160e-01  6.41305298e-01  4.23626728e-01  6.69024329e-01  7.42055495e-01  2.92683694e-01  6.14328007e-01 -3.68470297e-02  3.25890161e-01  4.91349827e-01  3.60415452e-01  6.46256965e-01  7.73289796e-01  6.62003822e-01  4.37675022e-01  7.78272465e-01  4.67588479e-01  3.30493786e-01  4.83922011e-01  2.70396954e-01  2.22045665e-01  3.77279852e-01  4.03430723e-01  3.42700616e-01  3.51725965e-01  3.54871133e-01  3.90862583e-01  6.49789563e-01  7.38801144e-01  1.59306780e-01  4.21493538e-01  7.43716890e-01  2.92961678e-01  2.86227018e-01  2.68711439e-01  2.98343415e-01  4.03548755e-01  4.83247845e-01  3.61430698e-01  4.08706892e-01  4.46188917e-01  3.96745275e-01  3.72945669e-01  8.44879805e-01  3.85470247e-01  6.23814168e-01  5.92190938e-01  5.34525640e-01  7.15716233e-01  5.17309415e-01  4.37156637e-01  9.37123487e-01  2.87930061e-01  6.76085183e-01  4.75419457e-01  6.42998135e-01  6.29169319e-01  2.65030304e-01  5.51242740e-01  6.30204954e-01  4.47487179e-01  8.93424905e-01  9.15606316e-01  2.79803242e-01  2.71098323e-01  4.25292686e-01  3.12530588e-01  4.98295333e-01  3.98292363e-01  6.48594219e-01  5.75798390e-01  4.06543825e-01  2.01902925e-02  3.32055077e-01  5.37098558e-01  3.31121613e-01  9.66681614e-01  3.76179940e-01  7.16699017e-01  4.45738667e-01  5.82761456e-01  3.99956852e-01  7.12842532e-01  3.85965607e-01  5.68207526e-01  7.95473372e-01  3.85063328e-01  6.15900186e-01  6.41450856e-01  8.50585329e-01  6.97652594e-01  2.30194548e-01  4.66913590e-01  3.29228032e-01  4.84563854e-01  2.25052285e-01  9.04352553e-01  3.28065048e-01  2.54094103e-01  3.39910841e-01  6.24037062e-01  4.83310106e-01  3.30008683e-01  6.89600274e-01  4.67247879e-01  2.95067274e-01  3.12245822e-01  7.82927166e-01  4.18300283e-01 -2.57235162e-02  5.73030246e-01  7.22150648e-01  8.97854624e-01  6.88530388e-01  5.54530017e-01  5.41582864e-01  6.14265339e-01  1.59448357e-01  4.88603543e-01  8.68488192e-01  8.74196986e-01  3.96229659e-01  8.74277038e-01  6.75304616e-01  5.84603699e-01  4.92039502e-01  3.38159500e-01  4.61982991e-01  4.46062141e-01  4.65817114e-01  6.07446901e-01  5.00767145e-01  8.27301748e-01  3.47625866e-01  6.11007260e-01  7.36467474e-01  7.68005908e-01  2.49479597e-01  6.80701988e-01  6.43403743e-01  6.32443316e-01  7.04647895e-01  5.23572843e-01  8.82123994e-01  6.65484637e-01  8.40640498e-01  7.20916850e-01  2.76498614e-01  8.41627129e-01  4.33336973e-01  4.26323757e-01  8.72066599e-01  4.59959990e-01  5.19059765e-01  9.64819442e-01  4.33509174e-01  5.37756081e-01  4.41946705e-01  8.16990441e-01  4.33790288e-01  6.00526540e-01  8.75178569e-01  2.68953742e-01  6.44347962e-01  4.70963085e-01  5.52646015e-01  2.18118049e-01  3.96870713e-01  2.91110534e-01  5.62191673e-01  8.02750001e-01  9.22087286e-01  3.53827542e-01  6.85884451e-01  4.65114927e-01  5.96395922e-01  5.28037587e-01  8.85796589e-01  3.63259378e-01  6.00291109e-01  2.37146491e-01  6.13886042e-01  8.95054181e-01  8.33608076e-01  4.88892454e-01  7.87019209e-01  5.00058659e-01  7.15913488e-01  1.13429037e-01  4.11566084e-01  4.92317712e-01  6.14830892e-01  4.30842330e-01  3.51833957e-01  7.87238191e-01  3.34186183e-01  2.85357393e-01  4.69265545e-01  6.16605119e-01  4.40802936e-01  3.04184111e-01  5.41875948e-01  1.29971395e-01  7.78222288e-01  6.27903277e-01  6.08038577e-01  4.61289568e-01  5.51391866e-01  3.23979747e-01  5.27827811e-01  5.79608752e-01  5.43579353e-01  4.47888170e-01  5.64159242e-01  5.13248077e-01  5.72696510e-01  3.97392086e-01  4.91634195e-01  6.18278141e-01  8.85765337e-01  5.50707876e-01  7.34504063e-01  5.09513286e-01  6.92856337e-01  4.67934122e-01  6.11790217e-01  6.66984305e-01  6.02788571e-01  6.75515853e-01  3.48740944e-01  4.38187921e-01  3.16229197e-01  8.13288085e-01  8.55787527e-01  3.10834972e-01  5.84051349e-01  2.39132158e-01  3.94795155e-01]

12.评估

from sklearn.metrics import mean_squared_error
rmse_test = mean_squared_error(y_test, y_pred) ** 0.5
print(f'The RMSE of prediction is: {rmse_test}')
The RMSE of prediction is: 0.4187133172590796