-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_mixture2.m
More file actions
55 lines (47 loc) · 1.31 KB
/
test_mixture2.m
File metadata and controls
55 lines (47 loc) · 1.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
norm1 = normal_density([3;4], [1 1;1 2]);
norm2 = normal_density([2;3], [2 1;1 1]);
mix = mixture_density([0.5 0.5], norm1, norm2);
disp(mix);
data = sample(mix, 80);
figure(1);
plot(data(1, :), data(2, :), '.');
draw(mix);
norm1 = normal_density(randn(2, 1), eye(2));
norm1 = set_prior(norm1, normal_density(NaN, 1));
norm2 = normal_density(randn(2, 1), eye(2));
norm2 = set_prior(norm2, normal_density(NaN, 1));
%norm1 = set_cov_type(norm1, 'spherical');
%norm2 = set_cov_type(norm2, 'spherical');
%norm1 = set_cov_type(norm1, 'diagonal');
%norm2 = set_cov_type(norm2, 'diagonal');
obj = mixture_density([0.5 0.5], norm1, norm2);
obj = train(obj, data);
disp(obj);
figure(1);
draw(obj, 'r');
if 0
%mix2 = sample_posterior(mix);
%draw(mix2, 'k');
post = posterior_predict(obj, data);
figure(2);
plot(data(1, :), data(2, :), '.');
density_image(post);
% compute the divergence with the true distribution (smaller is better)
x = lattice([0 0.1 6; 0 0.1 6]);
ptrue = logProb(mix, x);
p1 = logProb(obj, x);
p2 = logProb(post, x);
sum(exp(ptrue) .* (ptrue - p1))
sum(exp(ptrue) .* (ptrue - p2))
end
if 0
% compare m1-m2 to the principal axis of the data
c = get_components(obj);
m1 = get_mean(c{1});
m2 = get_mean(c{2});
v = m1-m2;
v/norm(v)
c = cov_t(data);
[v,e] = sorted_eig(c);
v(:,1)
end