Skip to content

Commit 7034930

Browse files
committed
add SVM solver
1 parent 7e05ed6 commit 7034930

File tree

3 files changed

+60
-1
lines changed

3 files changed

+60
-1
lines changed

NAMESPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
importFrom(Rcpp, evalCpp)
22
useDynLib(rehline, .registration = TRUE)
33

4-
export(rehline, elastic_qr, elastic_huber)
4+
export(rehline, elastic_huber, elastic_qr, svm)

R/svm.R

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
##' Solving Regularized Support Vector Machine
2+
##'
3+
##' @description
4+
##' This function solves the regularized support vector machine
5+
##' of the following form:
6+
##' \deqn{
7+
##' \min_{\beta}\ \frac{C}{n}\sum_{i=1}^n \max(1-y_i x_i^T\beta,0) +
8+
##' \frac{1}{2}\Vert\beta\Vert_2^2
9+
##' }
10+
##' where \eqn{\beta\in\mathbb{R}^d} is a length-\eqn{d} vector,
11+
##' \eqn{x_i\in\mathbb{R}^d} is the feature vector for the \eqn{i}-th observation,
12+
##' \eqn{y_i\in\{-1,1\}} is a binary label,
13+
##' and \eqn{C} is the cost parameter.
14+
##'
15+
##' @param x The data matrix \eqn{X=(x_1,\ldots,x_n)^T} of size
16+
##' \eqn{n\times d}, representing \eqn{n} observations
17+
##' and \eqn{d} features.
18+
##' @param y The length-\eqn{n} response vector.
19+
##' @param C The cost parameter.
20+
##' @param max_iter Maximum number of iterations.
21+
##' @param tol Tolerance parameter for convergence test.
22+
##' @param shrink Whether to use the shrinkage algorithm.
23+
##' @param verbose Level of verbosity.
24+
##'
25+
##' @return A list of the following components:
26+
##' \item{beta}{Optimized value of the \eqn{\beta} vector.}
27+
##' \item{xi,Lambda,Gamma}{Values of dual variables.}
28+
##' \item{niter}{Number of iterations used.}
29+
##' \item{dual_objfns}{Dual objective function values during the optimization process.}
30+
##'
31+
##' @author Yixuan Qiu \url{https://statr.me}
32+
##'
33+
##' Ben Dai \url{https://bendai.org}
34+
##'
35+
svm = function(x, y, C = 1, max_iter = 1000, tol = 1e-5, shrink = TRUE, verbose = 0)
36+
{
37+
n = nrow(x)
38+
39+
Umat = -C / n * matrix(y, nrow = 1)
40+
Vmat = matrix(C / n, 1, n)
41+
42+
res = rehline(x, Umat, Vmat, max_iter = max_iter,
43+
tol = tol, shrink = shrink, verbose = verbose)
44+
res$beta
45+
}

test/test_svm.R

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
library(rehline)
2+
library(LiblineaR)
23
library(reticulate)
34

45
py_config()
@@ -23,13 +24,21 @@ read_npz = function(npz_file)
2324

2425
dat = read_npz("./dataset/sim/exp_svm.npz")
2526
X = dat[["X"]]
27+
y = dat[["y"]]
2628
U = dat[["U"]]
2729
V = dat[["V"]]
30+
C = 0.5
2831

2932
print(dim(X))
3033
print(dim(U))
3134
print(dim(V))
3235

36+
# LibLinear
37+
set.seed(123)
38+
y = matrix(y, ncol = 1)
39+
res = LiblineaR(X, y, type = 3, cost = C, epsilon = 1e-6, bias = 0, verbose = TRUE)
40+
print(as.numeric(res$W))
41+
3342
# Algorithm with shrinking
3443
set.seed(123)
3544
res = rehline(X, U, V, max_iter = 1000, tol = 1e-6, verbose = 1)
@@ -43,6 +52,11 @@ set.seed(123)
4352
res = rehline(X, U, V, max_iter = 1000, tol = 1e-6, shrink = FALSE, verbose = 1)
4453
print(res$beta)
4554

55+
# Directly call svm()
56+
set.seed(123)
57+
res = svm(X, y, C = C * nrow(X), max_iter = 1000, tol = 1e-6, verbose = 1)
58+
print(res$beta)
59+
4660
# Add constraints
4761
set.seed(123)
4862
K = 5

0 commit comments

Comments
 (0)