MATLAB SVM寻找最佳参数 c g

893 阅读3分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 6 月更文挑战」的第22天,点击查看活动详情

image.png

老师布置了一个数据挖掘的作业,用SVM做分类。老师原话是:“说SVM训练慢,其实svmtrain的过程是很快的,慢是慢在寻找最佳参数。”

但是寻找最佳参数这里老师又直接给了一个别人写的网格计算的小代码,一个函数直接调用就OK了。并不是每个人都有老师,所以在这里分享一下。

配合livsvm库一起使用效果绝佳。

新建一个文件叫SVMcgForClass.m,如果你不嫌麻烦,拉到最底下,把下面代码解释部分的代码从头到尾复制进去,保存。

如果你懒得复制保存可以直接下一个:

搞好了那个文件之后,将文件放到你代码同目录下边。

1651661306(1).png

然后在你的代码中加一句:

[bestacc,bestc,bestg] = SVMcgForClass(train_label,train,cmin,cmax,gmin,gmax,v,cstep,gstep,accstep)

就会自动调用SVMcgForClass的函数,并返回最高准确率下(bestacc)的最佳的g和c(bestc、bestg)。

比如下图这样,我的训练集的分类标签是train_labels,训练集的分类特征是train_features,之后就是设置c和g的最大值最小值以及网格步长,还有交叉验证的折数(参数下边代码部分会详细解释)。 都设置好之后运行可以看到最右边的结果。

image.png


代码解释

解释一下下边这段代码啊,就是搞了个网格,挨个数据对测试,获取最佳的c和g,注意这是在计算精度相同的情况下获取最小的c和g。

原作者备注信息在代码的3-15行。

最后一次修改已经是10年了,(12年前的事了,现在居然还能用),我看了一下新浪博客已经倒闭了,所以我就发过来了。

function [bestacc,bestc,bestg] = 
SVMcgForClass(train_label,train,cmin,cmax,gmin,gmax,v,cstep,gstep,accstep)

第1行代码就是:写了一个SVMcgForClass函数,能返回在最好准确率情况下的最好的c和g。 参数:

  • train_label:测试集分类标签
  • train:测试集数据
  • cmin、cmax、cstep:c的取值范围以及每次增加的步长
  • gmin、gmax、gstep:g的取值范围以及每次增加的步长
  • v:SVM会做k折交叉验证,这里是验证的折数,即k的大小
  • accstep:精度的步长,这里是画图时候会用到的一个值
%% about the parameters of SVMcg 
if nargin < 10
    accstep = 4.5;
end
......

原代码的17-35行是参数设置。 nargin为“number of input arguments”的缩写。 在matlab中定义一个函数时, 在函数体内部, nargin是用来判断输入变量个数的函数。 他一共设置了10个参数。

  • nargin < 10就是说你只写了9个参数,不写accstep的时候默认设置accstep为4.5。
  • nargin < 8就是你还没写gstep和cstep,默认设置为0.8。
  • nargin < 7不写K折交叉默认5折。
  • nargin < 5不写g默认[-8,8]
  • nargin < 3不写c默认[-8,8]

所以我上边的例子代码写成下边这样也是可以的。

image.png

源代码的36-70行是设置网格挨个数据对都搞一个SVM训练。 cg用来存储你所有的SVM网络,以便寻找最好的一个。

源代码的71到最后就是对你刚才网格所有的数据进行一个可视化,一个等高线图一个3D的图。


完整代码

function [bestacc,bestc,bestg] = SVMcgForClass(train_label,train,cmin,cmax,gmin,gmax,v,cstep,gstep,accstep)

%SVMcg cross validation by faruto
%%
% by faruto
%Email:patrick.lee@foxmail.com QQ:516667408 http://blog.sina.com.cn/faruto BNU
%last modified 2010.01.17

%% 若转载请注明:
% faruto and liyang , LIBSVM-farutoUltimateVersion 
% a toolbox with implements for support vector machines based on libsvm, 2009. 
% 
% Chih-Chung Chang and Chih-Jen Lin, LIBSVM : a library for
% support vector machines, 2001. Software available at
% http://www.csie.ntu.edu.tw/~cjlin/libsvm

%% about the parameters of SVMcg 
if nargin < 10
    accstep = 4.5;
end
if nargin < 8
    cstep = 0.8;
    gstep = 0.8;
end
if nargin < 7
    v = 5;
end
if nargin < 5
    gmax = 8;
    gmin = -8;
end
if nargin < 3
    cmax = 8;
    cmin = -8;
end
%% X:c Y:g cg:CVaccuracy
[X,Y] = meshgrid(cmin:cstep:cmax,gmin:gstep:gmax);
[m,n] = size(X);
cg = zeros(m,n);

eps = 10^(-4);

%% record acc with different c & g,and find the bestacc with the smallest c
bestc = 1;
bestg = 0.1;
bestacc = 0;
basenum = 2;
for i = 1:m
    for j = 1:n
        cmd = ['-v ',num2str(v),' -c ',num2str( basenum^X(i,j) ),' -g ',num2str( basenum^Y(i,j) )];
        cg(i,j) = svmtrain(train_label, train, cmd);
        
        if cg(i,j) <= 90
            continue;
        end
        
        if cg(i,j) > bestacc
            bestacc = cg(i,j);
            bestc = basenum^X(i,j);
            bestg = basenum^Y(i,j);
        end        
        
        if abs( cg(i,j)-bestacc )<=eps && bestc > basenum^X(i,j) 
            bestacc = cg(i,j);
            bestc = basenum^X(i,j);
            bestg = basenum^Y(i,j);
        end        
        
    end
end
%% to draw the acc with different c & g
figure;
[C,h] = contour(X,Y,cg,70:accstep:100);
clabel(C,h,'Color','r');
xlabel('log2c','FontSize',12);
ylabel('log2g','FontSize',12);
firstline = 'SVC参数选择结果图(等高线图)[GridSearchMethod]'; 
secondline = ['Best c=',num2str(bestc),' g=',num2str(bestg), ...
    ' CVAccuracy=',num2str(bestacc),'%'];
title({firstline;secondline},'Fontsize',12);
grid on; 

figure;
meshc(X,Y,cg);
% mesh(X,Y,cg);
% surf(X,Y,cg);
axis([cmin,cmax,gmin,gmax,30,100]);
xlabel('log2c','FontSize',12);
ylabel('log2g','FontSize',12);
zlabel('Accuracy(%)','FontSize',12);
firstline = 'SVC参数选择结果图(3D视图)[GridSearchMethod]'; 
secondline = ['Best c=',num2str(bestc),' g=',num2str(bestg), ...
    ' CVAccuracy=',num2str(bestacc),'%'];
title({firstline;secondline},'Fontsize',12);