Skip to content

Commit e2efbdc

Browse files
author
Jelena Markovic
committed
working! bugs fixed. new sampler
1 parent d2a3b9b commit e2efbdc

File tree

4 files changed

+217
-31
lines changed

4 files changed

+217
-31
lines changed

selectiveInference/R/funs.randomized.R

Lines changed: 57 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,15 @@ randomizedLasso = function(X,
3333
if (is.null(noise_scale)) {
3434
noise_scale = 0.5 * sd(y) * sqrt(mean_diag)
3535
}
36-
36+
37+
print(paste("ridge term", ridge_term))
38+
print(paste("noise scale", noise_scale))
39+
3740
noise_type = match.arg(noise_type)
3841

3942
if (noise_scale > 0) {
4043
if (noise_type == 'gaussian') {
44+
set.seed(1)
4145
perturb_ = rnorm(p) * noise_scale
4246
}
4347
else if (noise_type == 'laplace') {
@@ -65,8 +69,8 @@ randomizedLasso = function(X,
6569
nactive = as.integer(0)
6670

6771
result = solve_QP_wide(X, # design matrix
68-
lam / n, # vector of Lagrange multipliers
69-
ridge_term / n, # ridge_term
72+
lam / n, # vector of Lagrange multipliers
73+
ridge_term / n, # ridge_term
7074
max_iter,
7175
soln,
7276
linear_func,
@@ -76,12 +80,12 @@ randomizedLasso = function(X,
7680
nactive,
7781
kkt_tol,
7882
objective_tol,
79-
parameter_tol,
83+
parameter_tol,
8084
p,
81-
objective_stop, # objective_stop
82-
kkt_stop, # kkt_stop
83-
parameter_stop) # param_stop
84-
85+
objective_stop, # objective_stop
86+
kkt_stop, # kkt_stop
87+
parameter_stop) # param_stop
88+
8589
sign_soln = sign(result$soln)
8690

8791
unpenalized = lam == 0
@@ -96,7 +100,11 @@ randomizedLasso = function(X,
96100

97101
observed_scalings = abs(result$soln)[active]
98102
observed_unpen = result$soln[unpenalized]
99-
observed_subgrad = result$gradient[inactive]
103+
observed_subgrad = -n*result$gradient[inactive]
104+
105+
if (length(which(abs(observed_subgrad)>lam[1]))){
106+
print("subgradient eq not satisfied")
107+
}
100108

101109
observed_opt_state = c(observed_unpen, observed_scalings, observed_subgrad)
102110

@@ -111,14 +119,15 @@ randomizedLasso = function(X,
111119
coef_term = L_E
112120

113121
signs_ = c(rep(1, sum(unpenalized)), sign_soln[active])
122+
123+
coef_term[active,] = coef_term[active,] + ridge_term * diag(rep(1, sum(active))) # ridge term
124+
114125
if (length(signs_) == 1) {
115-
coef_term = coef_term * signs_
126+
coef_term = coef_term * signs_
116127
} else {
117-
coef_term = coef_term %*% diag(signs_) # scaligns are non-negative
128+
coef_term = coef_term %*% diag(signs_) # scaligns are non-negative
118129
}
119-
120-
coef_term[active,] = coef_term[active,] + ridge_term * diag(rep(1, sum(active))) # ridge term
121-
130+
122131
subgrad_term = matrix(0, p, sum(inactive)) # for subgrad
123132
for (i in 1:sum(inactive)) {
124133
subgrad_term[inactive_set[i], i] = 1
@@ -155,7 +164,8 @@ randomizedLasso = function(X,
155164
inactive_lam = lam[inactive_set]
156165
inactive_start = sum(unpenalized) + sum(active)
157166
active_start = sum(unpenalized)
158-
167+
168+
159169
# XXX only for Gaussian so far
160170

161171
log_optimization_density = function(opt_state) {
@@ -185,9 +195,11 @@ randomizedLasso = function(X,
185195
optimization_transform = opt_transform,
186196
internal_transform = internal_transform,
187197
log_optimization_density = log_optimization_density,
188-
observed_opt_state = observed_opt_state,
198+
observed_opt_state = observed_opt_state,
189199
observed_raw = observed_raw,
190-
noise_scale = noise_scale
200+
noise_scale = noise_scale,
201+
soln = result$soln,
202+
perturb = perturb_
191203
))
192204

193205
}
@@ -314,19 +326,22 @@ conditional_density = function(noise_scale, lasso_soln) {
314326
lasso_soln$log_optimization_density = log_condl_optimization_density
315327
lasso_soln$observed_opt_state = observed_opt_state[1:nactive]
316328
lasso_soln$optimization_transform = opt_transform
317-
return(lasso_soln)
329+
reduced_opt_transform =list(linear_term = reduced_B, offset_term = reduced_beta_offset)
330+
return(list(lasso_soln=lasso_soln,
331+
reduced_opt_transform = reduced_opt_transform))
318332
}
319333

320334
randomizedLassoInf = function(X,
321335
y,
322336
lam,
337+
sampler="A",
323338
sigma=NULL,
324339
noise_scale=NULL,
325340
ridge_term=NULL,
326341
condition_subgrad=TRUE,
327342
level=0.9,
328-
nsample=10000,
329-
burnin=2000,
343+
nsample=10000,
344+
burnin=2000,
330345
max_iter=100, # how many iterations for each optimization problem
331346
kkt_tol=1.e-4, # tolerance for the KKT conditions
332347
parameter_tol=1.e-8, # tolerance for relative convergence of parameter
@@ -353,22 +368,35 @@ randomizedLassoInf = function(X,
353368
parameter_stop=parameter_stop)
354369

355370
active_set = lasso_soln$active_set
356-
if (length(active_set)==0){
371+
nactive = length(active_set)
372+
print(paste("nactive", nactive))
373+
if (nactive==0){
357374
return (list(active_set=active_set, pvalues=c(), ci=c()))
358375
}
359376
inactive_set = lasso_soln$inactive_set
360-
nactive = length(active_set)
377+
361378

362379
noise_scale = lasso_soln$noise_scale # set to default value in randomizedLasso
363380

364381
if (condition_subgrad==TRUE){
365-
lasso_soln=conditional_density(noise_scale, lasso_soln)
382+
condl_lasso=conditional_density(noise_scale, lasso_soln)
383+
lasso_soln = condl_lasso$lasso_soln
384+
reduced_opt_transform = condl_lasso$reduced_opt_transform
366385
}
367386

368387
ndim = length(lasso_soln$observed_opt_state)
369-
370-
S = sample_opt_variables(lasso_soln, jump_scale=rep(1/sqrt(n), ndim), nsample=nsample)
371-
opt_samples = as.matrix(S$samples[(burnin+1):nsample,,drop=FALSE])
388+
389+
if (sampler =="R"){
390+
S = sample_opt_variables(lasso_soln, jump_scale=rep(1/sqrt(n), ndim), nsample=nsample)
391+
opt_samples = as.matrix(S$samples[(burnin+1):nsample,,drop=FALSE])
392+
} else if (sampler == "A"){
393+
opt_samples = gaussian_sampler(noise_scale,
394+
lasso_soln$observed_opt_state,
395+
reduced_opt_transform$linear_term,
396+
reduced_opt_transform$offset_term,
397+
nsamples=nsample)
398+
opt_sample = opt_samples[(burnin+1):nsample,]
399+
}
372400

373401
X_E = X[, active_set]
374402
X_minusE = X[, inactive_set]
@@ -458,3 +486,6 @@ randomizedLassoInf = function(X,
458486
}
459487
return(list(active_set=active_set, pvalues=pvalues, ci=ci))
460488
}
489+
490+
491+

selectiveInference/R/sampler.R

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
2+
log_concave_sampler = function(negative_log_density,
3+
grad_negative_log_density,
4+
constraints,
5+
observed,
6+
nsamples){
7+
#print(constraints)
8+
constraints = as.matrix(constraints)
9+
dim = nrow(constraints)
10+
11+
get_poisson_process = function(state){
12+
pos = as.matrix(state$pos)
13+
velocity = as.matrix(state$velocity)
14+
neg_velocity = velocity<0
15+
pos_velocity = velocity>0
16+
tau_min = 0
17+
tau_max = 10
18+
if (sum(neg_velocity)>0){
19+
R = (-constraints[neg_velocity,1]+pos[neg_velocity])/(-velocity[neg_velocity])
20+
tau_max = min(tau_max, min(R))
21+
L = (-constraints[neg_velocity,2]+pos[neg_velocity])/(-velocity[neg_velocity])
22+
tau_min = max(tau_min, max(L))
23+
}
24+
if (sum(pos_velocity)>0){
25+
R = (constraints[pos_velocity,2]-pos[pos_velocity])/velocity[pos_velocity]
26+
tau_max = min(tau_max, min(R))
27+
L = (constraints[pos_velocity,1]-pos[pos_velocity])/velocity[pos_velocity]
28+
tau_min = max(tau_min, max(L))
29+
}
30+
31+
f=function(t){as.numeric(t(velocity) %*% grad_negative_log_density(pos+velocity*t))}
32+
tau_star = tau_max
33+
if (f(tau_min)*f(tau_max)<0){
34+
tau_star = uniroot(f, c(tau_min, tau_max))$root
35+
} else{
36+
if (negative_log_density(pos+velocity*tau_min)<negative_log_density(pos+velocity*tau_max)){
37+
tau_star = tau_min
38+
}
39+
}
40+
41+
tau_min = max(tau_min, tau_star)
42+
43+
RHS = negative_log_density(pos+velocity*tau_star)+rexp(1)
44+
g = function(t){negative_log_density(pos+velocity*t)-RHS}
45+
if (g(tau_min)*g(tau_max)<0){
46+
tau = uniroot(g, c(tau_min, tau_max))$root
47+
} else{
48+
tau = tau_max
49+
}
50+
return (tau)
51+
}
52+
53+
update_velocity = function(){
54+
Z=rnorm(dim)
55+
return(Z/sqrt(t(Z)%*%Z))
56+
}
57+
58+
compute_next = function(state){
59+
bounce_time = get_poisson_process(state)/2
60+
#print(paste("bounce time", bounce_time))
61+
next_pos = state$pos+state$velocity*bounce_time
62+
next_velocity=update_velocity()
63+
return(list(pos=next_pos, velocity=next_velocity))
64+
}
65+
66+
state = list(pos=observed, velocity = update_velocity())
67+
samples = matrix(0, nrow = nsamples, ncol = dim)
68+
for (i in 1:nsamples){
69+
#print(paste("pos", toString(state$pos)))
70+
#print(paste("velocity", toString(state$velocity)))
71+
samples[i,]=state$pos
72+
state = compute_next(state)
73+
}
74+
return (samples)
75+
}
76+
77+
gaussian_sampler = function(noise_scale, observed, linear_term, offset_term, nsamples){
78+
79+
negative_log_density = function(x) {
80+
recon = linear_term %*% x+offset_term
81+
return(as.numeric(t(recon)%*%recon/(2*noise_scale^2)))
82+
}
83+
grad_negative_log_density=function(x){
84+
recon = linear_term %*% x+offset_term
85+
return(t(linear_term)%*% recon/(noise_scale^2))
86+
}
87+
dim = length(observed)
88+
constraints = matrix(0,dim,2)
89+
constraints[,2] = Inf
90+
91+
return(log_concave_sampler(negative_log_density,
92+
grad_negative_log_density,
93+
constraints,
94+
observed,
95+
nsamples))
96+
}

tests/randomized/test_instances.R

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,46 @@ gaussian_instance = function(n, p, s, sigma=1, rho=0, signal=6, X=NA,
2424
return(result)
2525
}
2626

27+
test_randomized_lasso = function(n=100,p=200,s=0){
28+
set.seed(1)
29+
data = gaussian_instance(n=n,p=p,s=s, rho=0.3, sigma=3)
30+
X=data$X
31+
y=data$y
32+
lam = 2.
33+
noise_scale = 0.5
34+
ridge_term = 1./sqrt(n)
35+
result = selectiveInference:::randomizedLasso(X,y,lam, noise_scale, ridge_term)
36+
print(result$soln)
37+
print(length(which(result$soln!=0)))
38+
print(result$observed_opt_state) # compared with python code
39+
}
40+
41+
test_KKT=function(){
42+
set.seed(1)
43+
n=200
44+
p=100
45+
data = gaussian_instance(n=n,p=p,s=0, rho=0.3, sigma=3)
46+
X=data$X
47+
y=data$y
48+
lam = 2.
49+
noise_scale = 0.5
50+
ridge_term = 1./sqrt(n)
51+
result = selectiveInference:::randomizedLasso(X,y,lam, noise_scale, ridge_term)
52+
print("check KKT")
53+
opt_linear = result$optimization_transform$linear_term
54+
opt_offset = result$optimization_transform$offset_term
55+
observed_opt_state=result$observed_opt_state
56+
#print(dim(opt_linear))
57+
#print(opt_offset)
58+
#print(result$perturb)
59+
print(opt_linear %*% observed_opt_state+opt_offset+result$observed_raw-result$perturb) ## should be zero
60+
}
61+
62+
2763

2864
collect_results = function(n,p,s, nsim=100, level=0.9, condition_subgrad=TRUE, lam=1.2){
2965

30-
rho=0.3
66+
rho=0.
3167
sigma=1
3268
sample_pvalues = c()
3369
sample_coverage = c()
@@ -36,7 +72,14 @@ collect_results = function(n,p,s, nsim=100, level=0.9, condition_subgrad=TRUE, l
3672
X=data$X
3773
y=data$y
3874
beta=data$beta
39-
result = selectiveInference:::randomizedLassoInf(X, y, lam, level=level, burnin=2000, nsample=4000, condition_subgrad=condition_subgrad)
75+
result = selectiveInference:::randomizedLassoInf(X, y,
76+
lam=lam,
77+
sigma=sigma,
78+
level=level,
79+
sampler = "A",
80+
burnin=1000,
81+
nsample=5000,
82+
condition_subgrad=condition_subgrad)
4083
true_beta = beta[result$active_set]
4184
coverage = rep(0, nrow(result$ci))
4285
if (length(result$active_set)>0){
@@ -61,7 +104,7 @@ collect_results = function(n,p,s, nsim=100, level=0.9, condition_subgrad=TRUE, l
61104
}
62105

63106
set.seed(1)
64-
collect_results(n=200, p=100, s=0, lam=2)
65-
66-
107+
collect_results(n=100, p=2000, s=0, lam=3)
108+
#test_randomized_lasso()
109+
#test_KKT()
67110

tests/randomized/test_sampler.R

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
2+
test_log_concave_sampler = function(){
3+
samples = log_concave_sampler(negative_log_density= function(x){x^2/2},
4+
grad_negative_log_density=function(x){x},
5+
constraints = t(as.matrix(c(2,3))),
6+
observed = 2, nsamples=10000)
7+
mean(samples)
8+
hist(samples)
9+
}
10+
11+
12+
test_gaussian_sampler =function(){
13+
samples = gaussian_sampler(1, 1, 1, 0,10000)
14+
mean(samples)
15+
hist(samples)
16+
}

0 commit comments

Comments
 (0)