Skip to content
This repository was archived by the owner on Nov 18, 2023. It is now read-only.

Commit 6b028b5

Browse files
authored
Improve clarity of diagnosis example (#117)
## What is the goal of this PR? Reformat the diagnosis example to make it as clear as possible to users how they can build their own KGCN. ## What are the changes implemented in this PR? - Move main parameters to be easily visible - Move out generic util methods for retrieving types and roles - Improved docstrings - Some outdated format docstrings updated to Google format
1 parent 80ce7a5 commit 6b028b5

File tree

8 files changed

+176
-87
lines changed

8 files changed

+176
-87
lines changed

kglib/kgcn/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Knowledge Graph Convolutional Networks
22

3-
This project introduces a novel model: the *Knowledge Graph Convolutional Network* (KGCN). This work is in its second major iteration since inception.
3+
This project introduces a novel model: the *Knowledge Graph Convolutional Network* (KGCN).
44

55
### Getting Started - Running the Machine Learning Pipeline
66

kglib/kgcn/examples/diagnosis/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ py_library(
2626
"//kglib/kgcn/plot",
2727
"//kglib/kgcn/models",
2828
"//kglib/utils/grakn/synthetic",
29+
"//kglib/utils/grakn/type",
2930
"@graknlabs_client_python//:client_python",
3031
],
3132
visibility=['//visibility:public']

kglib/kgcn/examples/diagnosis/diagnosis.py

Lines changed: 68 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,58 @@
2525

2626
from kglib.kgcn.pipeline.pipeline import pipeline
2727
from kglib.utils.grakn.synthetic.examples.diagnosis.generate import generate_example_graphs
28+
from kglib.utils.grakn.type.type import get_thing_types, get_role_types
2829
from kglib.utils.graph.iterate import multidigraph_data_iterator
2930
from kglib.utils.graph.query.query_graph import QueryGraph
3031
from kglib.utils.graph.thing.queries_to_graph import build_graph_from_queries
3132

33+
KEYSPACE = "diagnosis"
34+
URI = "localhost:48555"
35+
36+
# Existing elements in the graph are those that pre-exist in the graph, and should be predicted to continue to exist
37+
PREEXISTS = dict(solution=0)
38+
39+
# Candidates are neither present in the input nor in the solution, they are negative samples
40+
CANDIDATE = dict(solution=1)
41+
42+
# Elements to infer are the graph elements whose existence we want to predict to be true, they are positive samples
43+
TO_INFER = dict(solution=2)
44+
45+
# Categorical Attribute types and the values of their categories
46+
CATEGORICAL_ATTRIBUTES = {'name': ['Diabetes Type II', 'Multiple Sclerosis', 'Blurred vision', 'Fatigue', 'Cigarettes',
47+
'Alcohol']}
48+
# Continuous Attribute types and their min and max values
49+
CONTINUOUS_ATTRIBUTES = {'severity': (0, 1), 'age': (7, 80), 'units-per-week': (3, 29)}
50+
51+
TYPES_TO_IGNORE = ['candidate-diagnosis', 'example-id', 'probability-exists', 'probability-non-exists', 'probability-preexists']
52+
ROLES_TO_IGNORE = ['candidate-patient', 'candidate-diagnosed-disease']
53+
54+
# The learner should see candidate relations the same as the ground truth relations, so adjust these candidates to
55+
# look like their ground truth counterparts
56+
TYPES_AND_ROLES_TO_OBFUSCATE = {'candidate-diagnosis': 'diagnosis',
57+
'candidate-patient': 'patient',
58+
'candidate-diagnosed-disease': 'diagnosed-disease'}
59+
3260

3361
def diagnosis_example(num_graphs=200,
3462
num_processing_steps_tr=5,
3563
num_processing_steps_ge=5,
3664
num_training_iterations=1000,
37-
keyspace="diagnosis", uri="localhost:48555"):
65+
keyspace=KEYSPACE, uri=URI):
66+
"""
67+
Run the diagnosis example from start to finish, including traceably ingesting predictions back into Grakn
68+
69+
Args:
70+
num_graphs: Number of graphs to use for training and testing combined
71+
num_processing_steps_tr: The number of message-passing steps for training
72+
num_processing_steps_ge: The number of message-passing steps for testing
73+
num_training_iterations: The number of training epochs
74+
keyspace: The name of the keyspace to retrieve example subgraphs from
75+
uri: The uri of the running Grakn instance
76+
77+
Returns:
78+
Final accuracies for training and for testing
79+
"""
3880

3981
tr_ge_split = int(num_graphs*0.5)
4082

@@ -48,7 +90,10 @@ def diagnosis_example(num_graphs=200,
4890
with session.transaction().read() as tx:
4991
# Change the terminology here onwards from thing -> node and role -> edge
5092
node_types = get_thing_types(tx)
93+
[node_types.remove(el) for el in TYPES_TO_IGNORE]
94+
5195
edge_types = get_role_types(tx)
96+
[edge_types.remove(el) for el in ROLES_TO_IGNORE]
5297
print(f'Found node types: {node_types}')
5398
print(f'Found edge types: {edge_types}')
5499

@@ -72,12 +117,17 @@ def diagnosis_example(num_graphs=200,
72117
return solveds_tr, solveds_ge
73118

74119

75-
CATEGORICAL_ATTRIBUTES = {'name': ['Diabetes Type II', 'Multiple Sclerosis', 'Blurred vision', 'Fatigue', 'Cigarettes',
76-
'Alcohol']}
77-
CONTINUOUS_ATTRIBUTES = {'severity': (0, 1), 'age': (7, 80), 'units-per-week': (3, 29)}
120+
def create_concept_graphs(example_indices, grakn_session):
121+
"""
122+
Builds an in-memory graph for each example, with an example_id as an anchor for each example subgraph.
123+
Args:
124+
example_indices: The values used to anchor the subgraph queries within the entire knowledge graph
125+
grakn_session: Grakn Session
78126
127+
Returns:
128+
In-memory graphs of Grakn subgraphs
129+
"""
79130

80-
def create_concept_graphs(example_indices, grakn_session):
81131
graphs = []
82132
infer = True
83133

@@ -90,37 +140,28 @@ def create_concept_graphs(example_indices, grakn_session):
90140

91141
# Remove label leakage - change type labels that indicate candidates into non-candidates
92142
for data in multidigraph_data_iterator(graph):
93-
typ = data['type']
94-
if typ == 'candidate-diagnosis':
95-
data.update(type='diagnosis')
96-
elif typ == 'candidate-patient':
97-
data.update(type='patient')
98-
elif typ == 'candidate-diagnosed-disease':
99-
data.update(type='diagnosed-disease')
143+
for label_to_obfuscate, with_label in TYPES_AND_ROLES_TO_OBFUSCATE.items():
144+
if data['type'] == label_to_obfuscate:
145+
data.update(type=with_label)
146+
break
100147

101148
graph.name = example_id
102149
graphs.append(graph)
103150

104151
return graphs
105152

106153

107-
# Existing elements in the graph are those that pre-exist in the graph, and should be predicted to continue to exist
108-
PREEXISTS = dict(solution=0)
109-
110-
# Candidates are neither present in the input nor in the solution, they are negative samples
111-
CANDIDATE = dict(solution=1)
112-
113-
# Elements to infer are the graph elements whose existence we want to predict to be true, they are positive samples
114-
TO_INFER = dict(solution=2)
115-
116-
117154
def get_query_handles(example_id):
118155
"""
119-
1. Supply a query
120-
2. Supply a `QueryGraph` object to represent that query. That itself is a subclass of a networkx graph
121-
3. Execute the query
122-
4. Make a graph of the query results by taking the variables you got back and arranging the concepts as they are in the `QueryGraph`. This gives one graph for each result, for each query.
123-
5. Combine all of these graphs into one single graph, and that’s your example subgraph
156+
Creates an iterable, each element containing a Graql query, a function to sample the answers, and a QueryGraph
157+
object which must be the Grakn graph representation of the query. This tuple is termed a "query_handle"
158+
159+
Args:
160+
example_id: A uniquely identifiable attribute value used to anchor the results of the queries to a specific
161+
subgraph
162+
163+
Returns:
164+
query handles
124165
"""
125166

126167
# === Hereditary Feature ===
@@ -165,7 +206,6 @@ def get_query_handles(example_id):
165206
$p isa person, has example-id {example_id}, has age $a;
166207
get;''')
167208

168-
169209
vars = p, a = 'p', 'a'
170210
g = QueryGraph()
171211
g.add_vars(*vars, **PREEXISTS)
@@ -248,48 +288,6 @@ def get_query_handles(example_id):
248288
]
249289

250290

251-
def get_thing_types(tx):
252-
"""
253-
Get all schema types, excluding those for implicit attribute relations, base types, and candidate types
254-
Args:
255-
tx: Grakn transaction
256-
257-
Returns:
258-
Grakn types
259-
"""
260-
schema_concepts = tx.query(
261-
"match $x sub thing; "
262-
"not {$x sub @has-attribute;}; "
263-
"not {$x sub @key-attribute;}; "
264-
"get;")
265-
thing_types = [schema_concept.get('x').label() for schema_concept in schema_concepts]
266-
[thing_types.remove(el) for el in
267-
['thing', 'relation', 'entity', 'attribute', 'candidate-diagnosis', 'example-id', 'probability-exists',
268-
'probability-non-exists', 'probability-preexists']]
269-
return thing_types
270-
271-
272-
def get_role_types(tx):
273-
"""
274-
Get all schema roles, excluding those for implicit attribute relations, the base role type, and candidate roles
275-
Args:
276-
tx: Grakn transaction
277-
278-
Returns:
279-
Grakn roles
280-
"""
281-
schema_concepts = tx.query(
282-
"match $x sub role; "
283-
"not{$x sub @key-attribute-value;}; "
284-
"not{$x sub @key-attribute-owner;}; "
285-
"not{$x sub @has-attribute-value;}; "
286-
"not{$x sub @has-attribute-owner;};"
287-
"get;")
288-
role_types = ['has'] + [role.get('x').label() for role in schema_concepts]
289-
[role_types.remove(el) for el in ['role', 'candidate-patient', 'candidate-diagnosed-disease']]
290-
return role_types
291-
292-
293291
def write_predictions_to_grakn(graphs, tx):
294292
"""
295293
Take predictions from the ML model, and insert representations of those predictions back into the graph.

kglib/kgcn/pipeline/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,11 @@ def duplicate_edges_in_reverse(graph):
2323
Takes in a directed multi graph, and creates duplicates of all edges, the duplicates having reversed direction to
2424
the originals. This is useful since directed edges constrain the direction of messages passed. We want to permit
2525
omni-directional message passing.
26-
:param graph: The graph
27-
:return: The graph with duplicated edges, reversed, with all original edge properties attached to the duplicates
26+
Args:
27+
graph: The graph
28+
29+
Returns:
30+
The graph with duplicated edges, reversed, with all original edge properties attached to the duplicates
2831
"""
2932
for sender, receiver, keys, data in graph.edges(data=True, keys=True):
3033
graph.add_edge(receiver, sender, keys, **data)

kglib/utils/grakn/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ py_library(
88
'//kglib/utils/grakn/test',
99
'//kglib/utils/grakn/object',
1010
'//kglib/utils/grakn/synthetic',
11+
'//kglib/utils/grakn/type',
1112
],
1213
visibility=['//visibility:public']
1314
)

kglib/utils/grakn/type/BUILD

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
load("@io_bazel_rules_python//python:python.bzl", "py_library")
2+
load("@pypi_dependencies//:requirements.bzl", "requirement")
3+
4+
5+
py_library(
6+
name = "type",
7+
srcs = [
8+
'type.py',
9+
],
10+
visibility=['//visibility:public']
11+
)

kglib/utils/grakn/type/type.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
#
19+
20+
21+
def get_thing_types(tx):
22+
"""
23+
Get all schema types, excluding those for implicit attribute relations and base types
24+
Args:
25+
tx: Grakn transaction
26+
27+
Returns:
28+
Grakn types
29+
"""
30+
schema_concepts = tx.query(
31+
"match $x sub thing; "
32+
"not {$x sub @has-attribute;}; "
33+
"not {$x sub @key-attribute;}; "
34+
"get;")
35+
thing_types = [schema_concept.get('x').label() for schema_concept in schema_concepts]
36+
[thing_types.remove(el) for el in ['thing', 'relation', 'entity', 'attribute']]
37+
return thing_types
38+
39+
40+
def get_role_types(tx):
41+
"""
42+
Get all schema roles, excluding those for implicit attribute relations, the base role type
43+
Args:
44+
tx: Grakn transaction
45+
46+
Returns:
47+
Grakn roles
48+
"""
49+
schema_concepts = tx.query(
50+
"match $x sub role; "
51+
"not{$x sub @key-attribute-value;}; "
52+
"not{$x sub @key-attribute-owner;}; "
53+
"not{$x sub @has-attribute-value;}; "
54+
"not{$x sub @has-attribute-owner;};"
55+
"get;")
56+
role_types = ['has'] + [role.get('x').label() for role in schema_concepts]
57+
role_types.remove('role')
58+
return role_types

kglib/utils/graph/thing/queries_to_graph.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,26 @@ def concept_dict_from_concept_map(concept_map):
2929
"""
3030
Given a concept map, build a dictionary of the variables present and the concepts they refer to, locally storing any
3131
information required about those concepts.
32-
:param concept_map: A dict of Concepts provided by Grakn keyed by query variables
33-
:return: A dictionary of concepts keyed by query variables
32+
33+
Args:
34+
concept_map: A dict of Concepts provided by Grakn keyed by query variables
35+
36+
Returns:
37+
A dictionary of concepts keyed by query variables
3438
"""
3539
return {variable: build_thing(grakn_concept) for variable, grakn_concept in concept_map.map().items()}
3640

3741

3842
def combine_2_graphs(graph1, graph2):
3943
"""
4044
Combine two graphs into one. Do this by recognising common nodes between the two.
41-
:param graph1: Graph to compare
42-
:param graph2: Graph to compare
43-
:return: Combined graph
45+
46+
Args:
47+
graph1: Graph to compare
48+
graph2: Graph to compare
49+
50+
Returns:
51+
Combined graph
4452
"""
4553

4654
for node, data in graph1.nodes(data=True):
@@ -67,8 +75,12 @@ def combine_2_graphs(graph1, graph2):
6775
def combine_n_graphs(graphs_list):
6876
"""
6977
Combine N graphs into one. Do this by recognising common nodes between the two.
70-
:param graphs_list: List of graphs to combine
71-
:return: Combined graph
78+
79+
Args:
80+
graphs_list: List of graphs to combine
81+
82+
Returns:
83+
Combined graph
7284
"""
7385
return reduce(lambda x, y: combine_2_graphs(x, y), graphs_list)
7486

@@ -78,14 +90,19 @@ def build_graph_from_queries(query_sampler_variable_graph_tuples, grakn_transact
7890
"""
7991
Builds a graph of Things, interconnected by roles (and *has*), from a set of queries and graphs representing those
8092
queries (variable graphs)of those queries, over a Grakn transaction
81-
:param infer:
82-
:param query_sampler_variable_graph_tuples: A list of tuples, each tuple containing a query, a sampling function,
83-
and a variable_graph
84-
:param grakn_transaction: A Grakn transaction
85-
:param concept_dict_converter: The function to use to convert from concept_dicts to a Grakn model. This could be
86-
a typical model or a mathematical model
87-
:return: A networkx graph
93+
94+
Args:
95+
infer: whether to use Grakn's inference engine
96+
query_sampler_variable_graph_tuples: A list of tuples, each tuple containing a query, a sampling function,
97+
and a variable_graph
98+
grakn_transaction: A Grakn transaction
99+
concept_dict_converter: The function to use to convert from concept_dicts to a Grakn model. This could be
100+
a typical model or a mathematical model
101+
102+
Returns:
103+
A networkx graph
88104
"""
105+
89106
query_concept_graphs = []
90107

91108
for query, sampler, variable_graph in query_sampler_variable_graph_tuples:

0 commit comments

Comments
 (0)