如何使用网格搜索来优化超参数

666 阅读6分钟

使用网格搜索来优化超参数

与参数不同,超参数是在训练机器学习模型之前设置的。这些超参数需要被优化,以使一个模型适应一个数据集。然而,在一个数据集上的最佳超参数设置,在另一个数据集上不太可能是最佳的。这使得超参数优化的任务变得更加棘手。

在这篇文章中,我们将深入探讨优化超参数的一个关键技术,即网格搜索。

前提条件

在我们开始之前,仔细阅读并安装以下内容可能会有所帮助。-本文的一个部分将简要介绍超参数和网格搜索方法。

-[VSCode]将是我的代码编辑器,我们将使用Python作为语言。

- [GridSearchCV]是[scikit-learn]库中的一个工具,用来进行交叉验证的网格搜索。

目标

本文的目标是通过在Python中的简单实现来理解网格搜索优化。

网格搜索

网格搜索是指一种用于识别模型的最佳超参数的技术。与参数不同,在训练数据中寻找超参数是无法实现的。因此,为了找到正确的超参数,我们为每个超参数的组合创建一个模型。

因此,网格搜索被认为是一种非常传统的超参数优化方法,因为我们基本上是 "粗暴 "地执行所有可能的组合。然后通过交叉验证对模型进行评估。拥有最佳精度的模型自然被认为是最好的。

grid

网格布局。

从上面的图片中,我们注意到,数值是以矩阵的形式排列的。

交叉验证

我们已经提到,交叉验证被用来评估模型的性能。交叉验证衡量一个模型对独立数据集的概括程度。我们使用交叉验证来很好地估计一个预测模型的性能如何。

通过这种方法,我们有一对数据集:一个独立的数据集和一个训练数据集。我们可以对一个单一的数据集进行分割,以产生这两个数据集。这些分区的大小是一样的,被称为褶皱。考虑中的模型是在所有的折叠上训练的,除了一个。

然后用被排除的折叠来测试模型。这个过程不断重复,直到所有的折叠被用作测试集。然后用模型在所有折叠上的平均性能来估计模型的性能。

在一个被称为k-fold交叉验证的技术中,用户指定了用kk表示的折叠数量。这意味着,当k=5k=5时,有5个折子。

crossvalidation

K-折交叉验证法,K为5。

网格搜索的实现

下面给出的例子是网格搜索的一个基本实现。我们首先指定我们寻求检查的超参数。然后我们提供一组要测试的值。之后,网格搜索将在交叉验证的帮助下尝试所有可能的超参数组合。让我们把这个过程分成以下几个步骤。

步骤

  1. 加载数据集

我的第一步是使用from sklearn.datasets import load_irisiris = load_iris() 加载数据集。虹膜数据集是Python中的sci-kit learn库。数据被存储在一个150\*4150 \* 4的数组中。要查看数据集的内容,我们可以使用print(iris.data)print(iris.data.shape)

要看到数据集的内容,我们的输入就变成了。

from sklearn.datasets import load_iris
iris = load_iris()
print(iris.data)
print(iris.data.shape)

我们的输出。

[[5.1 3.5 1.4 0.2]
 [4.9 3. 1.4 0.2]
 [4.7 3.2 1.3 0.2]
 [4.6 3.1 1.5 0.2]
 [5. 3.6 1.4 0.2]
 [5.4 3.9 1.7 0.4]
 [4.6 3.4 1.4 0.3]
 [5. 3.4 1.5 0.2]
 [4.4 2.9 1.4 0.2]
 [4.9 3.1 1.5 0.1]
 [5.4 3.7 1.5 0.2]
 [4.8 3.4 1.6 0.2]
 [4.8 3. 1.4 0.1]
 [4.3 3. 1.1 0.1]
 [5.8 4. 1.2 0.2]
 [5.7 4.4 1.5 0.4]
 [5.4 3.9 1.3 0.4]
 [5.1 3.5 1.4 0.3]
 [5.7 3.8 1.7 0.3]
 [5.1 3.8 1.5 0.3]
 [5.4 3.4 1.7 0.2]
 [5.1 3.7 1.5 0.4]
 [4.6 3.6 1. 0.2]
 [5.1 3.3 1.7 0.5]
 [4.8 3.4 1.9 0.2]
 [5. 3. 1.6 0.2]
 [5. 3.4 1.6 0.4]
 [5.2 3.5 1.5 0.2]
 [5.2 3.4 1.4 0.2]
 [4.7 3.2 1.6 0.2]
 [4.8 3.1 1.6 0.2]
 [5.4 3.4 1.5 0.4]
 [5.2 4.1 1.5 0.1]
 [5.5 4.2 1.4 0.2]
 [4.9 3.1 1.5 0.2]
 [5. 3.2 1.2 0.2]
 [5.5 3.5 1.3 0.2]
 [4.9 3.6 1.4 0.1]
 [4.4 3. 1.3 0.2]
 [5.1 3.4 1.5 0.2]
 [5. 3.5 1.3 0.3]
 [4.5 2.3 1.3 0.3]
 [4.4 3.2 1.3 0.2]
 [5. 3.5 1.6 0.6]
 [5.1 3.8 1.9 0.4]
 [4.8 3. 1.4 0.3]
 [5.1 3.8 1.6 0.2]
 [4.6 3.2 1.4 0.2]
 [5.3 3.7 1.5 0.2]
 [5. 3.3 1.4 0.2]
 [7. 3.2 4.7 1.4]
 [6.4 3.2 4.5 1.5]
 [6.9 3.1 4.9 1.5]
 [5.5 2.3 4. 1.3]
 [6.5 2.8 4.6 1.5]
 [5.7 2.8 4.5 1.3]
 [6.3 3.3 4.7 1.6]
 [4.9 2.4 3.3 1. ]
 [6.6 2.9 4.6 1.3]
 [5.2 2.7 3.9 1.4]
 [5. 2. 3.5 1. ]
 [5.9 3. 4.2 1.5]
 [6. 2.2 4.
1. ]
 [6.1 2.9 4.7 1.4]
 [5.6 2.9 3.6 1.3]
 [6.7 3.1 4.4 1.4]
 [5.6 3. 4.5 1.5]
 [5.8 2.7 4.1 1. ]
 [6.2 2.2 4.5 1.5]
 [5.6 2.5 3.9 1.1]
 [5.9 3.2 4.8 1.8]
 [6.1 2.8 4. 1.3]
 [6.3 2.5 4.9 1.5]
 [6.1 2.8 4.7 1.2]
 [6.4 2.9 4.3 1.3]
 [6.6 3. 4.4 1.4]
 [6.8 2.8 4.8 1.4]
 [6.7 3. 5.
1.7]
 [6. 2.9 4.5 1.5]
 [5.7 2.6 3.5 1. ]
 [5.5 2.4 3.8 1.1]
 [5.5 2.4 3.7 1. ]
 [5.8 2.7 3.9 1.2]
 [6. 2.7 5.1 1.6]
 [5.4 3. 4.5 1.5]
 [6. 3.4 4.5 1.6]
 [6.7 3.1 4.7 1.5]
 [6.3 2.3 4.4 1.3]
 [5.6 3. 4.1 1.3]
 [5.5 2.5 4. 1.3]
 [5.5 2.6 4.4 1.2]
 [6.1 3. 4.6 1.4]
 [5.8 2.6 4. 1.2]
 [5. 2.3 3.3 1. ]
 [5.6 2.7 4.2 1.3]
 [5.7 3. 4.2 1.2]
 [5.7 2.9 4.2 1.3]
 [6.2 2.9 4.3 1.3]
 [5.1 2.5 3. 1.1]
 [5.7 2.8 4.1 1.3]
 [6.3 3.3 6. 2.5]
 [5.8 2.7 5.1 1.9]
 [7.1 3. 5.9 2.1]
 [6.3 2.9 5.6 1.8]
 [6.5 3. 5.8 2.2]
 [7.6 3. 6.6 2.1]
 [4.9 2.5 4.5 1.7]
 [7.3 2.9 6.3 1.8]
 [6.7 2.5 5.8 1.8]
 [7.2 3.6 6.1 2.5]
 [6.5 3.2 5.1 2. ]
 [6.4 2.7 5.3 1.9]
 [6.8 3. 5.5 2.1]
 [5.7 2.5 5. 2. ]
 [5.8 2.8 5.1 2.4]
 [6.4 3.2 5.3 2.3]
 [6.5 3. 5.5 1.8]
 [7.7 3.8 6.7 2.2]
 [7.7 2.6 6.9 2.3]
 [6. 2.2 5.
1.5]
 [6.9 3.2 5.7 2.3]
 [5.6 2.8 4.9 2. ]
 [7.7 2.8 6.7 2. ]
 [6.3 2.7 4.9 1.8]
 [6.7 3.3 5.7 2.1]
 [7.2 3.2 6. 1.8]
 [6.2 2.8 4.8 1.8]
 [6.1 3. 4.9 1.8]
 [6.4 2.8 5.6 2.1]
 [7.2 3. 5.8 1.6]
 [7.4 2.8 6.1 1.9]
 [7.9 3.8 6.4 2. ]
 [6.4 2.8 5.6 2.2]
 [6.3 2.8 5.1 1.5]
 [6.1 2.6 5.6 1.4]
 [7.7 3. 6.1 2.3]
 [6.3 3.4 5.6 2.4]
 [6.4 3.1 5.5 1.8]
 [6. 3. 4.8 1.8]
 [6.9 3.1 5.4 2.1]
 [6.7 3.1 5.6 2.4]
 [6.9 3.1 5.1 2.3]
 [5.8 2.7 5.1 1.9]
 [6.8 3.2 5.9 2.3]
 [6.7 3.3 5.7 2.5]
 [6.7 3. 5.2 2.3]
 [6.3 2.5 5. 1.9]
 [6.5 3. 5.2 2. ]
 [6.2 3.4 5.4 2.3]
 [5.9 3. 5.1 1.8]]
(150, 4)

然而,值得注意的是,上面的可视化步骤是为了了解数据集,而不一定是在实现网格搜索。

  1. 导入GridSearchCV,svmSVR

sklearn.model_selection 加载数据集后,我们再从GridSearchCV ,以及svmSVR ,如下图所示。

GridSearchCV它确保了一个详尽的网格搜索,从参数值的网格中孕育出候选人。正如我们稍后将看到的,这些值是用参数param_grid

我们导入svm ,因为我们要使用的算法类型是支持向量机。SVR 类代表Epsilon支持向量回归。有了这个,该模型有两个自由参数;C和epsilon。我们将在下一个步骤中设置参数。

from sklearn.model_selection import GridSearchCV
from sklearn import svm
from sklearn.svm import SVR
  1. 设置估计器参数。

SVR 在这个实现中,我们使用了rbf 模型的内核。rbf 代表径向基函数。它给模型引入了某种形式的非线性,因为使用的数据是非线性的。我们的意思是,数据的排列不遵循特定的序列。

estimator=SVR(kernel='rbf')
  1. 指定超参数和取值范围。

然后我们指定我们寻求考察的超参数。当使用SVR的rbf 核时,要使用的三个超参数是:C,epsilon, 和gamma 。我们可以给每个参数几个值来选择。

记住,可以改变这些值,并对它们进行测试,看看哪个值的集合可以得到更好的结果。下面是我随机选择的值。

param_grid={
            'C': [1.1, 5.4, 170, 1001],
            'epsilon': [0.0003, 0.007, 0.0109, 0.019, 0.14, 0.05, 8, 0.2, 3, 2, 7],
            'gamma': [0.7001, 0.008, 0.001, 3.1, 1, 1.3, 5]
        }

param_grid 参数需要一个参数列表和每个参数的范围,如我们上面所示。

  1. 评估。

我们提到,交叉验证是为了估计模型的性能而进行的。在k-fold交叉验证中,k是折叠的数量。如下图所示,通过cv=5 ,我们用交叉验证法来训练模型5次。这意味着,5将是kk值。

scoring='neg_mean_squared_error' 给我们的是平均平方误差。它在网格搜索中以这种形式使用。这是指取平均平方误差的负值来最大化和优化它,而不是最小化实际误差。

n_jobs 参数指定了用库joblib并行化的例程应该使用的并发进程的数量。在我们的例子中,-1意味着所有的CPU都在使用。

verbose 给了我们一个产生日志信息的选项。我们把它保持在0的位置来禁用它,因为它可能会减慢我们的算法。

grid = GridSearchCV(

estimator=SVR(kernel='rbf'),
        param_grid={
            'C': [1.1, 5.4, 170, 1001],
            'epsilon': [0.0003, 0.007, 0.0109, 0.019, 0.14, 0.05, 8, 0.2, 3, 2, 7],
            'gamma': [0.7001, 0.008, 0.001, 3.1, 1, 1.3, 5]
        },
        cv=5, scoring='neg_mean_squared_error', verbose=0, n_jobs=-1)
  1. 拟合数据。

我们通过grid.fit(X,y) ,它对所有的参数进行拟合。

所有的代码

既然我们现在了解了代码的关键方面,让我们来运行所有的代码。

from sklearn.datasets import load_iris
from sklearn.model_selection import GridSearchCV
from sklearn import svm
from sklearn.svm import SVR

iris = load_iris()
svc = svm.SVR()

grid = GridSearchCV(

estimator=SVR(kernel='rbf'),

param_grid={

'C': [1.1, 5.4, 170, 1001],

'epsilon': [0.0003, 0.007, 0.0109, 0.019, 0.14, 0.05, 8, 0.2, 3, 2, 7],

'gamma': [0.7001, 0.008, 0.001, 3.1, 1, 1.3, 5]
        },
        cv=5, scoring='neg_mean_squared_error', verbose=0, n_jobs=-1)

X = iris.data
y = iris.target

grid.fit(X,y)


#print the best parameters from all possible combinations
print("best parameters are: ", grid.best_params_)

结果

best parameters are:
{'C': 170, 'epsilon': 0.0003, 'gamma': 0.008}

我们成功地进行了网格搜索,并确定了最佳参数是C的170,ε值的0.0003和gamma的0.008。

结论

由于网格搜索尝试了所有可能的组合,所以它成为一种计算上很昂贵的方法。我们已经定义了网格搜索,并通过一个简单的Python例子探索了它是如何工作的。

还有其他的优化方法,其复杂性和有效性各不相同。我希望在未来能介绍一些。