Skip to content

Commit 95b98c2

Browse files
minor cleanup -- doesn't seem to use conditional_density?
1 parent 09db28a commit 95b98c2

File tree

1 file changed

+15
-24
lines changed

1 file changed

+15
-24
lines changed

selectiveInference/R/funs.randomized.R

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,7 @@ importance_weight = function(noise_scale,
221221
return(exp(W))
222222
}
223223

224-
### Jelena's functions
225-
226-
conditional_density = function(noise_scale, lasso_soln){
224+
conditional_density = function(noise_scale, lasso_soln) {
227225

228226
active_set = lasso_soln$active_set
229227
observed_raw = lasso_soln$observed_raw
@@ -241,7 +239,7 @@ conditional_density = function(noise_scale, lasso_soln){
241239
opt_transform = list(linear_term=B,
242240
offset_term = beta_offset)
243241
reduced_B = chol(t(B) %*% B)
244-
beta_offset = beta_offset+observed_raw
242+
beta_offset = beta_offset + observed_raw
245243
reduced_beta_offset = solve(t(reduced_B)) %*% (t(B) %*% beta_offset)
246244

247245
log_condl_optimization_density = function(opt_state) {
@@ -260,26 +258,23 @@ conditional_density = function(noise_scale, lasso_soln){
260258
return(lasso_soln)
261259
}
262260

261+
randomized_inference = function(X, y, sigma, lam, noise_scale, ridge_term){
263262

264-
randomized_inference = function(X,y,sigma, lam, noise_scale, ridge_term){
265-
n=nrow(X)
266-
p=ncol(X)
267-
lasso_soln=selectiveInference:::randomizedLASSO(X, y, lam, noise_scale, ridge_term)
263+
n = nrow(X)
264+
p = ncol(X)
265+
lasso_soln = selectiveInference:::randomizedLASSO(X, y, lam, noise_scale, ridge_term)
268266
active_set = lasso_soln$active_set
269267
inactive_set = lasso_soln$inactive_set
270268
nactive = length(active_set)
271-
print(paste("nactive", nactive))
272-
273-
#lasso_soln = conditional_density(noise_scale, lasso_soln)
274-
275-
dim=length(lasso_soln$observed_opt_state)
269+
270+
dim = length(lasso_soln$observed_opt_state)
276271
print(paste("chain dim", dim))
277272
S = selectiveInference:::sample_opt_variables(lasso_soln, jump_scale=rep(1/sqrt(n), dim), nsample=10000)
278273
opt_samples = S$samples[2001:10000,]
279274
print(paste("dim opt samples", toString(dim(opt_samples))))
280275

281-
X_E=X[, active_set]
282-
X_minusE=X[, inactive_set]
276+
X_E = X[, active_set]
277+
X_minusE = X[, inactive_set]
283278
target_cov = solve(t(X_E) %*% X_E)*sigma^2
284279
cov_target_internal = rbind(target_cov, matrix(0, nrow=p-nactive, ncol=nactive))
285280
observed_target = solve(t(X_E) %*% X_E) %*% t(X_E) %*% y
@@ -288,7 +283,7 @@ randomized_inference = function(X,y,sigma, lam, noise_scale, ridge_term){
288283
opt_transform = lasso_soln$optimization_transform
289284
observed_raw = lasso_soln$observed_raw
290285

291-
pivots = rep(0, nactive)
286+
pvalus = rep(0, nactive)
292287
ci = matrix(0, nactive, 2)
293288
for (i in 1:nactive){
294289
target_transform = selectiveInference:::linear_decomposition(observed_target[i],
@@ -300,7 +295,7 @@ randomized_inference = function(X,y,sigma, lam, noise_scale, ridge_term){
300295

301296
pivot = function(candidate){
302297
weights = selectiveInference:::importance_weight(noise_scale,
303-
t(as.matrix(target_sample))+candidate,
298+
t(as.matrix(target_sample)) + candidate,
304299
t(opt_samples),
305300
opt_transform,
306301
target_transform,
@@ -314,15 +309,11 @@ randomized_inference = function(X,y,sigma, lam, noise_scale, ridge_term){
314309
rootL = function(candidate){
315310
return (pivot(observed_target[i]+candidate)-(1+level)/2)
316311
}
317-
pivots[i] = pivot(0)
312+
pvalues[i] = pivot(0)
318313
line_min = -10*sd(target_sample)
319314
line_max = 10*sd(target_sample)
320315
ci[i,1] = uniroot(rootU, c(line_min, line_max))$root+observed_target[i]
321-
ci[i,2] = uniroot(rootL,c(line_min, line_max))$root+observed_target[i]
322-
}
323-
print(paste("pivots", toString(pivots)))
324-
for (i in 1:nactive){
325-
print(paste("CIs", toString(ci[i,])))
316+
ci[i,2] = uniroot(rootL, c(line_min, line_max))$root+observed_target[i]
326317
}
327-
return(list(pivots=pivots, ci=ci))
318+
return(list(pvalues=pvalues, ci=ci))
328319
}

0 commit comments

Comments
 (0)