Skip to content

Commit ae9aec8

Browse files
WIP: have sample from optimization density, importance weight -- need to write pivot function
1 parent 8bd3c12 commit ae9aec8

File tree

5 files changed

+112
-3
lines changed

5 files changed

+112
-3
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ addons:
1111
warnings_are_errors: true
1212
before_install:
1313
- tlmgr install index # for texlive and vignette?
14-
- R -e 'install.packages(c("Rcpp", "intervals"), repos="http://cloud.r-project.org")'
14+
- R -e 'install.packages(c("Rcpp", "intervals", "adaptMCMC", "glmnet"), repos="http://cloud.r-project.org")'
1515
- cd C-software
1616
- git submodule init
1717
- git submodule update

selectiveInference/DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Depends:
1010
glmnet,
1111
intervals,
1212
survival,
13+
adaptMCMC,
1314
Suggests:
1415
Rmpfr
1516
Description: New tools for post-selection inference, for use with forward

selectiveInference/NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,4 @@ importFrom("stats", "coef", "df", "lm", "pf")
4545
importFrom("stats", "glm", "residuals", "vcov")
4646
importFrom("stats", "rbinom", "rexp")
4747
importFrom("Rcpp", "sourceCpp")
48+
importFrom("adaptMCMC", "MCMC")

selectiveInference/R/funs.randomized.R

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ randomizedLASSO = function(X,
6767
objective_stop, # objective_stop
6868
kkt_stop, # kkt_stop
6969
param_stop) # param_stop
70-
70+
7171
sign_soln = sign(result$soln)
7272

7373
unpenalized = lam == 0
@@ -78,6 +78,14 @@ randomizedLASSO = function(X,
7878
active_set = which(active)
7979
inactive_set = which(inactive)
8080

81+
# observed opt state
82+
83+
observed_scalings = abs(result$soln)[active]
84+
observed_unpen = result$soln[unpenalized]
85+
observed_subgrad = result$gradient[inactive]
86+
87+
observed_opt_state = c(observed_unpen, observed_scalings, observed_subgrad)
88+
8189
# affine transform for optimization variables
8290

8391
E = c(unpenalized_set, active_set)
@@ -120,12 +128,97 @@ randomizedLASSO = function(X,
120128
internal_transform = list(linear_term = linear_term,
121129
offset_term = offset_term)
122130

131+
# density for sampling optimization variables
132+
133+
observed_raw = -t(X) %*% Y
134+
inactive_lam = lam[inactive_set]
135+
inactive_start = sum(unpenalized) + sum(active)
136+
active_start = sum(unpenalized)
137+
138+
# XXX only for Gaussian so far
139+
140+
log_optimization_density = function(opt_state
141+
) {
142+
143+
144+
if ((sum(abs(opt_state[(inactive_start + 1):p]) > inactive_lam) > 0) ||
145+
(sum(opt_state[(active_start+1):inactive_start] < 0) > 0)) {
146+
return(-Inf)
147+
}
148+
149+
D = log_density_gaussian_conditional_(noise_scale,
150+
opt_transform$linear_term,
151+
as.matrix(opt_state),
152+
observed_raw)
153+
return(D)
154+
}
155+
123156
return(list(active_set = active_set,
124157
inactive_set = inactive_set,
125158
unpenalized_set = unpenalized_set,
126159
sign_soln = sign_soln,
127160
optimization_transform = opt_transform,
128-
internal_transform = internal_transform
161+
internal_transform = internal_transform,
162+
log_optimization_density = log_optimization_density,
163+
observed_opt_state = observed_opt_state,
164+
observed_raw = observed_raw
129165
))
130166

131167
}
168+
169+
sample_opt_variables = function(randomizedLASSO_obj, jump_scale, nsample=10000) {
170+
return(MCMC(randomizedLASSO_obj$log_optimization_density,
171+
nsample,
172+
randomizedLASSO_obj$observed_opt_state,
173+
acc.rate=0.2,
174+
scale=jump_scale))
175+
}
176+
177+
# Carry out a linear decompositon of an internal
178+
# representation with respect to a target
179+
180+
# Returns an affine transform into raw coordinates (i.e. \omega or randomization coordinates)
181+
182+
linear_decomposition = function(observed_target,
183+
observed_internal,
184+
var_target,
185+
cov_target_internal,
186+
internal_transform) {
187+
var_target = as.matrix(var_target)
188+
if (nrow(var_target) == 1) {
189+
nuisance = observed_internal - cov_target_internal * observed_target / var_target
190+
target_linear = internal_transform$linear_part %*% cov_target_internal / var_target
191+
} else {
192+
nuisance = observed_internal - cov_target_internal %*% solve(var_target) %*% observed_target
193+
target_linear = internal_transform$linear_part %*% cov_target_internal %*% solve(var_target)
194+
}
195+
target_offset = internal_transform$linear_part %*% nuisance + internal_transform$offset
196+
return(list(linear_term=target_linear,
197+
offset_term=target_offset))
198+
}
199+
200+
# XXX only for Gaussian so far
201+
202+
importance_weight = function(noise_scale,
203+
target_sample,
204+
opt_sample,
205+
opt_transform,
206+
target_transform,
207+
observed_raw) {
208+
209+
log_num = log_density_gaussian_(noise_scale,
210+
target_transform$linear_term,
211+
as.matrix(target_sample),
212+
optimization_transform$linear_term,
213+
as.matrix(opt_state),
214+
target_transform$offset_term + optimization_transform$offset_term)
215+
216+
log_den = log_density_gaussian_conditional_(noise_scale,
217+
opt_transform$linear_term,
218+
as.matrix(opt_sample),
219+
observed_raw)
220+
W = log_num - log_den
221+
W = W - max(W)
222+
return(exp(W))
223+
}
224+

tests/randomized/test_randomized.R

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,20 @@ smoke_test = function() {
1212
}
1313
A = smoke_test()
1414

15+
sampler_test = function() {
16+
17+
n = 100; p = 50
18+
X = matrix(rnorm(n * p), n, p)
19+
y = rnorm(n)
20+
lam = 20 / sqrt(n)
21+
noise_scale = 0.01 * sqrt(n)
22+
ridge_term = .1 / sqrt(n)
23+
obj = selectiveInference:::randomizedLASSO(X, y, lam, noise_scale, ridge_term)
24+
S = selectiveInference:::sample_opt_variables(obj, jump_scale=rep(1/sqrt(n), p), nsample=10000)
25+
return(S$samples[2001:10000,])
26+
}
27+
B = sampler_test()
28+
1529
gaussian_density_test = function() {
1630

1731
noise_scale = 10.

0 commit comments

Comments
 (0)