Skip to content

Commit e62f264

Browse files
committed
formal peaky behavior CTC
1 parent 4279e54 commit e62f264

File tree

8 files changed

+3737
-0
lines changed

8 files changed

+3737
-0
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
This is the code for the paper "Why does CTC result in peaky behavior".
2+
3+
Please see the docstring of `simple_model.py`,
4+
which covers most synthetic experiments from the paper.
5+
6+
`simple_model.txt` contains some further comments and notes about some experiments on synthetic data.
7+
8+
See the docstring of `fst_utils.py` for the symbolic computations and proofs.
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
2+
from sage.all import *
3+
# noinspection PyUnresolvedReferences
4+
from sage.calculus.all import var
5+
#import numpy
6+
7+
8+
def gen_model_1label():
9+
"""
10+
\\sum_{s:y} p(x|s),
11+
two possible inputs x1 (1,0) and x2 (0,1),
12+
two possible labels "a" and (blank) "B".
13+
Define p(x1|s=a) = theta_a, p(x2|s=a) = 1 - theta_a,
14+
p(x2|s=B) = theta_B, p(x1|s=B) = 1 - theta_B.
15+
16+
For simplicity, fsa ^= a*B*, and the input be x1^{na},x2^{nB}, T = na + nB.
17+
Then we can just count. All alignments can be iterated through by t=0...T.
18+
Symmetric case...
19+
"""
20+
na = var("na", domain=ZZ)
21+
nb = var("nb", domain=ZZ)
22+
theta_a = var("theta_a", domain=RR)
23+
theta_b = var("theta_b", domain=RR)
24+
t = var("t", domain=ZZ)
25+
# Make 2 parts of the sum, one t=0...na, another t=na..T.
26+
# Should get rid of the min/max cases, simplify it.
27+
p1 = theta_a ** min_symbolic(t, na)
28+
p2 = (1 - theta_a) ** max_symbolic(t - na, 0)
29+
p3 = theta_b ** min_symbolic(na + nb - t, nb) # exp = min(na - t, 0) + nb
30+
p4 = (1 - theta_b) ** max_symbolic(na - t, 0)
31+
sum_ = sum(p1 * p2 * p3 * p4, t, 0, na + nb)
32+
33+
for _ in range(6):
34+
sum_ = sum_.simplify()
35+
print(sum_)
36+
37+
#sum__ = sum_.substitute(na=10, nb=10)
38+
#xs = ys = numpy.linspace(0, 1., num=11)
39+
#values = numpy.zeros((len(xs), len(ys)))
40+
#for ix, x in enumerate(xs):
41+
# for iy, y in enumerate(ys):
42+
# value = sum__.substitute(theta_a=x, theta_b=y)
43+
# print("theta = (%f, %f) -> sum = %s" % (x, y, value))
44+
#values[ix, iy] = float(sum__.subs(theta_a, x).subs(theta_b, y).doit())
45+
#print(values)
46+
47+
#syms = (theta_a, theta_b)
48+
#syms = (theta_b,)
49+
syms = (theta_a,)
50+
sum_diff = sum_.diff(*syms)
51+
sum_diff = sum_diff.simplify()
52+
print("diff:", sum_diff)
53+
#for _ in range(5):
54+
# sum_diff = sum_diff.simplify()
55+
# print(sum_diff)
56+
# sum_diff = sum_diff.simplify() # -- also makes it harder?
57+
opts = solve(sum_diff == 0, *syms, domain=RR)
58+
print("num opts:", len(opts))
59+
for opt in opts:
60+
print("opt:", opt)
61+
62+
63+
def main():
64+
if len(sys.argv) >= 2:
65+
globals()[sys.argv[1]]() # eg test_ctc()
66+
return
67+
68+
print("Usage: %s <func>" % __file__)
69+
sys.exit(1)
70+
71+
72+
if __name__ == '__main__':
73+
#import better_exchook
74+
#better_exchook.install()
75+
main()

0 commit comments

Comments
 (0)