Skip to content

Commit e0ee2a7

Browse files
jelena's inference functions
1 parent 71b6586 commit e0ee2a7

File tree

1 file changed

+105
-0
lines changed

1 file changed

+105
-0
lines changed

selectiveInference/R/funs.randomized.R

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,108 @@ importance_weight = function(noise_scale,
221221
return(exp(W))
222222
}
223223

224+
### Jelena's functions
225+
226+
conditional_density = function(noise_scale, lasso_soln){
227+
228+
active_set = lasso_soln$active_set
229+
observed_raw = lasso_soln$observed_raw
230+
opt_linear = lasso_soln$optimization_transform$linear_term
231+
opt_offset = lasso_soln$optimization_transform$offset_term
232+
observed_opt_state = lasso_soln$observed_opt_state
233+
234+
nactive = length(active_set)
235+
B = opt_linear[,1:nactive]
236+
beta_offset = opt_offset
237+
p=length(observed_opt_state)
238+
if (nactive<p){
239+
beta_offset = beta_offset+(opt_linear[,(nactive+1):p] %*% observed_opt_state[(nactive+1):p])
240+
}
241+
opt_transform = list(linear_term=B,
242+
offset_term = beta_offset)
243+
reduced_B = chol(t(B) %*% B)
244+
beta_offset = beta_offset+observed_raw
245+
reduced_beta_offset = solve(t(reduced_B)) %*% (t(B) %*% beta_offset)
246+
247+
log_condl_optimization_density = function(opt_state) {
248+
if (sum(opt_state < 0) > 0) {
249+
return(-Inf)
250+
}
251+
D = selectiveInference:::log_density_gaussian_conditional_(noise_scale,
252+
reduced_B,
253+
as.matrix(opt_state),
254+
reduced_beta_offset)
255+
return(D)
256+
}
257+
lasso_soln$log_optimization_density = log_condl_optimization_density
258+
lasso_soln$observed_opt_state = observed_opt_state[1:nactive]
259+
lasso_soln$optimization_transform = opt_transform
260+
return(lasso_soln)
261+
}
262+
263+
264+
randomized_inference = function(X,y,sigma, lam, noise_scale, ridge_term){
265+
n=nrow(X)
266+
p=ncol(X)
267+
lasso_soln=selectiveInference:::randomizedLASSO(X, y, lam, noise_scale, ridge_term)
268+
active_set = lasso_soln$active_set
269+
inactive_set = lasso_soln$inactive_set
270+
nactive = length(active_set)
271+
print(paste("nactive", nactive))
272+
273+
#lasso_soln = conditional_density(noise_scale, lasso_soln)
274+
275+
dim=length(lasso_soln$observed_opt_state)
276+
print(paste("chain dim", dim))
277+
S = selectiveInference:::sample_opt_variables(lasso_soln, jump_scale=rep(1/sqrt(n), dim), nsample=10000)
278+
opt_samples = S$samples[2001:10000,]
279+
print(paste("dim opt samples", toString(dim(opt_samples))))
280+
281+
X_E=X[, active_set]
282+
X_minusE=X[, inactive_set]
283+
target_cov = solve(t(X_E) %*% X_E)*sigma^2
284+
cov_target_internal = rbind(target_cov, matrix(0, nrow=p-nactive, ncol=nactive))
285+
observed_target = solve(t(X_E) %*% X_E) %*% t(X_E) %*% y
286+
observed_internal = c(observed_target, t(X_minusE) %*% (y-X_E%*% observed_target))
287+
internal_transform = lasso_soln$internal_transform
288+
opt_transform = lasso_soln$optimization_transform
289+
observed_raw = lasso_soln$observed_raw
290+
291+
pivots = rep(0, nactive)
292+
ci = matrix(0, nactive, 2)
293+
for (i in 1:nactive){
294+
target_transform = selectiveInference:::linear_decomposition(observed_target[i],
295+
observed_internal,
296+
target_cov[i,i],
297+
cov_target_internal[,i],
298+
internal_transform)
299+
target_sample = rnorm(nrow(opt_samples)) * sqrt(target_cov[i,i])
300+
301+
pivot = function(candidate){
302+
weights = selectiveInference:::importance_weight(noise_scale,
303+
t(as.matrix(target_sample))+candidate,
304+
t(opt_samples),
305+
opt_transform,
306+
target_transform,
307+
observed_raw)
308+
return(mean((target_sample<observed_target[i])*weights)/mean(weights))
309+
}
310+
level = 0.9
311+
rootU = function(candidate){
312+
return (pivot(observed_target[i]+candidate)-(1-level)/2)
313+
}
314+
rootL = function(candidate){
315+
return (pivot(observed_target[i]+candidate)-(1+level)/2)
316+
}
317+
pivots[i] = pivot(0)
318+
line_min = -10*sd(target_sample)
319+
line_max = 10*sd(target_sample)
320+
ci[i,1] = uniroot(rootU, c(line_min, line_max))$root+observed_target[i]
321+
ci[i,2] = uniroot(rootL,c(line_min, line_max))$root+observed_target[i]
322+
}
323+
print(paste("pivots", toString(pivots)))
324+
for (i in 1:nactive){
325+
print(paste("CIs", toString(ci[i,])))
326+
}
327+
return(list(pivots=pivots, ci=ci))
328+
}

0 commit comments

Comments
 (0)