Skip to content

Commit 69bf473

Browse files
Merge branch 'jelena-markovic-logistic'
2 parents 183bc3a + 42cd878 commit 69bf473

File tree

3 files changed

+90
-29
lines changed

3 files changed

+90
-29
lines changed

selectiveInference/R/funs.randomized.R

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
randomizedLasso = function(X,
77
y,
88
lam,
9+
family=c("gaussian","binomial"),
910
noise_scale=NULL,
1011
ridge_term=NULL,
1112
noise_type=c('gaussian', 'laplace'),
@@ -17,6 +18,7 @@ randomizedLasso = function(X,
1718
kkt_stop=TRUE,
1819
parameter_stop=TRUE)
1920
{
21+
family = match.arg(family)
2022

2123
n = nrow(X); p = ncol(X)
2224

@@ -65,8 +67,8 @@ randomizedLasso = function(X,
6567
nactive = as.integer(0)
6668

6769
result = solve_QP_wide(X, # design matrix
68-
lam / n, # vector of Lagrange multipliers
69-
ridge_term / n, # ridge_term
70+
lam / n, # vector of Lagrange multipliers
71+
ridge_term / n, # ridge_term
7072
max_iter,
7173
soln,
7274
linear_func,
@@ -83,7 +85,7 @@ randomizedLasso = function(X,
8385
parameter_stop) # param_stop
8486

8587
sign_soln = sign(result$soln)
86-
88+
8789
unpenalized = lam == 0
8890
active = (!unpenalized) & (sign_soln != 0)
8991
inactive = (!unpenalized) & (sign_soln == 0)
@@ -110,8 +112,26 @@ randomizedLasso = function(X,
110112
I = inactive_set
111113
X_E = X[,E]
112114
X_I = X[,I]
113-
L_E = t(X) %*% X[,E]
114-
115+
116+
if (length(E)==0){
117+
return(list(active_set=c()))
118+
}
119+
120+
if (family=="binomial"){
121+
unpen_reg = glm(y~X_E-1, family="binomial")
122+
unpen_est = unpen_reg$coefficients
123+
pi_fn = function(beta){
124+
temp = X_E %*% as.matrix(beta)
125+
return(as.vector(exp(temp)/(1+exp(temp)))) # n-dimensional
126+
}
127+
pi_vec = pi_fn(unpen_est)
128+
W_E = diag(pi_vec*(1-pi_vec))
129+
} else if (family=="gaussian"){
130+
W_E = diag(rep(1,n))
131+
}
132+
133+
L_E = t(X) %*% W_E %*% X[,E]
134+
115135
coef_term = L_E
116136

117137
signs_ = c(rep(1, sum(unpenalized)), sign_soln[active])
@@ -155,8 +175,12 @@ randomizedLasso = function(X,
155175
offset_term = offset_term)
156176

157177
# density for sampling optimization variables
158-
178+
159179
observed_raw = -t(X) %*% y
180+
if (family=="binomial"){
181+
beta_E = result$soln[active_set]
182+
observed_raw = observed_raw + t(X)%*%pi_fn(beta_E) - L_E %*% beta_E
183+
}
160184
inactive_lam = lam[inactive_set]
161185
inactive_start = sum(unpenalized) + sum(active)
162186
active_start = sum(unpenalized)
@@ -191,11 +215,11 @@ randomizedLasso = function(X,
191215
optimization_transform = opt_transform,
192216
internal_transform = internal_transform,
193217
log_optimization_density = log_optimization_density,
194-
observed_opt_state = observed_opt_state,
218+
observed_opt_state = observed_opt_state,
195219
observed_raw = observed_raw,
196-
noise_scale = noise_scale,
197-
soln = result$soln,
198-
perturb = perturb_
220+
noise_scale = noise_scale,
221+
soln = result$soln,
222+
perturb = perturb_
199223
))
200224

201225
}
@@ -330,6 +354,7 @@ conditional_density = function(noise_scale, lasso_soln) {
330354
randomizedLassoInf = function(X,
331355
y,
332356
lam,
357+
family=c("gaussian", "binomial"),
333358
sigma=NULL,
334359
noise_scale=NULL,
335360
ridge_term=NULL,
@@ -349,10 +374,13 @@ randomizedLassoInf = function(X,
349374

350375
n = nrow(X)
351376
p = ncol(X)
352-
377+
378+
family = match.arg(family)
379+
353380
lasso_soln = randomizedLasso(X,
354381
y,
355382
lam,
383+
family=family,
356384
noise_scale=noise_scale,
357385
ridge_term=ridge_term,
358386
max_iter=max_iter,
@@ -409,17 +437,30 @@ randomizedLassoInf = function(X,
409437
X_E = X[, active_set]
410438
X_minusE = X[, inactive_set]
411439

412-
# if no sigma given, use OLS estimate
413-
440+
if (family == "gaussian") {
441+
lm_y = lm(y ~ X_E - 1)
442+
sigma_resid = sqrt(sum(resid(lm_y)^2) / lm_y$df.resid)
443+
observed_target = lm_y$coefficients
444+
W_E = diag(rep(1,n))
445+
observed_internal = c(observed_target, t(X_minusE) %*% (y-X_E%*% observed_target))
446+
} else if (family == "binomial") {
447+
glm_y = glm(y~X_E-1)
448+
sigma_resid = sqrt(sum(resid(glm_y)^2) / glm_y$df.resid)
449+
observed_target = as.matrix(glm_y$coefficients)
450+
temp = X_E%*%observed_target
451+
pi_vec = exp(temp)/(1+exp(temp))
452+
observed_internal = c(observed_target, t(X_minusE) %*% (y-pi_vec))
453+
W_E=diag(as.vector(pi_vec *(1-pi_vec)))
454+
}
455+
456+
# if no sigma given, use the estimate
457+
414458
if (is.null(sigma)) {
415-
lm_y = lm(y ~ X_E - 1)
416-
sigma = sqrt(sum(resid(lm_y)^2) / lm_y$df.resid)
459+
sigma = sigma_resid
417460
}
418-
419-
target_cov = solve(t(X_E) %*% X_E)*sigma^2
461+
462+
target_cov = solve(t(X_E) %*% W_E %*% X_E)*sigma^2
420463
cov_target_internal = rbind(target_cov, matrix(0, nrow=p-nactive, ncol=nactive))
421-
observed_target = solve(t(X_E) %*% X_E) %*% t(X_E) %*% y
422-
observed_internal = c(observed_target, t(X_minusE) %*% (y-X_E%*% observed_target))
423464
internal_transform = lasso_soln$internal_transform
424465
opt_transform = lasso_soln$optimization_transform
425466
observed_raw = lasso_soln$observed_raw
@@ -494,5 +535,10 @@ randomizedLassoInf = function(X,
494535
return(list(active_set=active_set, pvalues=pvalues, ci=ci))
495536
}
496537

538+
539+
540+
541+
542+
497543

498544

selectiveInference/man/randomizedLassoInf.Rd

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ randomization.
1414
randomizedLassoInf(X,
1515
y,
1616
lam,
17+
family=c("gaussian", "binomial"),
1718
sigma=NULL,
1819
noise_scale=NULL,
1920
ridge_term=NULL,
@@ -49,6 +50,9 @@ Value of lambda used to compute beta. See the above warning
4950
where obj is the object returned by glmnet (and [-1] removes the intercept,
5051
which glmnet always puts in the first component)
5152
}
53+
\item{family}{
54+
Response type: "gaussian" (default), "binomial".
55+
}
5256
\item{sigma}{
5357
Estimate of error standard deviation. If NULL (default), this is estimated
5458
using the mean squared residual of the full least squares based on

tests/randomized/test_instances.R

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
library(selectiveInference)
22

3-
gaussian_instance = function(n, p, s, sigma=1, rho=0, signal=6, X=NA,
4-
random_signs=TRUE, scale=TRUE, center=TRUE, seed=NA){
3+
get_instance = function(n, p, s, sigma=1, rho=0, signal=6, family="gaussian",
4+
X=NA, random_signs=TRUE, scale=TRUE, center=TRUE, seed=NA){
55
if (!is.na(seed)){
66
set.seed(seed)
77
}
@@ -19,11 +19,20 @@ gaussian_instance = function(n, p, s, sigma=1, rho=0, signal=6, X=NA,
1919
signs = sample(c(-1,1), s, replace = TRUE)
2020
beta = beta * signs
2121
}
22-
y = X %*% beta + rnorm(n)*sigma
22+
mu = X %*% beta
23+
if (family=="gaussian"){
24+
y = mu + rnorm(n)*sigma
25+
} else if (family=="binomial"){
26+
prob = exp(mu)/(1+exp(mu))
27+
y= rbinom(n,1, prob)
28+
}
2329
result = list(X=X,y=y,beta=beta)
2430
return(result)
2531
}
2632

33+
34+
35+
2736
test_randomized_lasso = function(n=100,p=200,s=0){
2837
set.seed(1)
2938
data = gaussian_instance(n=n,p=p,s=s, rho=0.3, sigma=3)
@@ -61,27 +70,29 @@ test_KKT=function(){
6170

6271

6372

64-
collect_results = function(n,p,s, nsim=100, level=0.9, condition_subgrad=FALSE, lam=1.2){
73+
collect_results = function(n,p,s, nsim=100, level=0.9,
74+
family = "binomial",
75+
condition_subgrad=FALSE, lam=1.2){
6576

6677
rho=0.3
6778
sigma=1
6879
sample_pvalues = c()
6980
sample_coverage = c()
7081
for (i in 1:nsim){
71-
data = gaussian_instance(n=n,p=p,s=s, rho=rho, sigma=sigma)
82+
data = get_instance(n=n,p=p,s=s, rho=rho, sigma=sigma, family=family)
7283
X=data$X
7384
y=data$y
74-
beta=data$beta
7585
result = selectiveInference:::randomizedLassoInf(X, y,
76-
lam=lam,
86+
lam,
87+
family = family,
88+
sampler = "A",
7789
sigma=sigma,
7890
level=level,
79-
sampler = "A",
8091
burnin=1000,
8192
nsample=5000,
8293
condition_subgrad=condition_subgrad)
8394
if (length(result$active_set)>0){
84-
true_beta = beta[result$active_set]
95+
true_beta = data$beta[result$active_set]
8596
coverage = rep(0, nrow(result$ci))
8697
for (i in 1:nrow(result$ci)){
8798
if (result$ci[i,1]<true_beta[i] & result$ci[i,2]>true_beta[i]){
@@ -104,7 +115,7 @@ collect_results = function(n,p,s, nsim=100, level=0.9, condition_subgrad=FALSE,
104115
}
105116

106117
set.seed(1)
107-
collect_results(n=100, p=20, s=0, lam=1.2)
118+
collect_results(n=100, p=20, s=0, lam=0.8)
108119
#test_randomized_lasso()
109120
#test_KKT()
110121

0 commit comments

Comments
 (0)