Training a GNN with PyG

Cas
·Leading the Polaris project at Valence Labs

This tutorial was last updated in August 2023. For the most up to date tutorials, please see our documentation site.

%load_ext autoreload
%autoreload 2

import torch
import pandas as pd
import numpy as np
import torch.nn.functional as F
from tqdm.auto import tqdm

PyG integration

Community contribution

Curious how one would run this tutorial on Graphcore IPUs? See this tutorial contributed by @s-maddrellmander

As seen in the molfeat integration tutorial, molfeat integrates easily with the PyTorch ecosystem. In this tutorial, we will demonstrate how you can integrate molfeat with PyG for training SOTA GNNs.

To run this tutorial, you will need to install pytorch-geometric.

mamba install -c conda-forge pytorch_geometric

from molfeat.trans.graph.adj import PYGGraphTransformer
from molfeat.calc.atom import AtomCalculator
from molfeat.calc.bond import EdgeMatCalculator

Featurizer

We first start by defining our featurizer. We will use the PYGGraphTransformer from molfeat with atom and bond featurizers

featurizer = PYGGraphTransformer(
    atom_featurizer=AtomCalculator(), 
    bond_featurizer=EdgeMatCalculator()
)

Dataset

For the dataset, we will use the Lipophilicity dataset (LogD) from MoleculeNet, which contains experimental results of octanol/water distribution coefficient at pH=7.4

df = pd.read_csv("https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/Lipophilicity.csv")
df.head()

CMPD_CHEMBLIDexpsmiles0CHEMBL5962713.54Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc141CHEMBL1951080-1.18COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)...2CHEMBL17713.69COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl3CHEMBL2349513.37OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...4CHEMBL5650793.10Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...

Since training a network with PyTorch requires defining a dataset and dataloader, we can define our custom dataset that will take (1) the SMILES, (2) the LogD measurement, and (3) our molfeat transformer as input to generate the data point we need for model training.

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch_geometric.utils import degree


class DTset(Dataset):
    def __init__(self, smiles, y, featurizer):
        super().__init__()
        self.smiles = smiles
        self.featurizer = featurizer
        self.featurizer.auto_self_loop()
        self.y = torch.tensor(y).unsqueeze(-1).float()
        self.transformed_mols = self.featurizer(smiles)
        self._degrees = None

    @property
    def num_atom_features(self):
        return self.featurizer.atom_dim

    @property
    def num_output(self):
        return self.y.shape[-1]
    
    def __len__(self):
        return len(self.transformed_mols)
    
    @property
    def num_bond_features(self):
        return self.featurizer.bond_dim
    

    @property
    def degree(self):
        if self._degrees is  None:
            max_degree = -1
            for data in self.transformed_mols:
                d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
                max_degree = max(max_degree, int(d.max()))
            # Compute the in-degree histogram tensor
            deg = torch.zeros(max_degree + 1, dtype=torch.long)
            for data in self.transformed_mols:
                d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
                deg += torch.bincount(d, minlength=deg.numel())
            self._degrees = deg
        return self._degrees

    def collate_fn(self, **kwargs):
        # luckily the molfeat featurizer provides a collate functoin for PyG
        return self.featurizer.get_collate_fn(**kwargs)
    
    def __getitem__(self, index):
        return self.transformed_mols[index], self.y[index]
dataset = DTset(df.smiles.values, df.exp.values, featurizer)
generator = torch.Generator().manual_seed(42)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dt, test_dt = torch.utils.data.random_split(dataset, [train_size, test_size], generator=generator)
BATCH_SIZE = 64
train_loader = DataLoader(train_dt, batch_size=BATCH_SIZE, shuffle=True, collate_fn=dataset.collate_fn(return_pair=False))
test_loader = DataLoader(test_dt, batch_size=BATCH_SIZE, shuffle=False, collate_fn=dataset.collate_fn(return_pair=False))

Network + Training

We are almost ready to go, we just need to define our GNN. Here we use PNA as our GNN.

from torch_geometric.nn.models import PNA
from torch_geometric.nn import global_add_pool

DEVICE = "cpu"
NUM_EPOCHS = 10
LEARNING_RATE = 5e-4
PNA_AGGREGATORS = ['mean', 'min', 'max', 'std']
PNA_SCALERS = ['identity', 'amplification', 'attenuation']
model = PNA(in_channels=dataset.num_atom_features, 
                    hidden_channels=128, 
                    num_layers=3,
                    out_channels=dataset.num_output, 
                    dropout=0.1, 
                    act="relu",
                    edge_dim=dataset.num_bond_features,
                    aggregators = PNA_AGGREGATORS,
                    scalers = PNA_SCALERS,
                    deg=dataset.degree,
)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
# Train
model = model.to(DEVICE).float()
model.train()
with tqdm(range(NUM_EPOCHS)) as pbar:
    for epoch in pbar:
        losses = []
        for data in train_loader:
            data = data.to(DEVICE)
            optimizer.zero_grad()
            out = model(data.x, data.edge_index, edge_attr=data.edge_attr)
            out = global_add_pool(out, data.batch)
            loss = F.mse_loss(out.squeeze(), data.y)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
        pbar.set_description(f"Epoch {epoch} - Loss {np.mean(losses):.3f}")
  0%|          | 0/10 [00:00<?, ?it/s]

Testing

We can now test our model. For the simplicity of this tutorial, no hyper-parameter search and evaluation of the best atom/bond featurization was performed. This inevitably impacts the performance.

from sklearn.metrics import r2_score, mean_absolute_error
from matplotlib import pyplot as plt

model.eval()
test_y_hat = []
test_y_true = []
with torch.no_grad():
    for data in test_loader:
        data = data.to(DEVICE)
        out = model(data.x, data.edge_index, edge_attr=data.edge_attr)
        out = global_add_pool(out, data.batch)
        test_y_hat.append(out.detach().cpu().squeeze())
        test_y_true.append(data.y)

test_y_hat = torch.cat(test_y_hat).numpy()
test_y_true = torch.cat(test_y_true).numpy()

r2 = r2_score(test_y_true, test_y_hat)
mae = mean_absolute_error(test_y_true, test_y_hat)

plt.scatter(test_y_true, test_y_hat)
_ =plt.gca().annotate(
    "$R2 = {:.2f}$\nMAE = {:.2f}".format(r2, mae), 
    xy=(0.05,0.9),  
    xycoords='axes fraction', 
    size=8,
    bbox=dict(boxstyle="round", fc=(1.0, 0.7, 0.7), ec="none")
)
No description has been provided for this image