Skip to content
This repository was archived by the owner on Nov 18, 2023. It is now read-only.

Commit 5b713c0

Browse files
authored
Fix QueryGraph construction (#119)
## What is the goal of this PR? Fix an issue where all `diagnosis` and `candidate-diagnosis` relations were given the same `solution` property, hence the learner could report 100% accuracy. ## What are the changes implemented in this PR? - Add an assertion to check for suspiciously high accuracy - Don't use graph copying, instead break queries up further
1 parent 9b7459a commit 5b713c0

File tree

2 files changed

+30
-21
lines changed

2 files changed

+30
-21
lines changed

kglib/kgcn/examples/diagnosis/diagnosis.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
# under the License.
1818
#
1919

20-
import copy
2120
import inspect
2221
import time
2322

@@ -224,50 +223,58 @@ def get_query_handles(example_id):
224223
.add_role_edge(r, p, 'person-at-risk', PREEXISTS)
225224
.add_role_edge(r, d, 'risked-disease', PREEXISTS))
226225

227-
# === Diagnosis ===
228-
diagnosis_query = inspect.cleandoc(f'''match
226+
# === Symptom ===
227+
vars = p, s, sn, d, dn, sp, sev, c = 'p', 's', 'sn', 'd', 'dn', 'sp', 'sev', 'c'
228+
229+
symptom_query = inspect.cleandoc(f'''match
229230
$p isa person, has example-id {example_id};
230231
$s isa symptom, has name $sn;
231232
$d isa disease, has name $dn;
232233
$sp(presented-symptom: $s, symptomatic-patient: $p) isa symptom-presentation, has severity $sev;
233234
$c(cause: $d, effect: $s) isa causality;
235+
get;''')
236+
237+
symptom_query_graph = (QueryGraph()
238+
.add_vars(vars, PREEXISTS)
239+
.add_has_edge(s, sn, PREEXISTS)
240+
.add_has_edge(d, dn, PREEXISTS)
241+
.add_role_edge(sp, s, 'presented-symptom', PREEXISTS)
242+
.add_has_edge(sp, sev, PREEXISTS)
243+
.add_role_edge(sp, p, 'symptomatic-patient', PREEXISTS)
244+
.add_role_edge(c, s, 'effect', PREEXISTS)
245+
.add_role_edge(c, d, 'cause', PREEXISTS))
246+
247+
# === Diagnosis ===
248+
249+
diag, d, p, dn = 'diag', 'd', 'p', 'dn'
250+
251+
diagnosis_query = inspect.cleandoc(f'''match
252+
$p isa person, has example-id {example_id};
253+
$d isa disease, has name $dn;
234254
$diag(patient: $p, diagnosed-disease: $d) isa diagnosis;
235255
get;''')
236256

237-
vars = p, s, sn, d, dn, sp, sev, c = 'p', 's', 'sn', 'd', 'dn', 'sp', 'sev', 'c'
238-
base_query_graph = (QueryGraph()
239-
.add_vars(vars, PREEXISTS)
240-
.add_has_edge(s, sn, PREEXISTS)
241-
.add_has_edge(d, dn, PREEXISTS)
242-
.add_role_edge(sp, s, 'presented-symptom', PREEXISTS)
243-
.add_has_edge(sp, sev, PREEXISTS)
244-
.add_role_edge(sp, p, 'symptomatic-patient', PREEXISTS)
245-
.add_role_edge(c, s, 'effect', PREEXISTS)
246-
.add_role_edge(c, d, 'cause', PREEXISTS))
247-
248-
diag, d, p = 'diag', 'd', 'p'
249-
250-
diagnosis_query_graph = (copy.copy(base_query_graph)
257+
diagnosis_query_graph = (QueryGraph()
251258
.add_vars([diag], TO_INFER)
259+
.add_vars([d, p, dn], PREEXISTS)
252260
.add_role_edge(diag, d, 'diagnosed-disease', TO_INFER)
253261
.add_role_edge(diag, p, 'patient', TO_INFER))
254262

255263
# === Candidate Diagnosis ===
256264
candidate_diagnosis_query = inspect.cleandoc(f'''match
257265
$p isa person, has example-id {example_id};
258-
$s isa symptom, has name $sn;
259266
$d isa disease, has name $dn;
260-
$sp(presented-symptom: $s, symptomatic-patient: $p) isa symptom-presentation, has severity $sev;
261-
$c(cause: $d, effect: $s) isa causality;
262267
$diag(candidate-patient: $p, candidate-diagnosed-disease: $d) isa candidate-diagnosis;
263268
get;''')
264269

265-
candidate_diagnosis_query_graph = (copy.copy(base_query_graph)
270+
candidate_diagnosis_query_graph = (QueryGraph()
266271
.add_vars([diag], CANDIDATE)
272+
.add_vars([d, p, dn], PREEXISTS)
267273
.add_role_edge(diag, d, 'candidate-diagnosed-disease', CANDIDATE)
268274
.add_role_edge(diag, p, 'candidate-patient', CANDIDATE))
269275

270276
return [
277+
(symptom_query, lambda x: x, symptom_query_graph),
271278
(diagnosis_query, lambda x: x, diagnosis_query_graph),
272279
(candidate_diagnosis_query, lambda x: x, candidate_diagnosis_query_graph),
273280
(risk_factor_query, lambda x: x, risk_factor_query_graph),

tests/end_to_end/kgcn/diagnosis.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ def tearDown(self):
3838
def test_learning_is_done(self):
3939
solveds_tr, solveds_ge = diagnosis_example()
4040
self.assertGreaterEqual(solveds_tr[-1], 0.7)
41+
self.assertLessEqual(solveds_tr[-1], 0.99)
4142
self.assertGreaterEqual(solveds_ge[-1], 0.7)
43+
self.assertLessEqual(solveds_ge[-1], 0.99)
4244

4345

4446
if __name__ == "__main__":

0 commit comments

Comments
 (0)