-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsvm_kernel.m
More file actions
85 lines (74 loc) · 2.25 KB
/
svm_kernel.m
File metadata and controls
85 lines (74 loc) · 2.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
function [trainPoints, trainLabels,fun] = svm_kernel(varargin)
% function [X, y,h] = svm_kernel(varargin)
%
% Inputs:
% 'X',X : defines feature vectors as columns of matrix X (dxn)
% 'Y',y : defines labels as one 1xn vector
% 'lambda',C : sets the loss/regularization trade-off (default=1)
% Visualization:
% 'vismargin',true/false : visualizes the margin of 1
% 'viscolor',true/false : generates color plots
% Kernel parameters:
% 'kernel',s : sets kernel to either 'rbf','polynomial','linear','mkl'
% 'sigma',sigma : rbf lernel width
% 'degree',d : sets degree of polynomial kernel
%
% Outputs:
% X : feature vectors
% y : labels
% h : svm classifier function
%
% example:
% []
pars.X=[];
pars.Y=[];
pars.bias=1;
pars.C=1;
pars.sigma=0.5;
pars.kernel='rbf';
pars.degree=4;
pars.viscolor=false;
pars.vismargin=true;
pars=extractpars(varargin,pars);
% Initialize training data to empty; will get points from user
% Obtain points froom the user:
trainPoints=pars.X;
trainLabels=pars.Y;
clf;
axis([-5 5 -5 5]);
if isempty(trainPoints)
% Define the symbols and colors we'll use in the plots later
symbols = {'o','x'};
classvals = [-1 1];
trainLabels=[];
hold on; % Allow for overwriting existing plots
xlim([-5 5]); ylim([-5 5]);
for c = 1:2
title(sprintf('Click to create points from class %d. Press enter when finished.', c));
[x y] = getpts;
plot(x,y,symbols{c},'LineWidth', 2, 'Color', 'black');
% Grow the data and label matrices
trainPoints = vertcat(trainPoints, [x y]);
trainLabels = vertcat(trainLabels, repmat(classvals(c), numel(x), 1));
end
end
switch pars.kernel
case 'rbf'
kernel=@(x,z) exp(-distance(x',z')./(2*pars.sigma^2));
disp('RBF');
case 'linear'
kernel=@(x,z) x*z';
disp('linear');
case 'polynomial'
kernel=@(x,z) (x*z'+1).^pars.degree;
disp('polynomial');
case 'mkl'
resc=@(K) K./max(max(K));
kernel=@(x,z) resc((x*z'+1).^pars.degree)+resc(exp(-distance(x',z')./(2*pars.sigma^2)));
disp('mkl');
end;
K=kernel(trainPoints,trainPoints);
% This is where your work begins
[alpha, b] = trainSVM(K, trainLabels, pars.C);
fun = @(Xt) kernel(Xt,trainPoints) * alpha + b;
% This is where your work ends