@@ -221,3 +221,108 @@ importance_weight = function(noise_scale,
221
221
return (exp(W ))
222
222
}
223
223
224
+ # ## Jelena's functions
225
+
226
+ conditional_density = function (noise_scale , lasso_soln ){
227
+
228
+ active_set = lasso_soln $ active_set
229
+ observed_raw = lasso_soln $ observed_raw
230
+ opt_linear = lasso_soln $ optimization_transform $ linear_term
231
+ opt_offset = lasso_soln $ optimization_transform $ offset_term
232
+ observed_opt_state = lasso_soln $ observed_opt_state
233
+
234
+ nactive = length(active_set )
235
+ B = opt_linear [,1 : nactive ]
236
+ beta_offset = opt_offset
237
+ p = length(observed_opt_state )
238
+ if (nactive < p ){
239
+ beta_offset = beta_offset + (opt_linear [,(nactive + 1 ): p ] %*% observed_opt_state [(nactive + 1 ): p ])
240
+ }
241
+ opt_transform = list (linear_term = B ,
242
+ offset_term = beta_offset )
243
+ reduced_B = chol(t(B ) %*% B )
244
+ beta_offset = beta_offset + observed_raw
245
+ reduced_beta_offset = solve(t(reduced_B )) %*% (t(B ) %*% beta_offset )
246
+
247
+ log_condl_optimization_density = function (opt_state ) {
248
+ if (sum(opt_state < 0 ) > 0 ) {
249
+ return (- Inf )
250
+ }
251
+ D = selectiveInference ::: log_density_gaussian_conditional_(noise_scale ,
252
+ reduced_B ,
253
+ as.matrix(opt_state ),
254
+ reduced_beta_offset )
255
+ return (D )
256
+ }
257
+ lasso_soln $ log_optimization_density = log_condl_optimization_density
258
+ lasso_soln $ observed_opt_state = observed_opt_state [1 : nactive ]
259
+ lasso_soln $ optimization_transform = opt_transform
260
+ return (lasso_soln )
261
+ }
262
+
263
+
264
+ randomized_inference = function (X ,y ,sigma , lam , noise_scale , ridge_term ){
265
+ n = nrow(X )
266
+ p = ncol(X )
267
+ lasso_soln = selectiveInference ::: randomizedLASSO(X , y , lam , noise_scale , ridge_term )
268
+ active_set = lasso_soln $ active_set
269
+ inactive_set = lasso_soln $ inactive_set
270
+ nactive = length(active_set )
271
+ print(paste(" nactive" , nactive ))
272
+
273
+ # lasso_soln = conditional_density(noise_scale, lasso_soln)
274
+
275
+ dim = length(lasso_soln $ observed_opt_state )
276
+ print(paste(" chain dim" , dim ))
277
+ S = selectiveInference ::: sample_opt_variables(lasso_soln , jump_scale = rep(1 / sqrt(n ), dim ), nsample = 10000 )
278
+ opt_samples = S $ samples [2001 : 10000 ,]
279
+ print(paste(" dim opt samples" , toString(dim(opt_samples ))))
280
+
281
+ X_E = X [, active_set ]
282
+ X_minusE = X [, inactive_set ]
283
+ target_cov = solve(t(X_E ) %*% X_E )* sigma ^ 2
284
+ cov_target_internal = rbind(target_cov , matrix (0 , nrow = p - nactive , ncol = nactive ))
285
+ observed_target = solve(t(X_E ) %*% X_E ) %*% t(X_E ) %*% y
286
+ observed_internal = c(observed_target , t(X_minusE ) %*% (y - X_E %*% observed_target ))
287
+ internal_transform = lasso_soln $ internal_transform
288
+ opt_transform = lasso_soln $ optimization_transform
289
+ observed_raw = lasso_soln $ observed_raw
290
+
291
+ pivots = rep(0 , nactive )
292
+ ci = matrix (0 , nactive , 2 )
293
+ for (i in 1 : nactive ){
294
+ target_transform = selectiveInference ::: linear_decomposition(observed_target [i ],
295
+ observed_internal ,
296
+ target_cov [i ,i ],
297
+ cov_target_internal [,i ],
298
+ internal_transform )
299
+ target_sample = rnorm(nrow(opt_samples )) * sqrt(target_cov [i ,i ])
300
+
301
+ pivot = function (candidate ){
302
+ weights = selectiveInference ::: importance_weight(noise_scale ,
303
+ t(as.matrix(target_sample ))+ candidate ,
304
+ t(opt_samples ),
305
+ opt_transform ,
306
+ target_transform ,
307
+ observed_raw )
308
+ return (mean((target_sample < observed_target [i ])* weights )/ mean(weights ))
309
+ }
310
+ level = 0.9
311
+ rootU = function (candidate ){
312
+ return (pivot(observed_target [i ]+ candidate )- (1 - level )/ 2 )
313
+ }
314
+ rootL = function (candidate ){
315
+ return (pivot(observed_target [i ]+ candidate )- (1 + level )/ 2 )
316
+ }
317
+ pivots [i ] = pivot(0 )
318
+ line_min = - 10 * sd(target_sample )
319
+ line_max = 10 * sd(target_sample )
320
+ ci [i ,1 ] = uniroot(rootU , c(line_min , line_max ))$ root + observed_target [i ]
321
+ ci [i ,2 ] = uniroot(rootL ,c(line_min , line_max ))$ root + observed_target [i ]
322
+ }
323
+ print(paste(" pivots" , toString(pivots )))
324
+ for (i in 1 : nactive ){
325
+ print(paste(" CIs" , toString(ci [i ,])))
326
+ }
327
+ return (list (pivots = pivots , ci = ci ))
328
+ }
0 commit comments