Skip to content

Commit 2c0e3d6

Browse files
WIP: function to fit randomized lasso
1 parent ffccea5 commit 2c0e3d6

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Functions to fit and "infer" about parameters in the
2+
# randomized LASSO
3+
#
4+
# min 1/2 || y - \beta_0 - X \beta ||_2^2 + \lambda || \beta ||_1 - \omega^T\beta + \frac{\epsilon}{2} \|\beta\|^2_2
5+
6+
fit_randomized_lasso = function(X,
7+
y,
8+
lam,
9+
noise_scale,
10+
ridge_term,
11+
noise_type=c('gaussian', 'laplace'),
12+
max_iter=100, # how many iterations for each optimization problem
13+
kkt_tol=1.e-4, # tolerance for the KKT conditions
14+
objective_tol=1.e-8, # tolerance for relative decrease in objective
15+
objective_stop=FALSE,
16+
kkt_stop=TRUE,
17+
param_stop=TRUE)
18+
{
19+
20+
n = nrow(X); p = ncol(X)
21+
22+
noise_type = match.arg(noise_type)
23+
24+
if (noise_type == 'gaussian') {
25+
D = Norm(mean=0, sd=noise_scale)
26+
}
27+
else if (noise_type == 'laplace') {
28+
D = DExp(rate = 1 / noise_scale) # D is a Laplace distribution with rate = 1.
29+
}
30+
perturb_ = distr::r(D)(p)
31+
32+
lam = as.numeric(lam)
33+
if (length(lam) == 1) {
34+
lam = rep(lam, p)
35+
}
36+
if (length(lam) != p) {
37+
stop("Lagrange parameter should be single float or of length ncol(X)")
38+
}
39+
40+
soln = rep(0, p)
41+
Xsoln = rep(0, n)
42+
linear_func = (- t(X) %*% y - perturb_)
43+
gradient = 1. * linear_func
44+
ever_active = rep(0, p)
45+
nactive = as.integer(0)
46+
47+
result = solve_QP_wide(X, # design matrix
48+
lam, # vector of Lagrange multipliers
49+
ridge_term / n, # ridge_term
50+
max_iter,
51+
soln,
52+
linear_func,
53+
gradient,
54+
Xsoln,
55+
ever_active,
56+
nactive,
57+
kkt_tol,
58+
objective_tol,
59+
p,
60+
objective_stop, # objective_stop
61+
kkt_stop, # kkt_stop
62+
param_stop) # param_stop
63+
return(result)
64+
}

tests/test_randomized.R

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
library(selectiveInference)
2+
3+
test = function() {
4+
5+
n = 100; p = 50
6+
X = matrix(rnorm(n * p), n, p)
7+
y = rnorm(n)
8+
lam = 20 / sqrt(n)
9+
noise_scale = 0.01 * sqrt(n)
10+
ridge_term = .1 / sqrt(n)
11+
fit_randomized_lasso(X, y, lam, noise_scale, ridge_term)
12+
}
13+
14+
print(test())

0 commit comments

Comments
 (0)