|
| 1 | + |
| 2 | +from sage.all import * |
| 3 | +# noinspection PyUnresolvedReferences |
| 4 | +from sage.calculus.all import var |
| 5 | +#import numpy |
| 6 | + |
| 7 | + |
| 8 | +def gen_model_1label(): |
| 9 | + """ |
| 10 | + \\sum_{s:y} p(x|s), |
| 11 | + two possible inputs x1 (1,0) and x2 (0,1), |
| 12 | + two possible labels "a" and (blank) "B". |
| 13 | + Define p(x1|s=a) = theta_a, p(x2|s=a) = 1 - theta_a, |
| 14 | + p(x2|s=B) = theta_B, p(x1|s=B) = 1 - theta_B. |
| 15 | +
|
| 16 | + For simplicity, fsa ^= a*B*, and the input be x1^{na},x2^{nB}, T = na + nB. |
| 17 | + Then we can just count. All alignments can be iterated through by t=0...T. |
| 18 | + Symmetric case... |
| 19 | + """ |
| 20 | + na = var("na", domain=ZZ) |
| 21 | + nb = var("nb", domain=ZZ) |
| 22 | + theta_a = var("theta_a", domain=RR) |
| 23 | + theta_b = var("theta_b", domain=RR) |
| 24 | + t = var("t", domain=ZZ) |
| 25 | + # Make 2 parts of the sum, one t=0...na, another t=na..T. |
| 26 | + # Should get rid of the min/max cases, simplify it. |
| 27 | + p1 = theta_a ** min_symbolic(t, na) |
| 28 | + p2 = (1 - theta_a) ** max_symbolic(t - na, 0) |
| 29 | + p3 = theta_b ** min_symbolic(na + nb - t, nb) # exp = min(na - t, 0) + nb |
| 30 | + p4 = (1 - theta_b) ** max_symbolic(na - t, 0) |
| 31 | + sum_ = sum(p1 * p2 * p3 * p4, t, 0, na + nb) |
| 32 | + |
| 33 | + for _ in range(6): |
| 34 | + sum_ = sum_.simplify() |
| 35 | + print(sum_) |
| 36 | + |
| 37 | + #sum__ = sum_.substitute(na=10, nb=10) |
| 38 | + #xs = ys = numpy.linspace(0, 1., num=11) |
| 39 | + #values = numpy.zeros((len(xs), len(ys))) |
| 40 | + #for ix, x in enumerate(xs): |
| 41 | + # for iy, y in enumerate(ys): |
| 42 | + # value = sum__.substitute(theta_a=x, theta_b=y) |
| 43 | + # print("theta = (%f, %f) -> sum = %s" % (x, y, value)) |
| 44 | + #values[ix, iy] = float(sum__.subs(theta_a, x).subs(theta_b, y).doit()) |
| 45 | + #print(values) |
| 46 | + |
| 47 | + #syms = (theta_a, theta_b) |
| 48 | + #syms = (theta_b,) |
| 49 | + syms = (theta_a,) |
| 50 | + sum_diff = sum_.diff(*syms) |
| 51 | + sum_diff = sum_diff.simplify() |
| 52 | + print("diff:", sum_diff) |
| 53 | + #for _ in range(5): |
| 54 | + # sum_diff = sum_diff.simplify() |
| 55 | + # print(sum_diff) |
| 56 | + # sum_diff = sum_diff.simplify() # -- also makes it harder? |
| 57 | + opts = solve(sum_diff == 0, *syms, domain=RR) |
| 58 | + print("num opts:", len(opts)) |
| 59 | + for opt in opts: |
| 60 | + print("opt:", opt) |
| 61 | + |
| 62 | + |
| 63 | +def main(): |
| 64 | + if len(sys.argv) >= 2: |
| 65 | + globals()[sys.argv[1]]() # eg test_ctc() |
| 66 | + return |
| 67 | + |
| 68 | + print("Usage: %s <func>" % __file__) |
| 69 | + sys.exit(1) |
| 70 | + |
| 71 | + |
| 72 | +if __name__ == '__main__': |
| 73 | + #import better_exchook |
| 74 | + #better_exchook.install() |
| 75 | + main() |
0 commit comments