Skip to content

Commit 2a8dc77

Browse files
more renaming
1 parent 1fa2e7b commit 2a8dc77

File tree

1 file changed

+23
-23
lines changed

1 file changed

+23
-23
lines changed

selectiveInference/R/funs.ROSI.R

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ solve_restricted_problem = function(X, y, var, lambda_glmnet, penalty_factor, lo
3535
return(restricted_soln)
3636
}
3737

38-
solve_problem_Q = function(Q_sq,
38+
solve_problem_Q = function(Xdesign,
3939
Qbeta_bar,
4040
lambda_glmnet,
4141
penalty_factor,
@@ -46,25 +46,24 @@ solve_problem_Q = function(Q_sq,
4646
kkt_stop=TRUE,
4747
objective_stop=TRUE,
4848
parameter_stop=TRUE){
49-
n=nrow(Q_sq)
50-
p=ncol(Q_sq)
49+
n=nrow(Xdesign)
50+
p=ncol(Xdesign)
5151

52-
Xinfo = Q_sq
5352
linear_func = -as.numeric(Qbeta_bar)
5453
soln = as.numeric(rep(0., p))
5554
ever_active = as.integer(rep(0, p))
5655
ever_active[1] = 1
5756
ever_active = as.integer(ever_active)
5857
nactive = as.integer(1)
59-
Xsoln = as.numeric(rep(0, nrow(Xinfo)))
58+
Xsoln = as.numeric(rep(0, n))
6059
gradient = 1. * linear_func
61-
max_active=as.integer(p)
60+
max_active = as.integer(p)
6261

6362
linear_func = linear_func/n
6463
gradient = gradient/n
6564

66-
#solve_QP_wide solves n*slinear_func^T\beta+\beta^T Xinfo\beta+\sum\lambda_i|\beta_i|
67-
result = solve_QP_wide(Xinfo, # this is a design matrix
65+
#solve_QP_wide solves n*slinear_func^T\beta+1/2(X\beta)^T (X\beta)+\sum\lambda_i|\beta_i|
66+
result = solve_QP_wide(Xdesign, # this is a design matrix
6867
as.numeric(penalty_factor*lambda_glmnet), # vector of Lagrange multipliers
6968
0, # ridge_term
7069
max_iter,
@@ -95,9 +94,9 @@ truncation_set = function(X,
9594
y,
9695
Qbeta_bar,
9796
QE,
98-
Q_sq,
97+
Xdesign,
9998
target_stat,
100-
target_cov,
99+
QiE,
101100
var,
102101
active_set,
103102
lambda_glmnet,
@@ -108,7 +107,7 @@ truncation_set = function(X,
108107
if (solver=="QP"){
109108
penalty_factor_rest = rep(penalty_factor)
110109
penalty_factor_rest[var] = 10^10
111-
restricted_soln = solve_problem_Q(Q_sq,
110+
restricted_soln = solve_problem_Q(Xdesign,
112111
Qbeta_bar,
113112
lambda_glmnet,
114113
penalty_factor=penalty_factor_rest)
@@ -124,7 +123,7 @@ truncation_set = function(X,
124123
n = nrow(X)
125124
idx = match(var, active_set) # active_set[idx]=var
126125
nuisance_res = (Qbeta_bar[var] - # nuisance stat restricted to active vars
127-
solve(target_cov) %*% target_stat)/n
126+
solve(QiE) %*% target_stat)/n
128127
center = nuisance_res - (QE[idx,] %*% restricted_soln/n)
129128
radius = penalty_factor[var]*lambda_glmnet
130129
return(list(center=center*n, radius=radius*n))
@@ -158,10 +157,13 @@ tnorm.union.surv = function(z,
158157

159158
pval = matrix(NA, nrow = dim(intervals)[1], ncol = length(mean))
160159

160+
print('truncation intervals')
161+
print(intervals)
162+
161163
for(jj in 1:dim(intervals)[1]){
162164
if(z <= intervals[jj,1]){
163165
pval[jj,] = 1
164-
}else if(z >= intervals[jj,2]){
166+
} else if(z >= intervals[jj,2]){
165167
pval[jj,] = 0
166168
}else{
167169
pval[jj,] = tnorm.surv(z, mean, sd, intervals[jj,1], intervals[jj,2], bits=bits)
@@ -321,7 +323,7 @@ setup_Qbeta = function(X,
321323
Qi = solve(Q) ## (X^TWX)^{-1}
322324
QiE = Qi[active_set,][, active_set]
323325

324-
Q_sq = W_root %*% X
326+
Xdesign = W_root %*% X
325327
beta_bar = soln - Qi %*% gradient(X, y, soln, loss=loss)
326328
Qbeta_bar = Q%*%soln - gradient(X, y, soln, loss=loss)
327329
beta_barE = beta_bar[active_set]
@@ -342,12 +344,12 @@ setup_Qbeta = function(X,
342344
M2 = M_active %*% t(W_root %*% X)
343345
QiE = M2 %*% t(M2) # size |E|\times |E|
344346
QE = hessian_active(X, soln, loss, active_set)
345-
Q_sq = W_root %*% X
347+
Xdesign = W_root %*% X
346348
Qbeta_bar = t(QE)%*%soln[active_set] - G
347349
}
348350

349351
return(list(QE=QE,
350-
Q_sq=Q_sq,
352+
Xdesign=Xdesign,
351353
Qbeta_bar=Qbeta_bar,
352354
QiE=QiE,
353355
beta_barE=beta_barE,
@@ -395,7 +397,7 @@ ROSI = function(X,
395397
use_debiased=use_debiased)
396398

397399
QE = as.matrix(setup_params$QE)
398-
Q_sq = setup_params$Q_sq
400+
Xdesign = setup_params$Xdesign
399401
QiE = as.matrix(setup_params$QiE)
400402
beta_barE = setup_params$beta_barE
401403
Qbeta_bar = setup_params$Qbeta_bar
@@ -426,8 +428,8 @@ ROSI = function(X,
426428
y=y,
427429
Qbeta_bar=Qbeta_bar,
428430
QE=QE,
429-
Q_sq=Q_sq,
430-
target_cov=target_cov, # this Hessian, i.e. without dispersion
431+
Xdesign=Xdesign,
432+
QiE=QiE[i,i,drop=FALSE], # this is part of inverse Hessian, i.e. without dispersion
431433
target_stat=target_stat,
432434
var=active_set[i],
433435
active_set=active_set,
@@ -484,6 +486,7 @@ ROSI = function(X,
484486
pvalues = c(pvalues, NA)
485487
sel_intervals = rbind(sel_intervals, c(NA, NA))
486488
warning("observation not within the truncation limits!")
489+
print("observation not within the truncation limits!")
487490
}
488491
}
489492

@@ -528,10 +531,8 @@ print.ROSI <- function(x, ...) {
528531

529532
# Some little used functions -- not exported
530533

531-
532534
compute_coverage = function(ci, beta){
533-
print(ci)
534-
print(beta)
535+
535536
nactive=length(beta)
536537
coverage_vector = rep(0, nactive)
537538
for (i in 1:nactive){
@@ -554,7 +555,6 @@ loss_label = function(family) {
554555
}
555556
}
556557

557-
558558
family_label = function(loss){
559559
if (loss=="ls"){
560560
return("gaussian")

0 commit comments

Comments
 (0)