Skip to content

Commit f84f96a

Browse files
Merge pull request #24 from kevinbfry/rlasso_changes
randomized now matches python in unit test.
2 parents 7364528 + f99384a commit f84f96a

File tree

3 files changed

+96
-109
lines changed

3 files changed

+96
-109
lines changed

selectiveInference/DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ Depends:
1010
glmnet,
1111
intervals,
1212
survival,
13-
adaptMCMC
13+
adaptMCMC,
14+
MASS
1415
Suggests:
1516
Rmpfr
1617
Description: New tools for post-selection inference, for use with forward

selectiveInference/NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,4 @@ importFrom("stats", "glm", "residuals", "vcov")
6363
importFrom("stats", "rbinom", "rexp")
6464
importFrom("Rcpp", "sourceCpp")
6565
importFrom("adaptMCMC", "MCMC")
66+
importFrom("MASS","mvrnorm")

selectiveInference/R/funs.randomized.R

Lines changed: 93 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ randomizedLasso = function(X,
1515
objective_tol=1.e-8, # tolerance for relative decrease in objective
1616
objective_stop=FALSE,
1717
kkt_stop=TRUE,
18-
parameter_stop=TRUE)
18+
parameter_stop=TRUE,
19+
for_test=FALSE)
1920
{
2021
family = match.arg(family)
2122

@@ -221,24 +222,25 @@ randomizedLasso = function(X,
221222
conditional_law$constraints = constraints
222223
law = conditional_law
223224

224-
return(list(X=X,
225-
y=y,
226-
lam=lam,
227-
family=family,
228-
active_set=active_set,
229-
inactive_set=inactive_set,
230-
unpenalized_set=unpenalized_set,
231-
sign_soln=sign_soln,
232-
law=law,
233-
internal_transform=internal_transform,
234-
observed_internal=observed_internal,
235-
observed_raw=observed_raw,
236-
noise_scale=noise_scale,
237-
soln=result$soln,
238-
perturb=perturb_,
239-
ridge_term=ridge_term
240-
))
241-
225+
return_list = list(X=X,
226+
y=y,
227+
lam=lam,
228+
family=family,
229+
active_set=active_set,
230+
inactive_set=inactive_set,
231+
unpenalized_set=unpenalized_set,
232+
sign_soln=sign_soln,
233+
law=law,
234+
internal_transform=internal_transform,
235+
observed_internal=observed_internal,
236+
observed_raw=observed_raw,
237+
noise_scale=noise_scale,
238+
soln=result$soln)
239+
if (for_test) {
240+
return_list$perturb=perturb_
241+
return_list$ridge_term=ridge_term
242+
}
243+
return(return_list)
242244
}
243245

244246
sample_opt_variables = function(law, jump_scale, nsample=10000) {
@@ -256,25 +258,27 @@ importance_weight = function(noise_scale,
256258
opt_sample,
257259
opt_transform,
258260
target_transform,
259-
observed_raw) {
260-
261+
law_density) {
262+
263+
for_candidate = function(candidate) {
264+
261265
log_num = log_density_gaussian_(noise_scale,
262266
target_transform$linear_term,
263-
as.matrix(target_sample),
267+
as.matrix(target_sample) + candidate,
264268
opt_transform$linear_term,
265269
as.matrix(opt_sample),
266270
target_transform$offset_term + opt_transform$offset_term)
267-
268-
log_den = log_density_gaussian_conditional_(noise_scale,
269-
opt_transform$linear_term,
270-
as.matrix(opt_sample),
271-
observed_raw + opt_transform$offset_term)
272-
271+
272+
log_den = law_density(as.matrix(opt_sample))
273+
273274
W = log_num - log_den
274275
W = W - max(W)
275276
W = exp(W)
276-
W = W / sum(W)
277-
return(W)
277+
return(list(W=W,
278+
log_num=log_num,
279+
log_den=log_den))
280+
}
281+
return(for_candidate)
278282
}
279283

280284
conditional_opt_transform = function(noise_scale,
@@ -352,7 +356,7 @@ compute_target = function(rand_lasso_soln,
352356
if (rand_lasso_soln$family == 'gaussian') {
353357
glm_y = glm(y ~ X_E-1)
354358
sigma_res = sigma(glm_y)
355-
glm_cov = vcov(glm_y)*sigma_est^2/(sigma_res^2)
359+
glm_cov = vcov(glm_y)
356360
} else if (rand_lasso_soln$family == 'binomial') {
357361
glm_y = glm(y ~ X_E-1, family=binomial())
358362
glm_cov = vcov(glm_y)
@@ -367,17 +371,18 @@ compute_target = function(rand_lasso_soln,
367371
stop("unregularized (relaxed) fit has NA values -- X[,active_set] likely singular")
368372
}
369373

370-
crosscov_target_internal=rbind(cov_target, matrix(0, nrow=p-nactive, ncol=nactive))
374+
crosscov_target_internal=-(t(X)%*%X_E)%*%cov_target
371375
}
372376

373-
alternatives = c()
374-
for (i in 1:length(rand_lasso_soln$sign_soln)) {
375-
if (rand_lasso_soln$sign_soln[i] == 1) {
376-
alternatives = c(alternatives, 'greater')
377-
} else {
378-
alternatives = c(alternatives, 'less')
379-
}
380-
}
377+
alternatives = rep("two-sided",length(rand_lasso_soln$sign_soln))
378+
# alternatives = c()
379+
# for (i in 1:length(rand_lasso_soln$sign_soln)) {
380+
# if (rand_lasso_soln$sign_soln[i] == 1) {
381+
# alternatives = c(alternatives, 'greater')
382+
# } else {
383+
# alternatives = c(alternatives, 'less')
384+
# }
385+
# }
381386

382387
if (type=="full"){
383388

@@ -455,7 +460,11 @@ randomizedLassoInf = function(rand_lasso_soln,
455460
level=0.9,
456461
sampler=c("norejection", "adaptMCMC"),
457462
nsample=10000,
458-
burnin=2000)
463+
burnin=2000,
464+
weight_mat=NULL,
465+
opt_samples=NULL,
466+
target_samples=NULL,
467+
for_test=FALSE)
459468
{
460469

461470
n = nrow(rand_lasso_soln$X)
@@ -476,19 +485,21 @@ randomizedLassoInf = function(rand_lasso_soln,
476485
sampler = match.arg(sampler)
477486

478487
law = rand_lasso_soln$law
479-
480-
if (sampler == "adaptMCMC"){
481-
S = sample_opt_variables(law,
482-
jump_scale=rep(1/sqrt(n), length(law$observed_opt_state)), nsample=nsample)
483-
opt_samples = as.matrix(S$samples[(burnin+1):nsample,,drop=FALSE])
484-
} else if (sampler == "norejection") {
485-
opt_samples = gaussian_sampler(noise_scale,
486-
law$observed_opt_state,
487-
law$sampling_transform$linear_term,
488-
law$sampling_transform$offset_term,
489-
law$constraints,
490-
nsamples=nsample,
491-
burnin=burnin)
488+
489+
if (is.null(opt_samples)) {
490+
if (sampler == "adaptMCMC"){
491+
S = sample_opt_variables(law,
492+
jump_scale=rep(1/sqrt(n), length(law$observed_opt_state)), nsample=nsample)
493+
opt_samples = as.matrix(S$samples[(burnin+1):nsample,,drop=FALSE])
494+
} else if (sampler == "norejection") {
495+
opt_samples = gaussian_sampler(noise_scale,
496+
law$observed_opt_state,
497+
law$sampling_transform$linear_term,
498+
law$sampling_transform$offset_term,
499+
law$constraints,
500+
nsamples=nsample,
501+
burnin=burnin)
502+
}
492503
}
493504

494505
if (is.null(targets)){
@@ -511,65 +522,31 @@ randomizedLassoInf = function(rand_lasso_soln,
511522

512523
names(pvalues) = names(targets$observed_target)
513524
rownames(ci) = names(targets$observed_target)
514-
525+
526+
527+
target_samples = mvrnorm(nrow(as.matrix(opt_samples)),rep(0,nactive),targets$cov_target)
528+
515529
for (i in 1:nactive){
516-
pre_nuisance = observed_internal - (as.vector(targets$crosscov_target_internal[,i]) *
517-
targets$observed_target[i] /
518-
targets$cov_target[i,i])
530+
target_sample = target_samples[,i]
519531

520-
nuisance = internal_transform$linear_term %*% pre_nuisance[1:nactive]
521-
nuisance[inactive_set] = nuisance[inactive_set] - pre_nuisance[(nactive+1):p]
522-
523-
pre_linear_term = targets$crosscov_target_internal[,i] / targets$cov_target[i,i]
524-
linear_term = rep(0, p)
525-
linear_term = internal_transform$linear_term %*% pre_linear_term[1:nactive]
526-
linear_term[inactive_set] = linear_term[inactive_set] - pre_linear_term[(nactive+1):p]
532+
reduced_linear = solve(t(law$sampling_transform$linear_term)) %*% t(importance_transform$linear_term)
533+
linear_term = reduced_linear%*%(as.matrix(targets$crosscov_target_internal[,i],ncol=1) /
534+
targets$cov_target[i,i])
535+
obs_opt_contrib = linear_term * targets$observed_target[i]
527536
target_transform = list(linear_term=linear_term,
528-
offset_term=as.vector(nuisance + internal_transform$offset_term)) # internal_transform$offset_term is 0...
537+
offset_term=as.vector(-obs_opt_contrib))
529538

530-
# compute sufficient statistic for root finding
531-
532-
target_sample = rnorm(nrow(as.matrix(opt_samples))) * sqrt(targets$cov_target[i,i])
533-
534-
# weight in the numerator is of the form
535-
# -1/(2 noise_scale^2)\|Do + q + P(t+\theta)\|^2_2
536-
# with D=importance_transform$linear_term
537-
# q=target_transform$offset_term + importance_transform$offset_term
538-
# P=target_transform$linear_term
539-
540-
# weight in the denominator is of the form
541-
# -1/(2 noise_scale^2)\|Do + q_D\|^2_2
542-
# with D=importance_transform$linear_term
543-
# q_D = observed_raw + importance_transform$offset_term
544-
545-
# reference measure just is the ratio at \theta=0
546-
# sufficient statistic is linear term in \theta
547-
548-
den = importance_transform$linear_term %*% t(opt_samples) + observed_raw + importance_transform$offset_term
549-
550-
num1 = (importance_transform$linear_term %*% t(opt_samples) +
551-
target_transform$linear_term %*% t(as.matrix(target_sample)) +
552-
importance_transform$offset_term +
553-
target_transform$offset_term)
554-
num2 = (importance_transform$linear_term %*% t(opt_samples) +
555-
target_transform$linear_term %*% t(as.matrix(target_sample) + 1) +
556-
importance_transform$offset_term +
557-
target_transform$offset_term)
558-
559-
sufficient_stat = -apply(num2^2 - num1^2, 2, sum) / (2 * noise_scale^2)
560-
561-
reference_measure = importance_weight(noise_scale,
562-
t(as.matrix(target_sample)),
563-
t(opt_samples),
564-
importance_transform,
565-
target_transform,
566-
observed_raw)
567-
log_reference_measure = log(reference_measure)
539+
weighting_transform = law$sampling_transform
540+
541+
importance_for_candidate = importance_weight(noise_scale,
542+
t(as.matrix(target_sample)),
543+
t(as.matrix(opt_samples)),
544+
weighting_transform,
545+
target_transform,
546+
law$log_optimization_density)
568547

569548
pivot = function(candidate){
570-
arg_ = candidate * sufficient_stat + log_reference_measure
571-
arg_ = arg_ - max(arg_)
572-
weights = exp(arg_)
549+
weights = importance_for_candidate(candidate)$W
573550
p = mean((target_sample + candidate <= targets$observed_target[i]) * weights)/mean(weights)
574551
return(p)
575552
}
@@ -608,7 +585,15 @@ randomizedLassoInf = function(rand_lasso_soln,
608585
}
609586
}
610587
}
611-
return(list(targets=targets, pvalues=pvalues, ci=ci))
588+
589+
return_list = list(targets=targets,
590+
pvalues=pvalues,
591+
ci=ci)
592+
if (for_test) {
593+
return_list$opt_samples=opt_samples
594+
return_list$target_samples=target_samples
595+
}
596+
return(return_list)
612597
}
613598

614599
logistic_fitted = function(X, beta){

0 commit comments

Comments
 (0)