Embedding and downstream tasks tutorial

This tutorial illustrates an example of a co-occurrence graph and guides the user through the graph representation learning and all it’s downstream tasks including node similarity queries, node classification, edge prediction and embedding pipeline building. The source notebook can be found here.

import pandas as pd
import numpy as np

from sklearn import model_selection
from sklearn import mixture
from sklearn.svm import LinearSVC
from bluegraph.core import PandasPGFrame
from bluegraph.preprocess.generators import CooccurrenceGenerator
from bluegraph.preprocess.encoders import ScikitLearnPGEncoder

from bluegraph.core.embed.embedders import GraphElementEmbedder
from bluegraph.backends.stellargraph import StellarGraphNodeEmbedder

from bluegraph.downstream import EmbeddingPipeline, transform_to_2d, plot_2d, get_classification_scores
from bluegraph.downstream.similarity import (FaissSimilarityIndex,
                                             SimilarityProcessor,
                                             NodeSimilarityProcessor)
from bluegraph.downstream.node_classification import NodeClassifier
from bluegraph.downstream.link_prediction import (generate_negative_edges,
                                                  EdgePredictor)

Data preparation

Fist, we read the source dataset with mentions of entities in different paragraphs

mentions = pd.read_csv("../data/labeled_entity_occurrence.csv")
# Extract unique paper/seciton/paragraph identifiers
mentions = mentions.rename(columns={"occurrence": "paragraph"})
number_of_paragraphs = len(mentions["paragraph"].unique())
mentions
entity paragraph
0 lithostathine-1-alpha 1
1 pulmonary 1
2 host 1
3 lithostathine-1-alpha 2
4 surfactant protein d measurement 2
... ... ...
2281346 covid-19 227822
2281347 covid-19 227822
2281348 viral infection 227823
2281349 lipid 227823
2281350 inflammation 227823

2281351 rows × 2 columns

We will also load a dataset that contains definitions of entities and their types

entity_data = pd.read_csv("../data/entity_types_defs.csv")
entity_data
entity entity_type definition
0 (e3-independent) e2 ubiquitin-conjugating enzyme PROTEIN (E3-independent) E2 ubiquitin-conjugating enzy...
1 (h115d)vhl35 peptide CHEMICAL A peptide vaccine derived from the von Hippel-...
2 1,1-dimethylhydrazine DRUG A clear, colorless, flammable, hygroscopic liq...
3 1,2-dimethylhydrazine CHEMICAL A compound used experimentally to induce tumor...
4 1,25-dihydroxyvitamin d(3) 24-hydroxylase, mit... PROTEIN 1,25-dihydroxyvitamin D(3) 24-hydroxylase, mit...
... ... ... ...
28127 zygomycosis DISEASE Any infection due to a fungus of the Zygomycot...
28128 zygomycota ORGANISM A phylum of fungi that are characterized by ve...
28129 zygosity ORGANISM The genetic condition of a zygote, especially ...
28130 zygote CELL_COMPARTMENT The cell formed by the union of two gametes, e...
28131 zyxin ORGANISM Zyxin (572 aa, ~61 kDa) is encoded by the huma...

28132 rows × 3 columns

Generation of a co-occurrence graph

We first create a graph whose nodes are entities

graph = PandasPGFrame()
entity_nodes = mentions["entity"].unique()
graph.add_nodes(entity_nodes)
graph.add_node_types({n: "Entity" for n in entity_nodes})

entity_props = entity_data.rename(columns={"entity": "@id"}).set_index("@id")
graph.add_node_properties(entity_props["entity_type"], prop_type="category")
graph.add_node_properties(entity_props["definition"], prop_type="text")
paragraph_prop = pd.DataFrame({"paragraphs": mentions.groupby("entity").aggregate(set)["paragraph"]})
graph.add_node_properties(paragraph_prop, prop_type="category")
graph.nodes(raw_frame=True)
@type entity_type definition paragraphs
@id
lithostathine-1-alpha Entity PROTEIN Lithostathine-1-alpha (166 aa, ~19 kDa) is enc... {1, 2, 3, 195589, 104454, 104455, 104456, 5120...
pulmonary Entity ORGAN Relating to the lungs as the intended site of ... {1, 196612, 196613, 196614, 196621, 196623, 16...
host Entity ORGANISM An organism that nourishes and supports anothe... {1, 114689, 3, 221193, 180243, 180247, 28, 180...
surfactant protein d measurement Entity PROTEIN The determination of the amount of surfactant ... {145537, 2, 3, 4, 5, 6, 51202, 103939, 103940,...
communication response Entity PATHWAY A statement (either spoken or written) that is... {46592, 64000, 2, 28162, 166912, 226304, 88585...
... ... ... ... ...
drug binding site Entity PATHWAY The reactive parts of a macromolecule that dir... {225082, 225079}
carbaril Entity CHEMICAL A synthetic carbamate acetylcholinesterase inh... {225408, 225409, 225415, 225419, 225397}
ny-eso-1 positive tumor cells present Entity CELL_TYPE An indication that Cancer/Testis Antigen 1 exp... {225544, 226996}
mustelidae Entity ORGANISM Taxonomic family which includes the Ferret. {225901, 225903}
friulian language Entity ORGANISM An Indo-European Romance language spoken in th... {225901, 225903}

17989 rows × 4 columns

For each node we will add the frequency property that counts the total number of paragraphs where the entity was mentioned.

frequencies = graph._nodes["paragraphs"].apply(len)
frequencies.name = "frequency"
graph.add_node_properties(frequencies)
graph.nodes(raw_frame=True)
@type entity_type definition paragraphs frequency
@id
lithostathine-1-alpha Entity PROTEIN Lithostathine-1-alpha (166 aa, ~19 kDa) is enc... {1, 2, 3, 195589, 104454, 104455, 104456, 5120... 80
pulmonary Entity ORGAN Relating to the lungs as the intended site of ... {1, 196612, 196613, 196614, 196621, 196623, 16... 8295
host Entity ORGANISM An organism that nourishes and supports anothe... {1, 114689, 3, 221193, 180243, 180247, 28, 180... 2660
surfactant protein d measurement Entity PROTEIN The determination of the amount of surfactant ... {145537, 2, 3, 4, 5, 6, 51202, 103939, 103940,... 268
communication response Entity PATHWAY A statement (either spoken or written) that is... {46592, 64000, 2, 28162, 166912, 226304, 88585... 160
... ... ... ... ... ...
drug binding site Entity PATHWAY The reactive parts of a macromolecule that dir... {225082, 225079} 2
carbaril Entity CHEMICAL A synthetic carbamate acetylcholinesterase inh... {225408, 225409, 225415, 225419, 225397} 5
ny-eso-1 positive tumor cells present Entity CELL_TYPE An indication that Cancer/Testis Antigen 1 exp... {225544, 226996} 2
mustelidae Entity ORGANISM Taxonomic family which includes the Ferret. {225901, 225903} 2
friulian language Entity ORGANISM An Indo-European Romance language spoken in th... {225901, 225903} 2

17989 rows × 5 columns

Now, for constructing co-occurrence network we will select only 1000 most frequent entities.

nodes_to_include = graph._nodes.nlargest(1000, "frequency").index

The CooccurrenceGenerator class allows us to generate co-occurrence edges from overlaps in node property values or edge (or edge properties). In this case we consider the paragraph node property and construct co-occurrence edges from overlapping sets of paragraphs. In addition, we will compute some co-occurrence statistics: total co-occurrence frequency and normalized pointwise mutual information (NPMI).

%%time
generator = CooccurrenceGenerator(graph.subgraph(nodes=nodes_to_include))
paragraph_cooccurrence_edges = generator.generate_from_nodes(
    "paragraphs", total_factor_instances=number_of_paragraphs,
    compute_statistics=["frequency", "npmi"],
    parallelize=True, cores=8)
CPU times: user 13.9 s, sys: 3.65 s, total: 17.6 s
Wall time: 1min 44s
cutoff = paragraph_cooccurrence_edges["npmi"].mean()
paragraph_cooccurrence_edges = paragraph_cooccurrence_edges[paragraph_cooccurrence_edges["npmi"] > cutoff]

We add generated edges to the original graph

graph._edges = paragraph_cooccurrence_edges
graph.edge_prop_as_numeric("frequency")
graph.edge_prop_as_numeric("npmi")
graph.edges(raw_frame=True)
common_factors frequency npmi
@source_id @target_id
surfactant protein d measurement microorganism {2, 3, 7810, 17, 19, 21, 100502, 26, 41, 7850,... 19 0.235263
lung {2, 103939, 51202, 5, 4, 103940, 15, 145438, 3... 93 0.221395
alveolar {223872, 2, 51202, 100502, 7831, 149657, 19522... 25 0.336175
epithelial cell {2, 4, 5, 222298, 7825, 7732, 7733, 169174, 7738} 9 0.175923
molecule {2, 7750, 49991, 134504, 206448, 49, 52, 20645... 10 0.113611
... ... ... ... ...
sars-cov-2 cardiac valve injury {196614, 207366, 186391, 190497, 196641, 18947... 123 0.213579
chloroquine {168961, 202755, 203276, 202765, 217102, 19868... 195 0.290027
severe acute respiratory syndrome {215556, 182277, 221190, 221191, 200710, 22119... 211 0.241288
caax prenyl protease 2 {226304, 208386, 215559, 209415, 208397, 21556... 150 0.343314
transmembrane protease serine 2 {192518, 200748, 200756, 204855, 188475, 19873... 380 0.420739

161332 rows × 3 columns

Recall that we have generated edges only for the 1000 most frequent entities, the rest of the entities will be isolated (having no incident edges). Let us remove all the isolated nodes.

graph.remove_node_properties("paragraphs")
graph.remove_edge_properties("common_factors")
graph.remove_isolated_nodes()
graph.number_of_nodes()
1000

Next, we save the generated co-occurrence graph.

graph.export_json("../data/cooccurrence_graph.json")
graph = PandasPGFrame.load_json("../data/cooccurrence_graph.json")

Node feature extraction

We extract node features from entity definitions using the tfidf model.

encoder = ScikitLearnPGEncoder(
    node_properties=["definition"],
    text_encoding_max_dimension=512)
%%time
transformed_graph = encoder.fit_transform(graph)
CPU times: user 959 ms, sys: 26.4 ms, total: 986 ms
Wall time: 1.02 s

We can have a glance at the vocabulary that the encoder constructed for the ‘definition’ property

vocabulary = encoder._node_encoders["definition"].model.vocabulary_
list(vocabulary.keys())[:10]
['relating',
 'lungs',
 'site',
 'administration',
 'product',
 'usually',
 'action',
 'lower',
 'respiratory',
 'tract']

We will add additional properties to our transformed graph corresponding to the entity type labels. We will also add NPMI as an edge property to this transformed graph.

transformed_graph.add_node_properties(
    graph.get_node_property_values("entity_type"))
transformed_graph.add_edge_properties(
    graph.get_edge_property_values("npmi"), prop_type="numeric")
transformed_graph.nodes(raw_frame=True)
features @type entity_type
@id
pulmonary [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... Entity ORGAN
host [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... Entity ORGANISM
surfactant protein d measurement [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... Entity PROTEIN
microorganism [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... Entity ORGANISM
lung [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... Entity ORGAN
... ... ... ...
candida parapsilosis [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... Entity ORGANISM
ciliated bronchial epithelial cell [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... Entity CELL_TYPE
cystic fibrosis pulmonary exacerbation [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... Entity DISEASE
caax prenyl protease 2 [0.0, 0.0, 0.3198444339599345, 0.0, 0.0, 0.0, ... Entity PROTEIN
transmembrane protease serine 2 [0.0, 0.0, 0.2853086240289885, 0.0, 0.0, 0.0, ... Entity PROTEIN

1000 rows × 3 columns

Node embedding and downstream tasks

Node embedding using StellarGraph

Using StellarGraphNodeEmbedder we construct three different embeddings of our transformed graph corresponding to different embedding techniques.

node2vec_embedder = StellarGraphNodeEmbedder(
    "node2vec", edge_weight="npmi", embedding_dimension=64, length=10, number_of_walks=20)
node2vec_embedding = node2vec_embedder.fit_model(transformed_graph)
attri2vec_embedder = StellarGraphNodeEmbedder(
    "attri2vec", feature_vector_prop="features",
    length=5, number_of_walks=10,
    epochs=10, embedding_dimension=128, edge_weight="npmi")
attri2vec_embedding = attri2vec_embedder.fit_model(transformed_graph)
link_classification: using 'ip' method to combine node embeddings into edge embeddings
gcn_dgi_embedder = StellarGraphNodeEmbedder(
    "gcn_dgi", feature_vector_prop="features", epochs=250, embedding_dimension=512)
gcn_dgi_embedding = gcn_dgi_embedder.fit_model(transformed_graph)
Using GCN (local pooling) filters...

The fit_model method produces a dataframe of the following shape

node2vec_embedding
embedding
pulmonary [0.13196799159049988, -0.23611457645893097, 0....
host [-0.6323956847190857, 0.36397579312324524, -0....
surfactant protein d measurement [-0.5495556592941284, 0.14938104152679443, 0.0...
microorganism [-0.4700668454170227, 0.5236756801605225, 0.14...
lung [-0.2819957435131073, 0.08759381622076035, 0.0...
... ...
candida parapsilosis [-0.18134233355522156, 0.14365115761756897, 0....
ciliated bronchial epithelial cell [-0.6209977865219116, 0.2375614047050476, 0.00...
cystic fibrosis pulmonary exacerbation [-0.1944447010755539, 0.06318975239992142, 0.1...
caax prenyl protease 2 [-0.2207261174917221, -0.071625716984272, 0.11...
transmembrane protease serine 2 [-0.40691250562667847, 0.07031852006912231, 0....

1000 rows × 1 columns

Let us add the embedding vectors obtained using different models as node properties of our graph.

transformed_graph.add_node_properties(
    node2vec_embedding.rename(columns={"embedding": "node2vec"}))
transformed_graph.add_node_properties(
    attri2vec_embedding.rename(columns={"embedding": "attri2vec"}))
transformed_graph.add_node_properties(
    gcn_dgi_embedding.rename(columns={"embedding": "gcn_dgi"}))
transformed_graph.nodes(raw_frame=True)
features @type entity_type node2vec attri2vec gcn_dgi
@id
pulmonary [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... Entity ORGAN [0.13196799159049988, -0.23611457645893097, 0.... [0.034921467304229736, 0.016040265560150146, 0... [0.01300269179046154, 0.0, 0.03357855603098869...
host [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... Entity ORGANISM [-0.6323956847190857, 0.36397579312324524, -0.... [0.07983770966529846, 0.02787071466445923, 0.0... [0.0, 0.0, 0.028662730008363724, 0.00578320631...
surfactant protein d measurement [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... Entity PROTEIN [-0.5495556592941284, 0.14938104152679443, 0.0... [0.026128143072128296, 0.030555397272109985, 0... [0.0, 0.0, 0.02776358649134636, 0.005184333305...
microorganism [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... Entity ORGANISM [-0.4700668454170227, 0.5236756801605225, 0.14... [0.2282787561416626, 0.05689656734466553, 0.07... [0.0, 0.0, 0.04060275852680206, 0.0, 0.0, 0.05...
lung [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... Entity ORGAN [-0.2819957435131073, 0.08759381622076035, 0.0... [0.01818174123764038, 0.014254063367843628, 0.... [0.0, 0.0, 0.03078138828277588, 0.008552972227...
... ... ... ... ... ... ...
candida parapsilosis [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... Entity ORGANISM [-0.18134233355522156, 0.14365115761756897, 0.... [0.373728483915329, 0.05336388945579529, 0.090... [0.0, 0.0, 0.02676139771938324, 0.0, 0.0, 0.03...
ciliated bronchial epithelial cell [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... Entity CELL_TYPE [-0.6209977865219116, 0.2375614047050476, 0.00... [0.03760749101638794, 0.00703778862953186, 0.0... [0.0, 0.0, 0.032069120556116104, 0.00537745608...
cystic fibrosis pulmonary exacerbation [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... Entity DISEASE [-0.1944447010755539, 0.06318975239992142, 0.1... [0.10799965262413025, 0.07695361971855164, 0.0... [0.0, 0.0, 0.031117763370275497, 0.0, 0.0, 0.0...
caax prenyl protease 2 [0.0, 0.0, 0.3198444339599345, 0.0, 0.0, 0.0, ... Entity PROTEIN [-0.2207261174917221, -0.071625716984272, 0.11... [0.006837755441665649, 0.01296880841255188, 0.... [0.010648305527865887, 0.0, 0.0312722884118557...
transmembrane protease serine 2 [0.0, 0.0, 0.2853086240289885, 0.0, 0.0, 0.0, ... Entity PROTEIN [-0.40691250562667847, 0.07031852006912231, 0.... [0.00615808367729187, 0.02638322114944458, 0.0... [0.0, 0.0, 0.03197368606925011, 0.010241100564...

1000 rows × 6 columns

Plotting the embeddings

Having produced the embedding vectors, we can project them into a 2D space using dimensionality reduction techniques such as TSNE (t-distributed Stochastic Neighbor Embedding).

node2vec_2d = transform_to_2d(transformed_graph._nodes["node2vec"].tolist())
attri2vec_2d = transform_to_2d(transformed_graph._nodes["attri2vec"].tolist())
gcn_dgi_2d = transform_to_2d(transformed_graph._nodes["gcn_dgi"].tolist())

We can now plot these 2D vectors using the plot_2d util provided by bluegraph.

plot_2d(transformed_graph, vectors=node2vec_2d, label_prop="entity_type")
../../_images/output_60_0.png
plot_2d(transformed_graph, vectors=attri2vec_2d, label_prop="entity_type")
../../_images/output_61_0.png
plot_2d(transformed_graph, vectors=gcn_dgi_2d, label_prop="entity_type")
../../_images/output_62_0.png

Node similarity

We would like to be able to search for similar nodes using the computed vector embeddings. For this we can use the NodeSimilarityProcessor interfaces provided as a part of bluegraph.

We construct similarity processors for different embeddings and query top 10 most similar nodes to the terms glucose and covid-19.

node2vec_l2 = NodeSimilarityProcessor(transformed_graph, "node2vec", similarity="euclidean")
node2vec_cosine = NodeSimilarityProcessor(
    transformed_graph, "node2vec", similarity="cosine")
node2vec_l2.get_neighbors(["glucose", "covid-19"], k=10)
{'glucose': {0.0: 'glucose',
  0.016042586: 'diabetic nephropathy',
  0.020855632: 'nonalcoholic fatty liver disease',
  0.020919867: 'hyperglycemia',
  0.027952814: 'metabolic syndrome',
  0.04255097: 'visceral',
  0.049424335: 'obesity',
  0.05932623: 'citrate',
  0.061201043: 'tissue factor',
  0.06682069: 'liver and intrahepatic bile duct disorder'},
 'covid-19': {0.0: 'covid-19',
  0.023866901: 'fatal',
  0.049039844: 'procalcitonin measurement',
  0.05976087: 'acute respiratory distress syndrome',
  0.08363058: 'neuromuscular',
  0.08448325: 'sterile',
  0.084664375: 'hydroxychloroquine',
  0.103314176: 'tidal volume',
  0.10976424: 'caspase-5',
  0.11111233: 'status epilepticus'}}
node2vec_cosine.get_neighbors(["glucose", "covid-19"], k=10)
{'glucose': {0.99999994: 'glucose',
  0.99718344: 'diabetic nephropathy',
  0.9968226: 'hyperglycemia',
  0.9958539: 'nonalcoholic fatty liver disease',
  0.9947761: 'metabolic syndrome',
  0.99151814: 'visceral',
  0.991088: 'respiration',
  0.9901221: 'obesity',
  0.9887427: 'liver and intrahepatic bile duct disorder',
  0.9885775: 'citrate'},
 'covid-19': {1.0: 'covid-19',
  0.99730766: 'fatal',
  0.9942852: 'procalcitonin measurement',
  0.9897085: 'acute respiratory distress syndrome',
  0.98890024: 'chronic obstructive pulmonary disease',
  0.9888062: 'sterile',
  0.98763454: 'neuromuscular',
  0.98537326: 'hydroxychloroquine',
  0.98534656: 'lopinavir/ritonavir',
  0.98470575: 'pulmonary'}}
attri2vec_l2 = NodeSimilarityProcessor(transformed_graph, "attri2vec")
attri2vec_cosine = NodeSimilarityProcessor(
    transformed_graph, "attri2vec", similarity="cosine")
attri2vec_l2.get_neighbors(["glucose", "covid-19"], k=10)
{'glucose': {0.0: 'glucose',
  0.0071316347: 'digestion',
  0.00823471: 'hepatocellular',
  0.0091231465: 'adipose tissue',
  0.010375342: 'axon',
  0.010453261: 'hemoglobin',
  0.010671802: 'bile',
  0.0106950635: 'vitamin',
  0.011250288: 'tissue',
  0.011955512: 'small intestine'},
 'covid-19': {0.0: 'covid-19',
  0.00061282323: 'chronic obstructive pulmonary disease',
  0.0009526084: 'vasculitis',
  0.0009802075: 'pulmonary edema',
  0.0010977304: 'liver failure',
  0.0011182561: 'inflammatory disorder',
  0.0011229385: 'parenteral',
  0.0012357396: 'osteoporosis',
  0.001249002: 'h1n1 influenza',
  0.0012659363: 'morphine'}}
attri2vec_cosine.get_neighbors(["glucose", "covid-19"], k=10)
{'glucose': {1.0: 'glucose',
  0.9778094: 'digestion',
  0.97610795: 'degradation',
  0.97395945: 'creatine',
  0.9727266: 'hepatocellular',
  0.9708393: 'adipose tissue',
  0.9704221: 'vitamin',
  0.9702778: 'astrocyte',
  0.9700098: 'hematopoietic stem cell',
  0.9698795: 'lymph node'},
 'covid-19': {1.0: 'covid-19',
  0.97816277: 'severe acute respiratory syndrome',
  0.9777578: 'middle east respiratory syndrome',
  0.9767103: 'respiratory failure',
  0.97613215: 'childhood-onset systemic lupus erythematosus',
  0.97379327: 'h1n1 influenza',
  0.9727: 'dengue fever',
  0.9719033: 'chronic obstructive pulmonary disease',
  0.97159684: 'arthritis',
  0.9704671: 'delirium'}}
gcn_l2 = NodeSimilarityProcessor(transformed_graph, "gcn_dgi")
gcn_cosine = NodeSimilarityProcessor(
    transformed_graph, "gcn_dgi", similarity="cosine")
gcn_l2.get_neighbors(["glucose", "covid-19"], k=10)
{'glucose': {0.0: 'glucose',
  0.0030039286: 'glucose tolerance test',
  0.0034940867: 'triglycerides',
  0.003617311: 'insulin',
  0.0036187829: 'high density lipoprotein',
  0.004899253: 'cholesterol',
  0.0056207227: 'organic phosphate',
  0.0057664528: 'uric acid',
  0.0058270395: 'fetus',
  0.006129055: 'diabetic nephropathy'},
 'covid-19': {0.0: 'covid-19',
  0.0009082245: 'coronavirus',
  0.002618216: 'fatal',
  0.0026699416: 'acute respiratory distress syndrome',
  0.0042233844: 'sars-cov-2',
  0.004636312: 'severe acute respiratory syndrome',
  0.004916654: 'middle east respiratory syndrome',
  0.005095474: 'myocarditis',
  0.0056914845: 'angiotensin ii receptor antagonist',
  0.0057702293: 'cardiac valve injury'}}
gcn_cosine.get_neighbors(["glucose", "covid-19"], k=10)
{'glucose': {1.0000001: 'glucose',
  0.98359084: 'triglycerides',
  0.9822164: 'cholesterol',
  0.981979: 'insulin',
  0.98167336: 'glucose tolerance test',
  0.979028: 'high density lipoprotein',
  0.9727696: 'low density lipoprotein',
  0.9723866: 'plasma',
  0.97019887: 'skeletal muscle tissue',
  0.9700538: 'atherosclerosis'},
 'covid-19': {0.99999994: 'covid-19',
  0.99609506: 'coronavirus',
  0.9897146: 'fatal',
  0.98897403: 'acute respiratory distress syndrome',
  0.98260605: 'sars-cov-2',
  0.980789: 'severe acute respiratory syndrome',
  0.9791904: 'middle east respiratory syndrome',
  0.97802055: 'myocarditis',
  0.97669864: 'angiotensin ii receptor antagonist',
  0.9753277: 'sars coronavirus'}}

Node clustering

We can cluster nodes according to their node embeddings. Often such clustering helps to reveal the community structure encoded in the underlying networks.

In this example we will use the BayesianGaussianMixture model provided by the scikit-learn to cluster the nodes according to different embeddings into 5 clusters.

N = 5
X = transformed_graph.get_node_property_values("node2vec").to_list()
gmm = mixture.BayesianGaussianMixture(n_components=N, covariance_type='full').fit(X)
node2vec_clusters = gmm.predict(X)
X = transformed_graph.get_node_property_values("attri2vec").to_list()
gmm = mixture.BayesianGaussianMixture(n_components=5, covariance_type='full').fit(X)
attri2vec_clusters = gmm.predict(X)
X = transformed_graph.get_node_property_values("gcn_dgi").to_list()
gmm = mixture.BayesianGaussianMixture(n_components=5, covariance_type='full').fit(X)
gcn_dgi_clusters = gmm.predict(X)

Below we inspect the most frequent cluster members.

def show_top_members(clusters, N):
    for i in range(N):
        df = transformed_graph._nodes.iloc[np.where(clusters == i)]
        df.loc[:, "frequency"] = df.index.map(lambda x: graph._nodes.loc[x, "frequency"])
        print(f"#{i}: ", ", ".join(df.nlargest(10, columns=["frequency"]).index))
show_top_members(node2vec_clusters, N)
#0:  blood, heart, pulmonary, death, renal, hypertension, cardiovascular system, septicemia, oral cavity, fever
#1:  lung, survival, cancer, organ, plasma, angiotensin-converting enzyme 2, vascular, insulin, neutrophil, antibody
#2:  bacteria, antibiotic, pneumonia, escherichia coli, staphylococcus aureus, pathogen, klebsiella pneumoniae, microorganism, mucoid pseudomonas aeruginosa, organism
#3:  human, mouse, inflammation, animal, cytokine, interleukin-6, neoplasm, dna, tissue, proliferation
#4:  covid-19, infectious disorder, diabetes mellitus, sars-cov-2, liver, virus, brain, glucose, kidney, serum
/Users/oshurko/opt/anaconda3/envs/bg/lib/python3.7/site-packages/pandas/core/indexing.py:1667: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.obj[key] = value
show_top_members(attri2vec_clusters, N)
#0:  antibiotic, escherichia coli, staphylococcus aureus, klebsiella pneumoniae, mucoid pseudomonas aeruginosa, vancomycin, pseudomonas aeruginosa, ciprofloxacin, community-acquired pneumonia, staphylococcus
#1:  human, renal, survival, brain, hypertension, obesity, respiratory system, oral cavity, injury, oxygen
#2:  death, person, proliferation, molecule, lower, failure, intestinal, transfer, organism, sterile
#3:  dog, cat, water, depression, horse, anxiety, nasal, subarachnoid hemorrhage, proximal, brother
#4:  covid-19, blood, infectious disorder, heart, diabetes mellitus, lung, sars-cov-2, mouse, pulmonary, bacteria
show_top_members(gcn_dgi_clusters, N)
#0:  lung, sars-cov-2, liver, survival, virus, brain, glucose, kidney, cancer, serum
#1:  covid-19, blood, heart, diabetes mellitus, pulmonary, death, renal, hypertension, cardiovascular system, dog
#2:  bacteria, antibiotic, escherichia coli, staphylococcus aureus, pathogen, klebsiella pneumoniae, microorganism, mucoid pseudomonas aeruginosa, organism, sputum
#3:  infectious disorder, respiratory system, oral cavity, pneumonia, skin, fever, cystic fibrosis, urine, human immunodeficiency virus, influenza
#4:  human, mouse, inflammation, animal, cytokine, plasma, interleukin-6, insulin, neoplasm, dna

We can also use the previously plot_2d util and color our 2D nore representation according to the clusters they belong to.

plot_2d(transformed_graph, vectors=node2vec_2d, labels=node2vec_clusters)
../../_images/output_86_0.png
plot_2d(transformed_graph, vectors=attri2vec_2d, labels=attri2vec_clusters)
../../_images/output_87_0.png
plot_2d(transformed_graph, vectors=gcn_dgi_2d, labels=gcn_dgi_clusters)
../../_images/output_88_0.png

Node classification

Another downstream task that we would like to perform is node classification. We would like to automatically assign entity types according to their node embeddings. For this we will build predictive models for entity type prediction based on:

  • Only node features

  • Node2vec embeddings (only structure)

  • Attri2vec embeddings (structure and node features)

  • GCN Deep Graph Infomax embeddings (structure and node features)

First of all, we split the graph nodes into the train and the test sets.

train_nodes, test_nodes = model_selection.train_test_split(
    transformed_graph.nodes(), train_size=0.8)

Now we use the NodeClassifier interface to create our classification models. As the base model we will use the linear SVM classifier (LinearSVC) provided by scikit-learn.

features_classifier = NodeClassifier(LinearSVC(), feature_vector_prop="features")
features_classifier.fit(transformed_graph, train_elements=train_nodes, label_prop="entity_type")
features_pred = features_classifier.predict(transformed_graph, predict_elements=test_nodes)
node2vec_classifier = NodeClassifier(LinearSVC(), feature_vector_prop="node2vec")
node2vec_classifier.fit(transformed_graph, train_elements=train_nodes, label_prop="entity_type")
node2vec_pred = node2vec_classifier.predict(transformed_graph, predict_elements=test_nodes)
attri2vec_classifier = NodeClassifier(LinearSVC(), feature_vector_prop="attri2vec")
attri2vec_classifier.fit(transformed_graph, train_elements=train_nodes, label_prop="entity_type")
attri2vec_pred = attri2vec_classifier.predict(transformed_graph, predict_elements=test_nodes)
/Users/oshurko/opt/anaconda3/envs/bg/lib/python3.7/site-packages/sklearn/svm/_base.py:986: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
  "the number of iterations.", ConvergenceWarning)
gcn_dgi_classifier = NodeClassifier(LinearSVC(), feature_vector_prop="gcn_dgi")
gcn_dgi_classifier.fit(transformed_graph, train_elements=train_nodes, label_prop="entity_type")
gcn_dgi_pred = gcn_dgi_classifier.predict(transformed_graph, predict_elements=test_nodes)

Let us have a look at the scores of different node classification models we have produced.

true_labels = transformed_graph._nodes.loc[test_nodes, "entity_type"]
get_classification_scores(true_labels, features_pred, multiclass=True)
{'accuracy': 0.59,
 'precision': 0.59,
 'recall': 0.59,
 'f1_score': 0.59,
 'roc_auc_score': 0.7847725250984877}
get_classification_scores(true_labels, node2vec_pred, multiclass=True)
{'accuracy': 0.36,
 'precision': 0.36,
 'recall': 0.36,
 'f1_score': 0.36,
 'roc_auc_score': 0.6786980556614562}
get_classification_scores(true_labels, attri2vec_pred, multiclass=True)
{'accuracy': 0.46,
 'precision': 0.46,
 'recall': 0.46,
 'f1_score': 0.46,
 'roc_auc_score': 0.7230397763375269}
get_classification_scores(true_labels, gcn_dgi_pred, multiclass=True)
{'accuracy': 0.33,
 'precision': 0.33,
 'recall': 0.33,
 'f1_score': 0.33,
 'roc_auc_score': 0.6585176007116533}

Creating and saving embedding pipelines

bluegraph allows to create emebedding pipelines (using the EmbeddingPipeline class) that represent a useful wrapper around a sequence of steps necessary to produce embeddings and compute point similarities. In the example below we create a pipeline for producing attri2vec node embeddings and computing their cosine similarity.

We first create an encoder object that will be used in our pipeline as a preprocessing step.

definition_encoder = ScikitLearnPGEncoder(
    node_properties=["definition"], text_encoding_max_dimension=512)

We then create an embedder object.

D = 128
params = {
    "length": 5,
    "number_of_walks": 10,
    "epochs": 5,
    "embedding_dimension": D
}
attri2vec_embedder = StellarGraphNodeEmbedder(
    "attri2vec", feature_vector_prop="features", edge_weight="npmi", **params)

And finally we create a pipeline object. Note that in the code below we use the SimilarityProcessor interface and not NodeSimilarityProcessor, as we have done it previously. We use this lower abstraction level interface, because the EmbeddingPipeline is designed to work with any embedding models (not only node embedding models).

attri2vec_pipeline = EmbeddingPipeline(
    preprocessor=definition_encoder,
    embedder=attri2vec_embedder,
    similarity_processor=SimilarityProcessor(
        FaissSimilarityIndex(
            similarity="cosine", dimension=D)))

We run the fitting process, which given the input data: 1. fits the encoder 2. transforms the data 3. fits the embedder 4. produces the embedding table 5. fits the similarity processor index

attri2vec_pipeline.run_fitting(graph)
link_classification: using 'ip' method to combine node embeddings into edge embeddings

How we can save our pipeline to the file system.

attri2vec_pipeline.save(
    "../data/attri2vec_test_model",
    compress=True)

And we can load the pipeline back into memory:

pipeline = EmbeddingPipeline.load(
    "../data/attri2vec_test_model.zip",
    embedder_interface=GraphElementEmbedder,
    embedder_ext="zip")

We can use retrieve_embeddings and get_similar_points methods of the pipeline object to respectively get embedding vectors and top most similar nodes for the input nodes.

pipeline.retrieve_embeddings(["covid-19", "glucose"])
[[0.07280001044273376,
  0.08163794130086899,
  0.08893375843763351,
  0.09304069727659225,
  0.11964225769042969,
  0.08136298507452011,
  0.0790518969297409,
  0.08503866195678711,
  0.08987397700548172,
  0.13234665989875793,
  0.06845631450414658,
  0.09433518350124359,
  0.057276081293821335,
  0.08183374255895615,
  0.0636567771434784,
  0.10424472391605377,
  0.06787201017141342,
  0.08923638612031937,
  0.07220311462879181,
  0.07509997487068176,
  0.09238457679748535,
  0.06531045585870743,
  0.0759056881070137,
  0.14457547664642334,
  0.08505883812904358,
  0.06661373376846313,
  0.07629712671041489,
  0.07443031668663025,
  0.07806529849767685,
  0.08416897058486938,
  0.12059333175420761,
  0.0758424922823906,
  0.10647209733724594,
  0.07496806234121323,
  0.09789688140153885,
  0.10009769350290298,
  0.09310337901115417,
  0.08175752311944962,
  0.08274300396442413,
  0.07131325453519821,
  0.12208940088748932,
  0.06224219128489494,
  0.09508002549409866,
  0.14279678463935852,
  0.057057347148656845,
  0.0588308647274971,
  0.08901730924844742,
  0.08926397562026978,
  0.0662379041314125,
  0.09682483226060867,
  0.07646792382001877,
  0.07486658543348312,
  0.070854052901268,
  0.054801177233457565,
  0.07894912362098694,
  0.060327619314193726,
  0.10469762980937958,
  0.07393162697553635,
  0.09346463531255722,
  0.09142538905143738,
  0.08995286375284195,
  0.057934362441301346,
  0.09345584362745285,
  0.09328961372375488,
  0.07854010164737701,
  0.07263723015785217,
  0.12583819031715393,
  0.06582190096378326,
  0.07038778066635132,
  0.06997384876012802,
  0.07740046083927155,
  0.0648268535733223,
  0.0915069580078125,
  0.1107659563422203,
  0.10443656146526337,
  0.06657622754573822,
  0.09377510845661163,
  0.06837121397256851,
  0.09725506603717804,
  0.060706377029418945,
  0.1157352551817894,
  0.0791042298078537,
  0.08426657319068909,
  0.06966130435466766,
  0.07881376147270203,
  0.06591648608446121,
  0.12842406332492828,
  0.09824175387620926,
  0.07571471482515335,
  0.0666264072060585,
  0.13996072113513947,
  0.10810025036334991,
  0.08261056989431381,
  0.062233999371528625,
  0.0959680825471878,
  0.0712309181690216,
  0.09311872720718384,
  0.08855060487985611,
  0.10211314260959625,
  0.0744297131896019,
  0.13628296554088593,
  0.07632824778556824,
  0.09952477365732193,
  0.09145186096429825,
  0.05990583822131157,
  0.08039164543151855,
  0.09073426574468613,
  0.0997760146856308,
  0.07251497358083725,
  0.06577309966087341,
  0.13079826533794403,
  0.08491260558366776,
  0.06395302712917328,
  0.04059096425771713,
  0.13386057317256927,
  0.07978139072656631,
  0.11739350110292435,
  0.05938231945037842,
  0.09113242477178574,
  0.04842013493180275,
  0.05951233580708504,
  0.0531817302107811,
  0.07620435208082199,
  0.0648634135723114,
  0.07864787429571152,
  0.16829492151737213,
  0.08553200215101242,
  0.10460848361253738],
 [0.10236917436122894,
  0.09674006700515747,
  0.07649692893028259,
  0.0845288410782814,
  0.0760805606842041,
  0.09261447936296463,
  0.09488159418106079,
  0.12473700195550919,
  0.0718981921672821,
  0.1021432876586914,
  0.09268027544021606,
  0.09814798831939697,
  0.09521770477294922,
  0.10098892450332642,
  0.09244446456432343,
  0.0635334774851799,
  0.09584149718284607,
  0.08556737005710602,
  0.0852125957608223,
  0.07645734399557114,
  0.08095100522041321,
  0.09593727439641953,
  0.08347492665052414,
  0.08885250240564346,
  0.08701310306787491,
  0.09694880247116089,
  0.11121281236410141,
  0.08294625580310822,
  0.08726843446493149,
  0.0701715424656868,
  0.09523919224739075,
  0.07785829901695251,
  0.09603790938854218,
  0.0824458971619606,
  0.08737047761678696,
  0.08853974938392639,
  0.06570149958133698,
  0.10123683512210846,
  0.07348940521478653,
  0.06943066418170929,
  0.1299903839826584,
  0.08817175030708313,
  0.06109187752008438,
  0.08437755703926086,
  0.08351798355579376,
  0.08457473665475845,
  0.07322832942008972,
  0.09192510694265366,
  0.08886606246232986,
  0.07747369259595871,
  0.07242843508720398,
  0.09057212620973587,
  0.10816606134176254,
  0.09043016284704208,
  0.09076884388923645,
  0.09677130728960037,
  0.08017739653587341,
  0.10074104368686676,
  0.07700169831514359,
  0.07268036901950836,
  0.07325926423072815,
  0.07274069637060165,
  0.06991708278656006,
  0.0845450609922409,
  0.06915223598480225,
  0.0702526643872261,
  0.09593337029218674,
  0.09438585489988327,
  0.08171636611223221,
  0.07945361733436584,
  0.0642147958278656,
  0.08085450530052185,
  0.0607246495783329,
  0.08492715656757355,
  0.07719805836677551,
  0.10578399896621704,
  0.10591499507427216,
  0.09201952069997787,
  0.0818672627210617,
  0.08240731060504913,
  0.06790471076965332,
  0.07807260751724243,
  0.0730040892958641,
  0.1071859821677208,
  0.11890396475791931,
  0.056871384382247925,
  0.09596915543079376,
  0.07900075614452362,
  0.09519974142313004,
  0.10644269734621048,
  0.08464374393224716,
  0.10578206926584244,
  0.10132604092359543,
  0.07531124353408813,
  0.09358139336109161,
  0.07341431826353073,
  0.09914236515760422,
  0.07994917780160904,
  0.06680438667535782,
  0.07904554903507233,
  0.09318091720342636,
  0.08036279678344727,
  0.07590607553720474,
  0.07815994322299957,
  0.10222751647233963,
  0.11459968239068985,
  0.0987963154911995,
  0.08063937723636627,
  0.10191671550273895,
  0.11327352374792099,
  0.08440998196601868,
  0.09114128351211548,
  0.0879993736743927,
  0.0869138091802597,
  0.1110539585351944,
  0.08841552585363388,
  0.08597182482481003,
  0.09037397056818008,
  0.07773328572511673,
  0.09250291436910629,
  0.09562606364488602,
  0.07948072999715805,
  0.08507171273231506,
  0.08046958595514297,
  0.08189624547958374,
  0.07476285845041275,
  0.10559207946062088,
  0.10403718799352646]]
a = pipeline.retrieve_embeddings(["covid-19", "glucose"])
pipeline.get_neighbors(existing_points=["covid-19", "glucose"], k=5)
([array([1.0000001 , 0.98876834, 0.9861363 , 0.9855296 , 0.98494315],
        dtype=float32),
  array([1.0000001 , 0.98885393, 0.98832536, 0.9882704 , 0.9882704 ],
        dtype=float32)],
 [Index(['covid-19', 'middle east respiratory syndrome',
         'severe acute respiratory syndrome',
         'childhood-onset systemic lupus erythematosus', 'h1n1 influenza'],
        dtype='object', name='@id'),
  Index(['glucose', 'fatigue', 'anorexia', 'congenital abnormality', 'proximal'], dtype='object', name='@id')])