-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlearning_lambeq.py
More file actions
89 lines (62 loc) · 2.65 KB
/
learning_lambeq.py
File metadata and controls
89 lines (62 loc) · 2.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import warnings
warnings.filterwarnings('ignore')
from lambeq import AtomicType, BobcatParser, TensorAnsatz, SpiderAnsatz, IQPAnsatz,Sim14Ansatz, Sim15Ansatz
from lambeq.backend.tensor import Dim
from discopy.quantum.gates import CX, Rx, H, Bra, Id
# Define atomic types
N = AtomicType.NOUN
S = AtomicType.SENTENCE
C = AtomicType.CONJUNCTION
ansatz_to_use= Sim15Ansatz
# Parse a sentence
parser = BobcatParser()
diagram = parser.sentence2diagram('Alice Loves Bob ')
train_diagrams=[diagram]
from lambeq import AtomicType, IQPAnsatz, RemoveCupsRewriter
ansatz = IQPAnsatz({AtomicType.NOUN: 1, AtomicType.SENTENCE: 0},
n_layers=1, n_single_qubit_params=3)
remove_cups = RemoveCupsRewriter()
train_circuits = [ansatz(remove_cups(diagram)) for diagram in train_diagrams]
train_circuits[0].draw(figsize=(9, 10))
from tqdm import tqdm
from lambeq import Rewriter, remove_cups
rewriter = Rewriter(['prepositional_phrase', 'determiner', 'coordination', 'connector', 'prepositional_phrase'])
train_X = []
test_X = []
for d in tqdm(train_diags_raw):
train_X.append(remove_cups(rewriter(d).normal_form()))
N = AtomicType.NOUN
S = AtomicType.SENTENCE
P = AtomicType.PREPOSITIONAL_PHRASE
equality_comparator = (CX >> (H @ Rx(0.5)) >> (Bra(0) @ Id(1)))
ansatz = IQPAnsatz({N: 1, S: 1, P:1}, n_layers=1, n_single_qubit_params=3)
train_circs = ansatz(diagram)
ansatz = SpiderAnsatz({AtomicType.NOUN: Dim(4),
AtomicType.SENTENCE: Dim(2)})
tensor_diagram = ansatz(diagram)
from discopy.quantum.gates import CX, Rx, H, Bra, Id
b= tensor_diagram>> equality_comparator
# tensor_diagram.draw(figsize=(12,5), fontsize=12)
print(f"when using bobcat parser and {ansatz_to_use} ")
print(tensor_diagram.free_symbols)
max_word_param_length= 9999
if(ansatz_to_use==Sim15Ansatz):
c = ansatz()
# ( all_symb=[]
# a=0
# for symb in tensor_diagram.free_symbols:
# all_symb.append(symb.name.rsplit('_', 1)[1])
# a=max(all_symb)
# max_word_param_length = max(max(int(symb.name.rsplit('_', 1)[1]) for symb in tensor_diagram.free_symbols ),
# max(int(symb.name.rsplit('_', 1)[1]) for symb in tensor_diagram.free_symbols)) + 1
if(ansatz_to_use==SpiderAnsatz):
all=[]
for symb in tensor_diagram.free_symbols:
x= symb.name.split('_', 1)[1]
y= x.split('__')[0]
all.append(y)
max_word_param_length = max(max(int(symb.name.split('_', 1)[1]) for symb in tensor_diagram.free_symbols),
max(int(symb.name.rsplit('_', 1)[1]) for symb in tensor_diagram.free_symbols)) + 1
print(f"value of max_word_param_length={max_word_param_length}")
import sys
sys.exit()