Skip to content
This repository was archived by the owner on Nov 18, 2023. It is now read-only.

Commit 9b7459a

Browse files
jmsfltchrGaneshwara Herawan Hananda
authored andcommitted
Simplify QueryGraph interface (#118)
## What is the goal of this PR? Simplify the user's job of building QueryGraphs as much as possible ## What are the changes implemented in this PR? - Make `'solution'` class mandatory when building QueryGraphs, this reduces the possibility of error - Make methods chainable and therefore more readable
1 parent 6b028b5 commit 9b7459a

File tree

5 files changed

+108
-80
lines changed

5 files changed

+108
-80
lines changed

kglib/kgcn/examples/diagnosis/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ py_test(
88
],
99
deps = [
1010
"diagnosis",
11+
"//kglib/utils/graph/test",
1112
requirement('numpy'),
1213
requirement('networkx'),
1314
requirement('decorator'),

kglib/kgcn/examples/diagnosis/diagnosis.py

Lines changed: 50 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@
3434
URI = "localhost:48555"
3535

3636
# Existing elements in the graph are those that pre-exist in the graph, and should be predicted to continue to exist
37-
PREEXISTS = dict(solution=0)
37+
PREEXISTS = 0
3838

3939
# Candidates are neither present in the input nor in the solution, they are negative samples
40-
CANDIDATE = dict(solution=1)
40+
CANDIDATE = 1
4141

4242
# Elements to infer are the graph elements whose existence we want to predict to be true, they are positive samples
43-
TO_INFER = dict(solution=2)
43+
TO_INFER = 2
4444

4545
# Categorical Attribute types and the values of their categories
4646
CATEGORICAL_ATTRIBUTES = {'name': ['Diabetes Type II', 'Multiple Sclerosis', 'Blurred vision', 'Fatigue', 'Cigarettes',
@@ -138,19 +138,23 @@ def create_concept_graphs(example_indices, grakn_session):
138138
# Build a graph from the queries, samplers, and query graphs
139139
graph = build_graph_from_queries(graph_query_handles, tx, infer=infer)
140140

141-
# Remove label leakage - change type labels that indicate candidates into non-candidates
142-
for data in multidigraph_data_iterator(graph):
143-
for label_to_obfuscate, with_label in TYPES_AND_ROLES_TO_OBFUSCATE.items():
144-
if data['type'] == label_to_obfuscate:
145-
data.update(type=with_label)
146-
break
141+
obfuscate_labels(graph, TYPES_AND_ROLES_TO_OBFUSCATE)
147142

148143
graph.name = example_id
149144
graphs.append(graph)
150145

151146
return graphs
152147

153148

149+
def obfuscate_labels(graph, types_and_roles_to_obfuscate):
150+
# Remove label leakage - change type labels that indicate candidates into non-candidates
151+
for data in multidigraph_data_iterator(graph):
152+
for label_to_obfuscate, with_label in types_and_roles_to_obfuscate.items():
153+
if data['type'] == label_to_obfuscate:
154+
data.update(type=with_label)
155+
break
156+
157+
154158
def get_query_handles(example_id):
155159
"""
156160
Creates an iterable, each element containing a Graql query, a function to sample the answers, and a QueryGraph
@@ -174,15 +178,13 @@ def get_query_handles(example_id):
174178
get;''')
175179

176180
vars = p, par, ps, d, diag, n = 'p', 'par', 'ps', 'd', 'diag', 'n'
177-
g = QueryGraph()
178-
g.add_vars(*vars, **PREEXISTS)
179-
g.add_role_edge(ps, p, 'child', **PREEXISTS)
180-
g.add_role_edge(ps, par, 'parent', **PREEXISTS)
181-
g.add_role_edge(diag, par, 'patient', **PREEXISTS)
182-
g.add_role_edge(diag, d, 'diagnosed-disease', **PREEXISTS)
183-
g.add_has_edge(d, n, **PREEXISTS)
184-
185-
hereditary_query_graph = g
181+
hereditary_query_graph = (QueryGraph()
182+
.add_vars(vars, PREEXISTS)
183+
.add_role_edge(ps, p, 'child', PREEXISTS)
184+
.add_role_edge(ps, par, 'parent', PREEXISTS)
185+
.add_role_edge(diag, par, 'patient', PREEXISTS)
186+
.add_role_edge(diag, d, 'diagnosed-disease', PREEXISTS)
187+
.add_has_edge(d, n, PREEXISTS))
186188

187189
# === Consumption Feature ===
188190
consumption_query = inspect.cleandoc(f'''match
@@ -192,26 +194,22 @@ def get_query_handles(example_id):
192194
has units-per-week $u; get;''')
193195

194196
vars = p, s, n, c, u = 'p', 's', 'n', 'c', 'u'
195-
g = QueryGraph()
196-
g.add_vars(*vars, **PREEXISTS)
197-
g.add_has_edge(s, n, **PREEXISTS)
198-
g.add_role_edge(c, p, 'consumer', **PREEXISTS)
199-
g.add_role_edge(c, s, 'consumed-substance', **PREEXISTS)
200-
g.add_has_edge(c, u, **PREEXISTS)
201-
202-
consumption_query_graph = g
197+
consumption_query_graph = (QueryGraph()
198+
.add_vars(vars, PREEXISTS)
199+
.add_has_edge(s, n, PREEXISTS)
200+
.add_role_edge(c, p, 'consumer', PREEXISTS)
201+
.add_role_edge(c, s, 'consumed-substance', PREEXISTS)
202+
.add_has_edge(c, u, PREEXISTS))
203203

204204
# === Age Feature ===
205205
person_age_query = inspect.cleandoc(f'''match
206206
$p isa person, has example-id {example_id}, has age $a;
207207
get;''')
208208

209209
vars = p, a = 'p', 'a'
210-
g = QueryGraph()
211-
g.add_vars(*vars, **PREEXISTS)
212-
g.add_has_edge(p, a, **PREEXISTS)
213-
214-
person_age_query_graph = g
210+
person_age_query_graph = (QueryGraph()
211+
.add_vars(vars, PREEXISTS)
212+
.add_has_edge(p, a, PREEXISTS))
215213

216214
# === Risk Factors Feature ===
217215
risk_factor_query = inspect.cleandoc(f'''match
@@ -221,12 +219,10 @@ def get_query_handles(example_id):
221219
get;''')
222220

223221
vars = p, d, r = 'p', 'd', 'r'
224-
g = QueryGraph()
225-
g.add_vars(*vars, **PREEXISTS)
226-
g.add_role_edge(r, p, 'person-at-risk', **PREEXISTS)
227-
g.add_role_edge(r, d, 'risked-disease', **PREEXISTS)
228-
229-
risk_factor_query_graph = g
222+
risk_factor_query_graph = (QueryGraph()
223+
.add_vars(vars, PREEXISTS)
224+
.add_role_edge(r, p, 'person-at-risk', PREEXISTS)
225+
.add_role_edge(r, d, 'risked-disease', PREEXISTS))
230226

231227
# === Diagnosis ===
232228
diagnosis_query = inspect.cleandoc(f'''match
@@ -239,26 +235,22 @@ def get_query_handles(example_id):
239235
get;''')
240236

241237
vars = p, s, sn, d, dn, sp, sev, c = 'p', 's', 'sn', 'd', 'dn', 'sp', 'sev', 'c'
242-
g = QueryGraph()
243-
g.add_vars(*vars, **PREEXISTS)
244-
g.add_has_edge(s, sn, **PREEXISTS)
245-
g.add_has_edge(d, dn, **PREEXISTS)
246-
g.add_role_edge(sp, s, 'presented-symptom', **PREEXISTS)
247-
g.add_has_edge(sp, sev, **PREEXISTS)
248-
g.add_role_edge(sp, p, 'symptomatic-patient', **PREEXISTS)
249-
g.add_role_edge(c, s, 'effect', **PREEXISTS)
250-
g.add_role_edge(c, d, 'cause', **PREEXISTS)
251-
252-
base_query_graph = g
253-
254-
g = copy.copy(base_query_graph)
238+
base_query_graph = (QueryGraph()
239+
.add_vars(vars, PREEXISTS)
240+
.add_has_edge(s, sn, PREEXISTS)
241+
.add_has_edge(d, dn, PREEXISTS)
242+
.add_role_edge(sp, s, 'presented-symptom', PREEXISTS)
243+
.add_has_edge(sp, sev, PREEXISTS)
244+
.add_role_edge(sp, p, 'symptomatic-patient', PREEXISTS)
245+
.add_role_edge(c, s, 'effect', PREEXISTS)
246+
.add_role_edge(c, d, 'cause', PREEXISTS))
255247

256248
diag, d, p = 'diag', 'd', 'p'
257-
g.add_vars(diag, **TO_INFER)
258-
g.add_role_edge(diag, d, 'diagnosed-disease', **TO_INFER)
259-
g.add_role_edge(diag, p, 'patient', **TO_INFER)
260249

261-
diagnosis_query_graph = g
250+
diagnosis_query_graph = (copy.copy(base_query_graph)
251+
.add_vars([diag], TO_INFER)
252+
.add_role_edge(diag, d, 'diagnosed-disease', TO_INFER)
253+
.add_role_edge(diag, p, 'patient', TO_INFER))
262254

263255
# === Candidate Diagnosis ===
264256
candidate_diagnosis_query = inspect.cleandoc(f'''match
@@ -270,13 +262,10 @@ def get_query_handles(example_id):
270262
$diag(candidate-patient: $p, candidate-diagnosed-disease: $d) isa candidate-diagnosis;
271263
get;''')
272264

273-
g = copy.copy(base_query_graph)
274-
275-
diag, d, p = 'diag', 'd', 'p'
276-
g.add_vars(diag, **CANDIDATE)
277-
g.add_role_edge(diag, d, 'candidate-diagnosed-disease', **CANDIDATE)
278-
g.add_role_edge(diag, p, 'candidate-patient', **CANDIDATE)
279-
candidate_diagnosis_query_graph = g
265+
candidate_diagnosis_query_graph = (copy.copy(base_query_graph)
266+
.add_vars([diag], CANDIDATE)
267+
.add_role_edge(diag, d, 'candidate-diagnosed-disease', CANDIDATE)
268+
.add_role_edge(diag, p, 'candidate-patient', CANDIDATE))
280269

281270
return [
282271
(diagnosis_query, lambda x: x, diagnosis_query_graph),

kglib/kgcn/examples/diagnosis/diagnosis_test.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@
2424
import networkx as nx
2525
import numpy as np
2626

27-
from kglib.kgcn.examples.diagnosis.diagnosis import write_predictions_to_grakn
27+
from kglib.kgcn.examples.diagnosis.diagnosis import write_predictions_to_grakn, obfuscate_labels
2828
from kglib.utils.grakn.object.thing import Thing
29+
from kglib.utils.graph.test.case import GraphTestCase
2930

3031

3132
class TestWritePredictionsToGrakn(unittest.TestCase):
@@ -90,5 +91,34 @@ def test_query_made_only_if_relation_wins(self):
9091
tx.commit.assert_called()
9192

9293

94+
class TestObfuscateLabels(GraphTestCase):
95+
96+
def test_labels_obfuscated_as_expected(self):
97+
98+
graph = nx.MultiDiGraph()
99+
100+
graph.add_node(0, type='person')
101+
graph.add_node(1, type='disease')
102+
graph.add_node(2, type='candidate-diagnosis')
103+
104+
graph.add_edge(2, 0, type='candidate-patient')
105+
graph.add_edge(2, 1, type='candidate-diagnosed-disease')
106+
107+
obfuscate_labels(graph, {'candidate-diagnosis': 'diagnosis',
108+
'candidate-patient': 'patient',
109+
'candidate-diagnosed-disease': 'diagnosed-disease'})
110+
111+
expected_graph = nx.MultiDiGraph()
112+
expected_graph.add_node(0, type='person')
113+
expected_graph.add_node(1, type='disease')
114+
expected_graph.add_node(2, type='diagnosis')
115+
116+
expected_graph.add_edge(2, 0, type='patient')
117+
expected_graph.add_edge(2, 1, type='diagnosed-disease')
118+
119+
self.assertGraphsEqual(graph, expected_graph)
120+
121+
122+
93123
if __name__ == "__main__":
94124
unittest.main()

kglib/utils/graph/query/query_graph.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,42 +25,45 @@ class QueryGraph(nx.MultiDiGraph):
2525
A custom graph to represent a query. Has additional helper methods specific to adding Graql patterns.
2626
"""
2727

28-
def add_vars(self, *vars, **attr):
28+
def add_vars(self, vars, solution):
2929
"""
3030
Add Graql variables, stored as nodes in the graph
3131
Args:
32-
*vars: String variables
33-
**attr: Properties to be added to the data stored on each variable node
32+
vars: String variables
33+
solution: Indicator of the ground truth class that the variables belongs to
3434
3535
Returns:
36-
None
36+
self
3737
"""
3838
for var in vars:
39-
self.add_node(var, **attr)
39+
self.add_node(var, solution=solution)
40+
return self
4041

41-
def add_has_edge(self, owner_var, attribute_var, **attr):
42+
def add_has_edge(self, owner_var, attribute_var, solution):
4243
"""
4344
Add a "has" edge to represent ownership of an attribute
4445
Args:
4546
owner_var: The variable of the owner
4647
attribute_var: The variable of the owned attribute
47-
**attr: Properties to be added to the data stored on the "has" edge added
48+
solution: Indicator of the ground truth class that the edge belongs to
4849
4950
Returns:
50-
None
51+
self
5152
"""
52-
self.add_edge(owner_var, attribute_var, type='has', **attr)
53+
self.add_edge(owner_var, attribute_var, type='has', solution=solution)
54+
return self
5355

54-
def add_role_edge(self, relation_var, roleplayer_var, role_label, **attr):
56+
def add_role_edge(self, relation_var, roleplayer_var, role_label, solution):
5557
"""
5658
Add an edge to represent the role a variable plays in a relation
5759
Args:
5860
relation_var: The variable of the relation
5961
roleplayer_var: The variable of the roleplayer in the relation
6062
role_label: The role the roleplayer plays in the relation
61-
**attr: Properties to be added to the data stored on the role edge added
63+
solution: Indicator of the ground truth class that the edge belongs to
6264
6365
Returns:
64-
None
66+
self
6567
"""
66-
self.add_edge(relation_var, roleplayer_var, type=role_label, **attr)
68+
self.add_edge(relation_var, roleplayer_var, type=role_label, solution=solution)
69+
return self

kglib/utils/graph/query/query_graph_test.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,27 +24,32 @@
2424

2525
class TestQueryGraph(unittest.TestCase):
2626

27+
def test_add_single_var_adds_variable_node_as_expected(self):
28+
g = QueryGraph()
29+
g.add_vars(['a'], 0)
30+
self.assertDictEqual({'solution': 0}, g.nodes['a'])
31+
2732
def test_add_vars_adds_variable_nodes_as_expected(self):
2833
g = QueryGraph()
29-
g.add_vars('a', 'b')
34+
g.add_vars(['a', 'b'], 0)
3035
nodes = {node for node in g.nodes}
3136
self.assertSetEqual({'a', 'b'}, nodes)
3237

3338
def test_add_has_edge_adds_edge_as_expected(self):
3439
g = QueryGraph()
3540
g.add_vars('a', 'b')
36-
g.add_has_edge('a', 'b')
41+
g.add_has_edge('a', 'b', 0)
3742
edges = [edge for edge in g.edges]
3843
self.assertEqual(1, len(edges))
39-
self.assertEqual('has', g.edges['a', 'b', 0]['type'])
44+
self.assertDictEqual({'type': 'has', 'solution': 0}, g.edges['a', 'b', 0])
4045

4146
def test_add_role_edge_adds_role_as_expected(self):
4247
g = QueryGraph()
4348
g.add_vars('a', 'b')
44-
g.add_role_edge('a', 'b', 'role')
49+
g.add_role_edge('a', 'b', 'role_label', 1)
4550
edges = [edge for edge in g.edges]
4651
self.assertEqual(1, len(edges))
47-
self.assertEqual('role', g.edges['a', 'b', 0]['type'])
52+
self.assertDictEqual({'type': 'role_label', 'solution': 1}, g.edges['a', 'b', 0])
4853

4954

5055
if __name__ == "__main__":

0 commit comments

Comments
 (0)