-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtunesampler.m
More file actions
81 lines (63 loc) · 1.89 KB
/
tunesampler.m
File metadata and controls
81 lines (63 loc) · 1.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
% change gpml path
addpath("../CNNForecasting/gpml-matlab-v3.6-2015-07-07");
addpath("model");
% addpath("/Users/yahoo/Documents/WashU/CSE515T/Code/Gaussian Process/gpml-matlab-v3.6-2015-07-07");
startup;
rng('default');
% load, augment
load_data;
% data is:
% 1: day number
% 2: group id
% 3: unit id
% 4: day number (replicated, useful for prediction)
% 5: weekday number
% 6: day number (set to zero for task 1, used for drift process)
x = [x, x(:, 1), mod(x(:, 1), 7), x(:, 1)];
x(x(:, 2) == 1, end) = 0;
% setup model
localnewsmodel;
% find MAP
p.method = 'LBFGS';
p.length = 100;
theta = minimize_v2(theta, @gp, p, inference_method, mean_function, ...
covariance_function, [], x, y);
% sampler parameters
num_chains = 5;
num_samples = 3000;
burn_in = 1000;
jitter = 1e-1;
% setup sampler
% select index of hyperparameters to sample
theta_ind = false(size(unwrap(theta)));
% just sample drift parameters
% theta_ind([14:16]) = true;
theta_ind([1:3, 6, 7, 10, 12, 14:16, 17]) = true;
theta_0 = unwrap(theta);
theta_0 = theta_0(theta_ind);
f = @(unwrapped_theta) ...
l(unwrapped_theta, theta_ind, theta, inference_method, mean_function, ...
covariance_function, x, y);
% create and tune sampler
hmc = hmcSampler(f, theta_0 + randn(size(theta_0)) * jitter);
tic;
[hmc, tune_info] = ...
tuneSampler(hmc, ...
'verbositylevel', 2, ...
'numprint', 10, ...
'numstepsizetuningiterations', 100, ...
'numstepslimit', 500);
toc;
i = 1;
rng(i);
tic;
[chain, endpoint, acceptance_ratio] = ...
drawSamples(hmc, ...
'start', theta_0 + jitter * randn(size(theta_0)), ...
'burnin', burn_in, ...
'numsamples', num_samples, ...
'verbositylevel', 1, ...
'numprint', 10);
toc;
save("localnews_1.mat");
% save("tunesampler.mat");