网格搜索(Grid Search)与交叉验证(Cross-Validation)详解
1. 网格搜索(Grid Search)
-
定义:网格搜索是一种超参数优化方法,通过遍历预定义的参数组合,找到那一组使模型性能最优的参数组合。
-
为什么使用:
- 超参数(如学习率、树的深度等)无法通过训练数据直接学习,需要手动调整。
- 暴力搜索所有可能的参数组合,避免手动试错,确保找到全局最优解。
2. 交叉验证(Cross-Validation)
-
定义:将数据集划分为
k个子集(称为“折”),依次用其中k-1个子集训练模型,剩余1个子集验证模型性能,重复k次后取平均得分,作为某一种超参数组合的模型性能。 -
为什么使用:
- 充分利用有限数据,减少因数据划分不合理导致的模型偏差。
- 避免过拟合训练集,更准确评估模型的泛化能力。
- 我们把训练集,又分成训练集和验证集, 就是为了避免在训练和优化模型的时候都是用同样的数据, 而导致训练出来的模型过拟合.
3. 为何结合使用?
网格搜索结合交叉验证(如 GridSearchCV)能在搜索最优参数时,通过交叉验证评估参数性能,确保最终模型的稳定性和泛化性。
完整代码示例(含注释)
# 安装依赖(如未安装)
# !pip install xgboost scikit-learn numpy pandas
import numpy as np
import pandas as pd
from xgboost import XGBClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.metrics import accuracy_score, classification_report
# 加载数据集(乳腺癌分类)
data = load_breast_cancer()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = pd.Series(data.target)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# 初始化XGBoost分类器
model = XGBClassifier(random_state=42)
# 定义参数网格(示例参数,可根据需求扩展)
param_grid = {
'max_depth': [3, 5, 7], # 树的最大深度
'n_estimators': [50, 100, 200], # 树的数量
'learning_rate': [0.01, 0.1, 0.2] # 学习率
}
# 创建GridSearchCV实例(5折交叉验证)
grid_search = GridSearchCV(
estimator=model,
param_grid=param_grid,
scoring='accuracy', # 评估指标(分类准确率)
cv=5, # 5折交叉验证
n_jobs=-1, # 使用所有CPU核心加速
verbose=1 # 输出训练日志
)
# 执行网格搜索(在训练集上)
grid_search.fit(X_train, y_train)
# 输出最佳参数和模型性能
print("最佳参数组合:", grid_search.best_params_)
best_model = grid_search.best_estimator_
print("训练集交叉验证最佳准确率:", grid_search.best_score_)
# 在测试集上评估模型性能
y_pred = best_model.predict(X_test)
print("测试集准确率:", accuracy_score(y_test, y_pred))
print("\n分类报告:\n", classification_report(y_test, y_pred))
代码说明
-
数据集:使用乳腺癌分类数据集(二分类任务)。
-
参数网格:定义了
max_depth、n_estimators和learning_rate三个关键参数的候选值。 -
GridSearchCV:
- 通过5折交叉验证评估每组参数。
- 最终选择在训练集上交叉验证得分最高的参数组合。
-
模型评估:输出测试集准确率和分类报告(精确率、召回率、F1值)。
运行此代码将自动完成模型训练、调参和评估流程。根据实际任务,可调整参数网格中的取值范围或添加更多参数(如 subsample、colsample_bytree 等)。