@@ -221,9 +221,7 @@ importance_weight = function(noise_scale,
221
221
return (exp(W ))
222
222
}
223
223
224
- # ## Jelena's functions
225
-
226
- conditional_density = function (noise_scale , lasso_soln ){
224
+ conditional_density = function (noise_scale , lasso_soln ) {
227
225
228
226
active_set = lasso_soln $ active_set
229
227
observed_raw = lasso_soln $ observed_raw
@@ -241,7 +239,7 @@ conditional_density = function(noise_scale, lasso_soln){
241
239
opt_transform = list (linear_term = B ,
242
240
offset_term = beta_offset )
243
241
reduced_B = chol(t(B ) %*% B )
244
- beta_offset = beta_offset + observed_raw
242
+ beta_offset = beta_offset + observed_raw
245
243
reduced_beta_offset = solve(t(reduced_B )) %*% (t(B ) %*% beta_offset )
246
244
247
245
log_condl_optimization_density = function (opt_state ) {
@@ -260,26 +258,23 @@ conditional_density = function(noise_scale, lasso_soln){
260
258
return (lasso_soln )
261
259
}
262
260
261
+ randomized_inference = function (X , y , sigma , lam , noise_scale , ridge_term ){
263
262
264
- randomized_inference = function (X ,y ,sigma , lam , noise_scale , ridge_term ){
265
- n = nrow(X )
266
- p = ncol(X )
267
- lasso_soln = selectiveInference ::: randomizedLASSO(X , y , lam , noise_scale , ridge_term )
263
+ n = nrow(X )
264
+ p = ncol(X )
265
+ lasso_soln = selectiveInference ::: randomizedLASSO(X , y , lam , noise_scale , ridge_term )
268
266
active_set = lasso_soln $ active_set
269
267
inactive_set = lasso_soln $ inactive_set
270
268
nactive = length(active_set )
271
- print(paste(" nactive" , nactive ))
272
-
273
- # lasso_soln = conditional_density(noise_scale, lasso_soln)
274
-
275
- dim = length(lasso_soln $ observed_opt_state )
269
+
270
+ dim = length(lasso_soln $ observed_opt_state )
276
271
print(paste(" chain dim" , dim ))
277
272
S = selectiveInference ::: sample_opt_variables(lasso_soln , jump_scale = rep(1 / sqrt(n ), dim ), nsample = 10000 )
278
273
opt_samples = S $ samples [2001 : 10000 ,]
279
274
print(paste(" dim opt samples" , toString(dim(opt_samples ))))
280
275
281
- X_E = X [, active_set ]
282
- X_minusE = X [, inactive_set ]
276
+ X_E = X [, active_set ]
277
+ X_minusE = X [, inactive_set ]
283
278
target_cov = solve(t(X_E ) %*% X_E )* sigma ^ 2
284
279
cov_target_internal = rbind(target_cov , matrix (0 , nrow = p - nactive , ncol = nactive ))
285
280
observed_target = solve(t(X_E ) %*% X_E ) %*% t(X_E ) %*% y
@@ -288,7 +283,7 @@ randomized_inference = function(X,y,sigma, lam, noise_scale, ridge_term){
288
283
opt_transform = lasso_soln $ optimization_transform
289
284
observed_raw = lasso_soln $ observed_raw
290
285
291
- pivots = rep(0 , nactive )
286
+ pvalus = rep(0 , nactive )
292
287
ci = matrix (0 , nactive , 2 )
293
288
for (i in 1 : nactive ){
294
289
target_transform = selectiveInference ::: linear_decomposition(observed_target [i ],
@@ -300,7 +295,7 @@ randomized_inference = function(X,y,sigma, lam, noise_scale, ridge_term){
300
295
301
296
pivot = function (candidate ){
302
297
weights = selectiveInference ::: importance_weight(noise_scale ,
303
- t(as.matrix(target_sample ))+ candidate ,
298
+ t(as.matrix(target_sample )) + candidate ,
304
299
t(opt_samples ),
305
300
opt_transform ,
306
301
target_transform ,
@@ -314,15 +309,11 @@ randomized_inference = function(X,y,sigma, lam, noise_scale, ridge_term){
314
309
rootL = function (candidate ){
315
310
return (pivot(observed_target [i ]+ candidate )- (1 + level )/ 2 )
316
311
}
317
- pivots [i ] = pivot(0 )
312
+ pvalues [i ] = pivot(0 )
318
313
line_min = - 10 * sd(target_sample )
319
314
line_max = 10 * sd(target_sample )
320
315
ci [i ,1 ] = uniroot(rootU , c(line_min , line_max ))$ root + observed_target [i ]
321
- ci [i ,2 ] = uniroot(rootL ,c(line_min , line_max ))$ root + observed_target [i ]
322
- }
323
- print(paste(" pivots" , toString(pivots )))
324
- for (i in 1 : nactive ){
325
- print(paste(" CIs" , toString(ci [i ,])))
316
+ ci [i ,2 ] = uniroot(rootL , c(line_min , line_max ))$ root + observed_target [i ]
326
317
}
327
- return (list (pivots = pivots , ci = ci ))
318
+ return (list (pvalues = pvalues , ci = ci ))
328
319
}
0 commit comments