Add your own featurizer
%load_ext autoreload
%autoreload 2
import numpy as np
Define your own calculator¶
Remember that a calculator is simply a Callable that takes in a molecule (either a RDKit Chem.Mol
object or SMILES string) and returns a dictionary of features.
We can thus easily define our own calculator!
from molfeat.trans import MoleculeTransformer
from rdkit.Chem.rdMolDescriptors import CalcNumHeteroatoms
def my_calculator(mol):
"""My custom featurizer"""
mol = dm.to_mol(mol)
rng = np.random.default_rng(0)
return [mol.GetNumAtoms(), mol.GetNumBonds(), CalcNumHeteroatoms(mol), rng.random()]
# This directly works with the MoleculeTransformer
trans = MoleculeTransformer(my_calculator)
trans(["CN1C=NC2=C1C(=O)N(C(=O)N2C)C"])
[array([14. , 15. , 6. , 0.63696169])]
If such functions get more complex, it might however be easier to wrap it in class. This also ensures the calculator remains serializable.
from molfeat.calc import SerializableCalculator
class MyCalculator(SerializableCalculator):
def __call__(self, mol):
mol = dm.to_mol(mol)
rng = np.random.default_rng(0)
return [mol.GetNumAtoms(), mol.GetNumBonds(), CalcNumHeteroatoms(mol), rng.random()]
trans = MoleculeTransformer(MyCalculator())
trans(["CN1C=NC2=C1C(=O)N(C(=O)N2C)C"])
2023-03-21 14:26:26.358 | WARNING | molfeat.calc.base:__init__:23 - The 'MyCalculator' interaction has been superseded by a new class with id 0x5624f1236620
[array([14. , 15. , 6. , 0.63696169])]
Define your own transformer¶
The above example shows that in many cases, there's no direct need to create your own transformer class. You can simply use the MoleculeTransformer
base class.
In more complex cases, such as with pretrained models, it will however be better to create your own subclass.
import datamol as dm
from sklearn.ensemble import RandomForestRegressor
from molfeat.trans.pretrained import PretrainedMolTransformer
class MyFoundationModel(PretrainedMolTransformer):
def __init__(self):
super().__init__(dtype=np.float32)
self._featurizer = MoleculeTransformer("maccs", dtype=np.float32)
self._model = RandomForestRegressor()
self.train_dummy_model()
def train_dummy_model(self):
"""
Load the pretrained model.
In this dummy example, we train a RF model to predict the cLogP
"""
data = dm.data.freesolv().smiles.values
X = self._featurizer(data)
y = np.array([dm.descriptors.clogp(dm.to_mol(smi)) for smi in data])
self._model.fit(X, y)
def _convert(self, inputs: list, **kwargs):
"""Convert the molecule to a format that the model expects"""
return self._featurizer(inputs)
def _embed(self, mols: list, **kwargs):
"""
Embed the molecules using the pretrained model
In this dummy example, we simply multiply the features by the importance of the feature
"""
return [feats * self._model.feature_importances_ for feats in mols]
2023-03-21 15:35:59.406 | WARNING | molfeat.trans.base:__init__:52 - The 'MyFoundationModel' interaction has been superseded by a new class with id 0x5624f1ddce60
trans = MyFoundationModel()
trans(["CN1C=NC2=C1C(=O)N(C(=O)N2C)C"]).shape
(1, 167)
Add it to the Model Store¶
Molfeat has a Model store to publish your models in a central place. The default is a read-only GCP bucket, but you can replace this with your own file storage. This can for example be useful to share private featurizers with your team.
import platformdirs
from molfeat.store.modelstore import ModelStore
from molfeat.store import ModelInfo
path = dm.fs.join(platformdirs.user_cache_dir("molfeat"), "custom_model_store")
store = ModelStore(model_store_bucket=path)
len(store.available_models)
0
# Let's define our model's info
info = ModelInfo(
name = "my_foundation_model",
inputs = "smiles",
type="pretrained",
group="my_group",
version=0,
submitter="Datamol",
description="Solves chemistry!",
representation="vector",
require_3D=False,
tags = ["foundation_model", "random_forest"],
authors= ["Datamol"],
reference = "/fake/ref"
)
store.register(info)
store.available_models
[ModelInfo(name='my_foundation_model', inputs='smiles', type='pretrained', version=0, group='my_group', submitter='Datamol', description='Solves chemistry!', representation='vector', require_3D=False, tags=['foundation_model', 'random_forest'], authors=['Datamol'], reference='/fake/ref', created_at=datetime.datetime(2023, 3, 21, 15, 37, 4, 59739), sha256sum='9c298d589a2158eb513cb52191144518a2acab2cb0c04f1df14fca0f712fa4a1')]
Share with the community¶
We invite you to share your featurizers with the community to progress the field. To learn more, visit the developer documentation.