Finetuning a pretrained transformer
%load_ext autoreload
%autoreload 2
import torch
import copy
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
HuggingFace Transformer Finetuning¶
Community contribution
Curious how one would run this tutorial on Graphcore IPUs? See this tutorial contributed by @s-maddrellmander:
We have previously shown how Molfeat integrates with PyTorch in general and even with Pytorch Geometric. Now we will demonstrate how to use molfeat to finetune a pretrained transformer. This tutorial will walk you through an example of finetuning the ChemBERTa pretrained model for molecular property prediction. These same principles can be applied to any pretrained transformers available in molfeat.
To run this tutorial, you will need to install transformers
and tokenizers
.
mamba install -c conda-forge transformers "tokenizers <0.13.2"
Advanced users
This tutorial is for advanced users that are comfortable with the APIs of molfeat and Hugging Face transformers.
from molfeat.utils.converters import SmilesConverter
from molfeat.trans.pretrained import PretrainedHFTransformer
Featurizer¶
Pretrained Transformer Featurizer in molfeat have an underlying object featurizer
that can handle both tokenization and embedding.
We will leverage this structure in molfeat to initialize our transformer model, but also to tokenize our molecules
We first start by defining our featurizer. Here we will use the ChemBERTa pretrained model.
featurizer = PretrainedHFTransformer(kind="ChemBERTa-77M-MLM", pooling="bert", preload=True)
- Note the use of preload to preload the model in the
__init__
- Note how we define a pooling mechanism here. Molfeat provides several poolers that you can explore in the API. Because a pooling layer can already be specified and will be accessible through the
_pooling_obj
attribute we will not bother defining one later. Instead we will just retrieve the one from the featurizer.
Dataset¶
For the dataset, we will use the BBBP
dataset, which contains binary labels of blood-brain barrier penetration.
df = pd.read_csv("https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv")
df.head()
num | name | p_np | smiles | |
---|---|---|---|---|
0 | 1 | Propanolol | 1 | [Cl].CC(C)NCC(O)COc1cccc2ccccc12 |
1 | 2 | Terbutylchlorambucil | 1 | C(=O)(OC(C)(C)C)CCCc1ccc(cc1)N(CCCl)CCCl |
2 | 3 | 40730 | 1 | c12c3c(N4CCN(C)CC4)c(F)cc1c(c(C(O)=O)cn2C(C)CO... |
3 | 4 | 24 | 1 | C1CCN(CC1)Cc1cccc(c1)OCCCNC(=O)C |
4 | 5 | cloxacillin | 1 | Cc1onc(c2ccccc2Cl)c1C(=O)N[C@H]3[C@H]4SC(C)(C)... |
Now we just need to define our PyTorch Dataset. As discussed above, we will leverage the internal structure of our transformer
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from transformers import default_data_collator
class DTset(Dataset):
def __init__(self, smiles, y, mf_featurizer):
super().__init__()
self.smiles = smiles
self.mf_featurizer = mf_featurizer
self.y = torch.tensor(y).float()
# here we use the molfeat mf_featurizer to convert the smiles to
# corresponding tokens based on the internal tokenizer
# we just want the data from the batch encoding object
self.transformed_mols = self.mf_featurizer._convert(smiles)
@property
def embedding_dim(self):
return len(self.mf_featurizer)
@property
def max_length(self):
return self.transformed_mols.shape[-1]
def __len__(self):
return self.y.shape[0]
def collate_fn(self, **kwargs):
# the default collate fn self.mf_featurizer.get_collate_fn(**kwargs)
# returns None, which should just concatenate the inputs
# You could also use `transformers.default_data_collator` instead
return self.mf_featurizer.get_collate_fn(**kwargs)
def __getitem__(self, index):
datapoint = dict((name, val[index]) for name, val in self.transformed_mols.items())
datapoint["y"] = self.y[index]
return datapoint
dataset = DTset(df.smiles.values, df.p_np.values, featurizer)
generator = torch.Generator().manual_seed(42)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dt, test_dt = torch.utils.data.random_split(dataset, [train_size, test_size], generator=generator)
0%| | 0/2050 [00:00<?, ?it/s]
0%| | 0/2050 [00:00<?, ?it/s]
BATCH_SIZE = 64
train_loader = DataLoader(train_dt, batch_size=BATCH_SIZE, shuffle=True, collate_fn=dataset.collate_fn())
test_loader = DataLoader(test_dt, batch_size=BATCH_SIZE, shuffle=False, collate_fn=dataset.collate_fn())
Network + Training¶
We are ready to go, now we just need to define our Model for finetuning pretrained ChemBerta on the BBBP task.
class AwesomeNet(torch.nn.Module):
def __init__(self, mf_featurizer, hidden_size=128, dropout=0.1, output_size=1):
super().__init__()
self.hidden_size = hidden_size
self.output_size = output_size
# we get the underlying model from the molfeat featurizer
# here we fetch the "base" huggingface transformer model
# and not the wrapper around for MLM
# this is principally to get smaller model and training efficiency
base_pretrained_model = getattr(mf_featurizer.featurizer.model, mf_featurizer.featurizer.model.base_model_prefix)
self.embedding_layer = copy.deepcopy(base_pretrained_model)
self.embedding_dim = mf_featurizer.featurizer.model.config.hidden_size
# given that we are not concatenating layers, the following is equivalent
# self.embedding_dim = len(mf_featurizer)
# we get the the pooling layer from the molfeat featurizer
self.pooling_layer = mf_featurizer._pooling_obj
self.hidden_layer = torch.nn.Sequential(
torch.nn.Dropout(p=dropout),
torch.nn.Linear(len(mf_featurizer), self.hidden_size),
torch.nn.ReLU()
)
self.output_layer = torch.nn.Linear(self.hidden_size, self.output_size)
def forward(self, *, y=None, **kwargs):
# get embeddings
x = self.embedding_layer(**kwargs)
# we take the last hidden state
# you could also set `output_hidden_states` to true above
# and take x["hidden_states"][-1] instead
emb = x["last_hidden_state"]
# run poolings
h = self.pooling_layer(
emb,
kwargs["input_ids"],
mask=kwargs.get('attention_mask'),
)
# run through our custom and optional hidden layer
h = self.hidden_layer(h)
# run through output layers to get logits
return self.output_layer(h)
DEVICE = "cpu"
NUM_EPOCHS = 10
LEARNING_RATE = 1e-3
PNA_AGGREGATORS = ['mean', 'min', 'max', 'std']
PNA_SCALERS = ['identity', 'amplification', 'attenuation']
model = AwesomeNet(featurizer, hidden_size=64, dropout=0.1, output_size=1)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = torch.nn.BCEWithLogitsLoss()
model = model.to(DEVICE).float()
model = model.train()
You might want to have a look at the model summary for a sanity check.
! pip install torchinfo
from torchinfo import summary
summary(model)
You should now see the following output:
==========================================================================================
Layer (type:depth-idx) Param #
==========================================================================================
AwesomeNet --
├─RobertaForMaskedLM: 1-1 --
│ └─RobertaModel: 2-1 --
│ │ └─RobertaEmbeddings: 3-1 429,312
│ │ └─RobertaEncoder: 3-2 2,850,288
│ └─RobertaLMHead: 2-2 --
│ │ └─Linear: 3-3 147,840
│ │ └─LayerNorm: 3-4 768
│ │ └─Linear: 3-5 231,000
├─BertPooler: 1-2 --
│ └─Linear: 2-3 147,840
│ └─Tanh: 2-4 --
├─Sequential: 1-3 --
│ └─Dropout: 2-5 --
│ └─Linear: 2-6 24,640
│ └─ReLU: 2-7 --
├─Linear: 1-4 65
==========================================================================================
Total params: 3,831,753
Trainable params: 3,831,753
Non-trainable params: 0
==========================================================================================
# Train
with tqdm(range(NUM_EPOCHS)) as pbar:
for epoch in pbar:
losses = []
for data in train_loader:
optimizer.zero_grad()
out = model(**data)
loss = loss_fn(out.squeeze(), data["y"])
loss.backward()
optimizer.step()
losses.append(loss.item())
pbar.set_description(f"Epoch {epoch} - Loss {np.mean(losses):.3f}")
0%| | 0/10 [00:00<?, ?it/s]
Testing¶
We can now test our model.
from sklearn.metrics import accuracy_score, roc_auc_score
from matplotlib import pyplot as plt
model.eval()
test_y_hat = []
test_y_true = []
with torch.no_grad():
for data in test_loader:
out = model(**data)
# we apply sigmoid
out = torch.sigmoid(out)
test_y_hat.append(out.detach().cpu().squeeze())
test_y_true.append(data["y"])
test_y_hat = torch.cat(test_y_hat).squeeze().numpy()
test_y_true = torch.cat(test_y_true).squeeze().numpy()
roc_auc = roc_auc_score(test_y_true, test_y_hat)
acc = accuracy_score(test_y_true, test_y_hat>=0.5)
print(f"Test ROC AUC: {roc_auc:.3f}\nTest Accuracy: {acc:.3f}")
Test ROC AUC: 0.964 Test Accuracy: 0.905