@@ -82,17 +82,20 @@ fixedLassoInf <- function(x, y, beta,
82
82
83
83
tol.coef = tol.beta * sqrt(n / colSums(x ^ 2 ))
84
84
vars = which(abs(beta ) > tol.coef )
85
+ sign_vars = sign(beta [vars ])
85
86
86
87
if (sum(vars )== 0 ){
87
88
cat(" Empty model" ,fill = T )
88
89
return ()
89
90
}
90
- if (any(sign(g [vars ]) != sign(beta [vars ])))
91
+
92
+ if (any(sign(g [vars ]) != sign_vars )) {
91
93
warning(paste(" Solution beta does not satisfy the KKT conditions" ,
92
94
" (to within specified tolerances). You might try rerunning" ,
93
95
" glmnet with a lower setting of the" ,
94
96
" 'thresh' parameter, for a more accurate convergence." ))
95
-
97
+ }
98
+
96
99
# Get lasso polyhedral region, of form Gy >= u
97
100
98
101
logical.vars = rep(FALSE ,p )
@@ -132,13 +135,19 @@ fixedLassoInf <- function(x, y, beta,
132
135
}
133
136
134
137
# add additional targets for inference if provided
135
- if (! is.null(add.targets )) vars = sort(unique(c(vars ,add.targets ,recursive = T )))
136
-
137
- k = length(vars )
138
+ if (! is.null(add.targets )) {
139
+ # vars is boolean...
140
+ old_vars = vars & TRUE
141
+ vars [add.targets ] = TRUE
142
+ sign_vars = sign(beta [vars ])
143
+ sign_vars [! old_vars ] = NA
144
+ stop(" `add.targets` not fully implemented yet" )
145
+ }
146
+
147
+ k = length(vars )
138
148
pv = vlo = vup = numeric (k )
139
149
vmat = matrix (0 ,k ,n )
140
150
ci = tailarea = matrix (0 ,k ,2 )
141
- sign = numeric (k )
142
151
143
152
if (type == " full" & p > n ) {
144
153
if (intercept == T ) {
@@ -202,28 +211,36 @@ fixedLassoInf <- function(x, y, beta,
202
211
vj = M [j ,]
203
212
mj = sqrt(sum(vj ^ 2 ))
204
213
vj = vj / mj # Standardize (divide by norm of vj)
205
- sign [j ] = sign(sum(vj * y ))
206
- vj = sign [j ] * vj
214
+
215
+ if (! is.na(sign_vars [j ])) {
216
+ vj = sign_vars [j ] * vj
217
+ }
207
218
208
219
limits.info = TG.limits(y , A , b , vj , Sigma = diag(rep(sigma ^ 2 , n )))
209
220
a = TG.pvalue.base(limits.info , null_value = null_value [j ], bits = bits )
210
221
pv [j ] = a $ pv
222
+ if (is.na(sign_vars [j ])) { # for variables not in the active set, report 2-sided pvalue
223
+ pv [j ] = 2 * min(pv [j ], 1 - pv [j ])
224
+ }
211
225
vlo [j ] = a $ vlo * mj # Unstandardize (mult by norm of vj)
212
226
vup [j ] = a $ vup * mj # Unstandardize (mult by norm of vj)
213
- vmat [j ,] = vj * mj * sign [j ] # Unstandardize (mult by norm of vj)
214
-
227
+ if (! is.na(sign_vars [j ])) {
228
+ vmat [j ,] = vj * mj * sign_vars [j ] # Unstandardize (mult by norm of vj) and fix sign
229
+ } else {
230
+ vmat [j ,] = vj * mj # Unstandardize (mult by norm of vj)
231
+ }
215
232
a = TG.interval.base(limits.info ,
216
233
alpha = alpha ,
217
234
gridrange = gridrange ,
218
- flip = (sign [j ]== - 1 ),
235
+ flip = (sign_vars [j ]== - 1 ),
219
236
bits = bits )
220
237
ci [j ,] = (a $ int - null_value [j ]) * mj # Unstandardize (mult by norm of vj)
221
238
tailarea [j ,] = a $ tailarea
222
239
}
223
240
224
241
out = list (type = type ,lambda = lambda ,pv = pv ,ci = ci ,
225
242
tailarea = tailarea ,vlo = vlo ,vup = vup ,vmat = vmat ,y = y ,
226
- vars = vars ,sign = sign ,sigma = sigma ,alpha = alpha ,
243
+ vars = vars ,sign = sign_vars ,sigma = sigma ,alpha = alpha ,
227
244
sd = sigma * sqrt(rowSums(vmat ^ 2 )),
228
245
coef0 = vmat %*% y ,
229
246
call = this.call )
0 commit comments