Skip to content

Commit d2438c9

Browse files
author
Jelena Markovic
committed
amir sampler working with subgrad condition false as well
1 parent b4fd8b5 commit d2438c9

File tree

3 files changed

+26
-15
lines changed

3 files changed

+26
-15
lines changed

selectiveInference/R/funs.randomized.R

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -374,14 +374,23 @@ randomizedLassoInf = function(X,
374374
}
375375
inactive_set = lasso_soln$inactive_set
376376

377-
378377
noise_scale = lasso_soln$noise_scale # set to default value in randomizedLasso
379-
380-
if (condition_subgrad==TRUE){
378+
379+
constraints = matrix(0,nactive,2)
380+
constraints[,2] = Inf
381+
if (condition_subgrad==TRUE){
381382
condl_lasso=conditional_density(noise_scale, lasso_soln)
382383
lasso_soln = condl_lasso$lasso_soln
383-
reduced_opt_transform = condl_lasso$reduced_opt_transform
384-
}
384+
cur_opt_transform = condl_lasso$reduced_opt_transform
385+
} else{
386+
if (nactive<p){
387+
subgrad_constraints = matrix(-lam, p-nactive, 2)
388+
subgrad_constraints[,2]=lam
389+
constraints = rbind(constraints, subgrad_constraints)
390+
}
391+
cur_opt_transform = list(linear_term = lasso_soln$optimization_transform$linear_term,
392+
offset_term = lasso_soln$optimization_transform$offset_term+lasso_soln$observed_raw)
393+
}
385394

386395
ndim = length(lasso_soln$observed_opt_state)
387396

@@ -391,8 +400,9 @@ randomizedLassoInf = function(X,
391400
} else if (sampler == "A"){
392401
opt_samples = gaussian_sampler(noise_scale,
393402
lasso_soln$observed_opt_state,
394-
reduced_opt_transform$linear_term,
395-
reduced_opt_transform$offset_term,
403+
cur_opt_transform$linear_term,
404+
cur_opt_transform$offset_term,
405+
constraints,
396406
nsamples=nsample)
397407
opt_sample = opt_samples[(burnin+1):nsample,]
398408
}
@@ -439,7 +449,6 @@ randomizedLassoInf = function(X,
439449
cur_linear = reduced_target_opt_linear[,2:ncol(reduced_target_opt_linear)]
440450
cur_offset = temp %*% opt_transform$offset_term
441451
cur_transform = list(linear_term = as.matrix(cur_linear), offset_term = cur_offset)
442-
443452
raw = target_transform$linear_term * observed_target[i] + target_transform$offset_term
444453
} else {
445454
cur_transform = opt_transform

selectiveInference/R/sampler.R

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,12 @@ log_concave_sampler = function(negative_log_density,
7474
return (samples)
7575
}
7676

77-
gaussian_sampler = function(noise_scale, observed, linear_term, offset_term, nsamples){
77+
gaussian_sampler = function(noise_scale,
78+
observed,
79+
linear_term,
80+
offset_term,
81+
constraints,
82+
nsamples){
7883

7984
negative_log_density = function(x) {
8085
recon = linear_term %*% x+offset_term
@@ -84,9 +89,6 @@ gaussian_sampler = function(noise_scale, observed, linear_term, offset_term, nsa
8489
recon = linear_term %*% x+offset_term
8590
return(t(linear_term)%*% recon/(noise_scale^2))
8691
}
87-
dim = length(observed)
88-
constraints = matrix(0,dim,2)
89-
constraints[,2] = Inf
9092

9193
return(log_concave_sampler(negative_log_density,
9294
grad_negative_log_density,

tests/randomized/test_instances.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ test_KKT=function(){
6161

6262

6363

64-
collect_results = function(n,p,s, nsim=100, level=0.9, condition_subgrad=TRUE, lam=1.2){
64+
collect_results = function(n,p,s, nsim=100, level=0.9, condition_subgrad=FALSE, lam=1.2){
6565

6666
rho=0.3
6767
sigma=1
@@ -76,7 +76,7 @@ collect_results = function(n,p,s, nsim=100, level=0.9, condition_subgrad=TRUE, l
7676
lam=lam,
7777
sigma=sigma,
7878
level=level,
79-
sampler = "R",
79+
sampler = "A",
8080
burnin=1000,
8181
nsample=5000,
8282
condition_subgrad=condition_subgrad)
@@ -104,7 +104,7 @@ collect_results = function(n,p,s, nsim=100, level=0.9, condition_subgrad=TRUE, l
104104
}
105105

106106
set.seed(1)
107-
collect_results(n=100, p=2000, s=0, lam=2.5)
107+
collect_results(n=100, p=20, s=0, lam=1.2)
108108
#test_randomized_lasso()
109109
#test_KKT()
110110

0 commit comments

Comments
 (0)