Create a custom modelstore
%load_ext autoreload
%autoreload 2
Modelstore¶
Introduction¶
In the last tutorial, we learned how to create our own custom featurizer. In this tutorial, we will guide you through the process of creating a new modelstore
to store our featurizer cards.
Key concepts of the modelstore
¶
The ModelStore class allows you to serialize/register and load models. A modelstore
is just a path to a bucket or a folder that contains information about our models (artifact such as model weights and a description of the model using the concept of model card). It also provides functionality to list the available models in the store. Before creating a new model store, let's understand the key concepts:
Model Artifact: It refers to the serialized representation of a model, typically the model weights or any object that needs to be saved.
Model Card: It contains information about a model, such as its name, description, and other metadata. The ModelInfo class represents a model card.
Model Store: It is a path to a bucket or a folder where the model artifacts and model cards are stored.
import os
from molfeat.store import ModelStore
from molfeat.store import ModelInfo
Model Info Card¶
A model (info) card, is a datastructure that describes a model. It has some required arguments such as:
# name of the featurizer
name: ~
# list of authors
authors:
- author 1
# describe the featurizer
description: ~
# which type of input does the featurizer expect ?
inputs: ~
# reference of the featurizer (a paper or a link)
reference: ~
# what does the featurizer return as output for molecular representation ?
# one of ['graph', 'line-notation', 'vector', 'tensor', 'other']
representation: ~
# does the featurizer require 3D information ?
require_3D: ~
# type of the featurizer, one of ["pretrained", "hand-crafted", "hashed", "count"]
type: ~
# name of the person that is submitting the featurizer
submitter: ~
For registration of a model, you will always need to provide the model info card.
ModelInfo.schema()["properties"]
{'name': {'title': 'Name', 'type': 'string'}, 'inputs': {'title': 'Inputs', 'default': 'smiles', 'type': 'string'}, 'type': {'title': 'Type', 'enum': ['pretrained', 'hand-crafted', 'hashed', 'count'], 'type': 'string'}, 'version': {'title': 'Version', 'default': 0, 'type': 'integer'}, 'group': {'title': 'Group', 'default': 'all', 'type': 'string'}, 'submitter': {'title': 'Submitter', 'type': 'string'}, 'description': {'title': 'Description', 'type': 'string'}, 'representation': {'title': 'Representation', 'enum': ['graph', 'line-notation', 'vector', 'tensor', 'other'], 'type': 'string'}, 'require_3D': {'title': 'Require 3D', 'default': False, 'type': 'boolean'}, 'tags': {'title': 'Tags', 'type': 'array', 'items': {'type': 'string'}}, 'authors': {'title': 'Authors', 'type': 'array', 'items': {'type': 'string'}}, 'reference': {'title': 'Reference', 'type': 'string'}, 'created_at': {'title': 'Created At', 'type': 'string', 'format': 'date-time'}, 'sha256sum': {'title': 'Sha256Sum', 'type': 'string'}, 'model_usage': {'title': 'Model Usage', 'type': 'string'}}
Creating an instance of ModelStore¶
The current implementation of the modelstore
has some limitations:
Lack of versioning: The current implementation does not support versioning of models. It treats each model as a unique entity without distinguishing different versions.
Unique model names: Each model name must be unique within the store. Duplicate model names are not allowed.
Single store support: Currently, a ModelStore instance can only handle a single store, which means it can index and manage models from only one bucket path at a time.
Now, let's examine the default store.
os.environ.pop("MOLFEAT_MODEL_STORE_BUCKET", None)
'/var/folders/zt/ck4vrp4n4vsb0v16tnlh9h9m0000gn/T/tmp4dqvgaqh'
default_store = ModelStore()
default_store.model_store_bucket
'gs://molfeat-store-prod/artifacts/'
# the length of the model store corresponds to the number of model cards that have been registered
len(default_store)
45
default_store.available_models[:10]
[ModelInfo(name='cats2d', inputs='smiles', type='hashed', version=0, group='all', submitter='Datamol', description='2D version of the 6 Potential Pharmacophore Points CATS (Chemically Advanced Template Search) pharmacophore. This version differs from `pharm2D-cats` on the process to make the descriptors fuzzy, which is closer to the original paper implementation. Implementation is based on work by Rajarshi Guha (08/26/07) and Chris Arthur (1/11/2015)', representation='vector', require_3D=False, tags=['CATS', 'hashed', '2D', 'pharmacophore', 'search'], authors=['Michael Reutlinger', 'Christian P Koch', 'Daniel Reker', 'Nickolay Todoroff', 'Petra Schneider', 'Tiago Rodrigues', 'Gisbert Schneider', 'Rajarshi Guha', 'Chris Arthur'], reference='https://doi.org/10.1021/ci050413p', created_at=datetime.datetime(2023, 5, 3, 0, 7, 6, 534648), sha256sum='9c298d589a2158eb513cb52191144518a2acab2cb0c04f1df14fca0f712fa4a1', model_usage=None), ModelInfo(name='cats3d', inputs='mol', type='hashed', version=0, group='all', submitter='Datamol', description='3D version of the 6 Potential Pharmacophore Points CATS (Chemically Advanced Template Search) pharmacophore. This version differs from `pharm3D-cats` on the process to make the descriptors fuzzy, which is closer to the original paper implementation. This version uses the 3D distance matrix between pharmacophoric points', representation='vector', require_3D=True, tags=['CATS', 'hashed', '3D', 'pharmacophore', 'search'], authors=['Michael Reutlinger', 'Christian Koch', 'Daniel Reker', 'Nickolay Todoroff', 'Petra Schneider', 'Tiago Rodrigues', 'Gisbert Schneider', 'Rajarshi Guha', 'Chris Arthur'], reference='https://doi.org/10.1021/ci050413p', created_at=datetime.datetime(2023, 5, 3, 0, 7, 9, 952490), sha256sum='9c298d589a2158eb513cb52191144518a2acab2cb0c04f1df14fca0f712fa4a1', model_usage=None), ModelInfo(name='mordred', inputs='mol', type='hand-crafted', version=0, group='all', submitter='Datamol', description='Mordred calculates over 1800 molecular descriptors, including constitutional, topological, electronic, and geometrical descriptors, among others. Both 2D and 3D descriptors are supported and optional.', representation='vector', require_3D=False, tags=['topological', 'physchem', 'mordred'], authors=['Hirotomo Moriwaki', 'Yu-Shi Tian', 'Norihito Kawashita', 'Tatsuya Takagi'], reference='https://jcheminf.biomedcentral.com/articles/10.1186/s13321-018-0258-y', created_at=datetime.datetime(2023, 5, 3, 0, 7, 26, 38783), sha256sum='9c298d589a2158eb513cb52191144518a2acab2cb0c04f1df14fca0f712fa4a1', model_usage=None), ModelInfo(name='scaffoldkeys', inputs='smiles', type='hand-crafted', version=0, group='all', submitter='Datamol', description='Scaffold Keys are a method for representing scaffold using substructure features and were proposed by Peter Ertl in: Identification of Bioisosteric Scaffolds using Scaffold Keys', representation='vector', require_3D=False, tags=['scaffold', 'bioisosters', 'search'], authors=['Peter Ertl'], reference='https://chemrxiv.org/engage/chemrxiv/article-details/60c7558aee301c5479c7b1be', created_at=datetime.datetime(2023, 5, 3, 0, 7, 22, 832077), sha256sum='9c298d589a2158eb513cb52191144518a2acab2cb0c04f1df14fca0f712fa4a1', model_usage=None), ModelInfo(name='gin_supervised_contextpred', inputs='smiles', type='pretrained', version=0, group='dgllife', submitter='Datamol', description='GIN neural network model pre-trained with supervised learning and context prediction on molecules from ChEMBL.', representation='graph', require_3D=False, tags=['GIN', 'dgl', 'pytorch', 'graph'], authors=['Weihua Hu', 'Bowen Liu', 'Joseph Gomes', 'Marinka Zitnik', 'Percy Liang', 'Vijay Pande', 'Jure Leskovec'], reference='https://arxiv.org/abs/1905.12265', created_at=datetime.datetime(2023, 2, 2, 19, 51, 17, 228390), sha256sum='72dc062936b78b515ed5d0989f909ab7612496d698415d73826b974c9171504a', model_usage=None), ModelInfo(name='gin_supervised_edgepred', inputs='smiles', type='pretrained', version=0, group='dgllife', submitter='Datamol', description='GIN neural network model pre-trained with supervised learning and edge prediction on molecules from ChEMBL.', representation='graph', require_3D=False, tags=['GIN', 'dgl', 'pytorch', 'graph'], authors=['Weihua Hu', 'Bowen Liu', 'Joseph Gomes', 'Marinka Zitnik', 'Percy Liang', 'Vijay Pande', 'Jure Leskovec'], reference='https://arxiv.org/abs/1905.12265', created_at=datetime.datetime(2023, 2, 14, 17, 42, 4, 710823), sha256sum='c1198b37239c3b733f5b48cf265af4c3a1e8c448e2e26cb53e3517fd096213de', model_usage=None), ModelInfo(name='gin_supervised_infomax', inputs='smiles', type='pretrained', version=0, group='dgllife', submitter='Datamol', description='GIN neural network model pre-trained with mutual information maximisation on molecules from ChEMBL.', representation='graph', require_3D=False, tags=['GIN', 'dgl', 'pytorch', 'graph'], authors=['Weihua Hu', 'Bowen Liu', 'Joseph Gomes', 'Marinka Zitnik', 'Percy Liang', 'Vijay Pande', 'Jure Leskovec'], reference='https://arxiv.org/abs/1905.12265', created_at=datetime.datetime(2023, 2, 14, 17, 42, 6, 967631), sha256sum='78dc0f76cde2151f5aa403cbbffead0f24aeac4ce0b48dbfa2689e1a87b95216', model_usage=None), ModelInfo(name='gin_supervised_masking', inputs='smiles', type='pretrained', version=0, group='dgllife', submitter='Datamol', description='GIN neural network model pre-trained with masked modelling on molecules from ChEMBL.', representation='graph', require_3D=False, tags=['GIN', 'dgl', 'pytorch', 'graph'], authors=['Weihua Hu', 'Bowen Liu', 'Joseph Gomes', 'Marinka Zitnik', 'Percy Liang', 'Vijay Pande', 'Jure Leskovec'], reference='https://arxiv.org/abs/1905.12265', created_at=datetime.datetime(2023, 2, 14, 17, 42, 9, 221083), sha256sum='c1c797e18312ad44ff089159cb1ed79fd4c67b3d5673c562f61621d95a5d7632', model_usage=None), ModelInfo(name='jtvae_zinc_no_kl', inputs='smiles', type='pretrained', version=0, group='dgllife', submitter='Datamol', description='A JTVAE pre-trained on ZINC for molecule generation, without KL regularization', representation='other', require_3D=False, tags=['JTNN', 'JTVAE', 'dgl', 'pytorch', 'junction-tree', 'graph'], authors=['Wengong Jin', 'Regina Barzilay', 'Tommi Jaakkola'], reference='https://arxiv.org/abs/1802.04364v4', created_at=datetime.datetime(2023, 2, 2, 19, 51, 20, 468939), sha256sum='eab8ecb8a7542a8cdf97410cb27f72aaf374fefef6a1f53642cc5b310cf2b7f6', model_usage=None), ModelInfo(name='map4', inputs='smiles', type='hashed', version=0, group='fp', submitter='Datamol', description='MinHashed atom-pair fingerprint up to a diameter of four bonds (MAP4) is suitable for both small and large molecules by combining substructure and atom-pair concepts. In this fingerprint the circular substructures with radii of r\u2009=\u20091 and r\u2009=\u20092 bonds around each atom in an atom-pair are written as two pairs of SMILES, each pair being combined with the topological distance separating the two central atoms. These so-called atom-pair molecular shingles are hashed, and the resulting set of hashes is MinHashed to form the MAP4 fingerprint.', representation='vector', require_3D=False, tags=['minhashed', 'map4', 'atompair', 'substructure', 'morgan'], authors=['Alice Capecchi', 'Daniel Probst', 'Jean-Louis Reymond'], reference='https://doi.org/10.1186/s13321-020-00445-4', created_at=datetime.datetime(2023, 2, 16, 10, 29, 8, 550063), sha256sum='9c298d589a2158eb513cb52191144518a2acab2cb0c04f1df14fca0f712fa4a1', model_usage=None)]
You can also perform searches within the existing model store using either the full model card (which can be partially instantiated) or the information you have about the model.
# you can use a model card for the search
my_model_card = default_store.available_models[0]
default_store.search(my_model_card)
[ModelInfo(name='cats2d', inputs='smiles', type='hashed', version=0, group='all', submitter='Datamol', description='2D version of the 6 Potential Pharmacophore Points CATS (Chemically Advanced Template Search) pharmacophore. This version differs from `pharm2D-cats` on the process to make the descriptors fuzzy, which is closer to the original paper implementation. Implementation is based on work by Rajarshi Guha (08/26/07) and Chris Arthur (1/11/2015)', representation='vector', require_3D=False, tags=['CATS', 'hashed', '2D', 'pharmacophore', 'search'], authors=['Michael Reutlinger', 'Christian P Koch', 'Daniel Reker', 'Nickolay Todoroff', 'Petra Schneider', 'Tiago Rodrigues', 'Gisbert Schneider', 'Rajarshi Guha', 'Chris Arthur'], reference='https://doi.org/10.1021/ci050413p', created_at=datetime.datetime(2023, 5, 3, 0, 7, 6, 534648), sha256sum='9c298d589a2158eb513cb52191144518a2acab2cb0c04f1df14fca0f712fa4a1', model_usage=None)]
# search by name
default_store.search(name="jtvae_zinc_no_kl")
[ModelInfo(name='jtvae_zinc_no_kl', inputs='smiles', type='pretrained', version=0, group='dgllife', submitter='Datamol', description='A JTVAE pre-trained on ZINC for molecule generation, without KL regularization', representation='other', require_3D=False, tags=['JTNN', 'JTVAE', 'dgl', 'pytorch', 'junction-tree', 'graph'], authors=['Wengong Jin', 'Regina Barzilay', 'Tommi Jaakkola'], reference='https://arxiv.org/abs/1802.04364v4', created_at=datetime.datetime(2023, 2, 2, 19, 51, 20, 468939), sha256sum='eab8ecb8a7542a8cdf97410cb27f72aaf374fefef6a1f53642cc5b310cf2b7f6', model_usage=None)]
# Assume you forgot the name of the model, but know it is one of the pretrained models that uses graph as representation
default_store.search(type="pretrained", representation="graph")
[ModelInfo(name='gin_supervised_contextpred', inputs='smiles', type='pretrained', version=0, group='dgllife', submitter='Datamol', description='GIN neural network model pre-trained with supervised learning and context prediction on molecules from ChEMBL.', representation='graph', require_3D=False, tags=['GIN', 'dgl', 'pytorch', 'graph'], authors=['Weihua Hu', 'Bowen Liu', 'Joseph Gomes', 'Marinka Zitnik', 'Percy Liang', 'Vijay Pande', 'Jure Leskovec'], reference='https://arxiv.org/abs/1905.12265', created_at=datetime.datetime(2023, 2, 2, 19, 51, 17, 228390), sha256sum='72dc062936b78b515ed5d0989f909ab7612496d698415d73826b974c9171504a', model_usage=None), ModelInfo(name='gin_supervised_edgepred', inputs='smiles', type='pretrained', version=0, group='dgllife', submitter='Datamol', description='GIN neural network model pre-trained with supervised learning and edge prediction on molecules from ChEMBL.', representation='graph', require_3D=False, tags=['GIN', 'dgl', 'pytorch', 'graph'], authors=['Weihua Hu', 'Bowen Liu', 'Joseph Gomes', 'Marinka Zitnik', 'Percy Liang', 'Vijay Pande', 'Jure Leskovec'], reference='https://arxiv.org/abs/1905.12265', created_at=datetime.datetime(2023, 2, 14, 17, 42, 4, 710823), sha256sum='c1198b37239c3b733f5b48cf265af4c3a1e8c448e2e26cb53e3517fd096213de', model_usage=None), ModelInfo(name='gin_supervised_infomax', inputs='smiles', type='pretrained', version=0, group='dgllife', submitter='Datamol', description='GIN neural network model pre-trained with mutual information maximisation on molecules from ChEMBL.', representation='graph', require_3D=False, tags=['GIN', 'dgl', 'pytorch', 'graph'], authors=['Weihua Hu', 'Bowen Liu', 'Joseph Gomes', 'Marinka Zitnik', 'Percy Liang', 'Vijay Pande', 'Jure Leskovec'], reference='https://arxiv.org/abs/1905.12265', created_at=datetime.datetime(2023, 2, 14, 17, 42, 6, 967631), sha256sum='78dc0f76cde2151f5aa403cbbffead0f24aeac4ce0b48dbfa2689e1a87b95216', model_usage=None), ModelInfo(name='gin_supervised_masking', inputs='smiles', type='pretrained', version=0, group='dgllife', submitter='Datamol', description='GIN neural network model pre-trained with masked modelling on molecules from ChEMBL.', representation='graph', require_3D=False, tags=['GIN', 'dgl', 'pytorch', 'graph'], authors=['Weihua Hu', 'Bowen Liu', 'Joseph Gomes', 'Marinka Zitnik', 'Percy Liang', 'Vijay Pande', 'Jure Leskovec'], reference='https://arxiv.org/abs/1905.12265', created_at=datetime.datetime(2023, 2, 14, 17, 42, 9, 221083), sha256sum='c1c797e18312ad44ff089159cb1ed79fd4c67b3d5673c562f61621d95a5d7632', model_usage=None), ModelInfo(name='pcqm4mv2_graphormer_base', inputs='smiles', type='pretrained', version=0, group='graphormer', submitter='Datamol', description='Pretrained Graph Transformer on PCQM4Mv2 Homo-Lumo energy gap prediction using 2D molecular graphs.', representation='graph', require_3D=False, tags=['Graphormer', 'PCQM4Mv2', 'graph', 'pytorch', 'Microsoft'], authors=['Chengxuan Ying', 'Tianle Cai', 'Shengjie Luo', 'Shuxin Zheng', 'Guolin Ke', 'Di He', 'Yanming Shen', 'Tie-Yan Liu'], reference='https://arxiv.org/abs/2106.05234', created_at=datetime.datetime(2023, 2, 2, 19, 51, 19, 330147), sha256sum='9c298d589a2158eb513cb52191144518a2acab2cb0c04f1df14fca0f712fa4a1', model_usage=None)]
You can also load models based on their name, this is not the recommended way to explore models, because some models required a custom loading function
gin_model, gin_model_info = default_store.load(model_name="gin_supervised_infomax")
gin_model_info
ModelInfo(name='gin_supervised_infomax', inputs='smiles', type='pretrained', version=0, group='dgllife', submitter='Datamol', description='GIN neural network model pre-trained with mutual information maximisation on molecules from ChEMBL.', representation='graph', require_3D=False, tags=['GIN', 'dgl', 'pytorch', 'graph'], authors=['Weihua Hu', 'Bowen Liu', 'Joseph Gomes', 'Marinka Zitnik', 'Percy Liang', 'Vijay Pande', 'Jure Leskovec'], reference='https://arxiv.org/abs/1905.12265', created_at=datetime.datetime(2023, 2, 14, 17, 42, 6, 967631), sha256sum='78dc0f76cde2151f5aa403cbbffead0f24aeac4ce0b48dbfa2689e1a87b95216', model_usage=None)
Creating a Custom Model Store¶
Model Store Bucket¶
To create a custom model store, you need to specify a model store bucket path. By default, the code uses the read-only model store bucket located at gs://molfeat-store-prod/artifacts/
. If you want to create a custom modelstore
that includes the content of the default bucket, you can copy its contents to your custom bucket.
There are two key concepts to understand when building a custom modelstore
:
Readable and Writable Path: You need to provide a local or remote folder path that is compatible with fsspec. This path will serve as your model store bucket, allowing you to store and access models.
Multiple Model Stores: You can create multiple instances of the
modelstore
, each representing a different model store. However, it's important to note that unless you manually load a model, only models present in the default model store bucket path of themodelstore
can be loaded by their names. You can override the default model store bucket path by setting theMOLFEAT_MODEL_STORE_BUCKET
environment variable. Currently, we use a single source of truth to simplify the registration and loading process.
By understanding these concepts, you can create a custom model store by specifying a suitable model store bucket path and optionally copying the content from the default bucket. This allows you to have control over your model store and manage models according to your specific needs.
import tempfile
import os
Let's start with a simple local temporary model store
temp_dir = tempfile.TemporaryDirectory()
temp_model_store = ModelStore(temp_dir.name)
temp_model_store.available_models
[]
Let's look at the content of the temp dir
%%bash -s "$temp_dir.name"
tree $1
/var/folders/zt/ck4vrp4n4vsb0v16tnlh9h9m0000gn/T/tmp5k07n15a 0 directories, 0 files
Let's add the GIN model we just downloaded before. To register a new model, you need the following:
- Model weights (or None for pretrained models)
- Model card
- Serializing function for the model, which determines how to save the model.
You can also pass additional keyword arguments for your save_fn
through the save_fn_kwargs
parameter. The save_fn
is expected to be called as follows:
save_fn(model, <model_upload_path>, **save_fn_kwargs)
For example, here's a dummy save_fn
using PyTorch:
import torch
import fsspec
def my_torch_save_fn(model, path, **kwargs):
with fsspec.open(path, 'wb') as f:
torch.save(model, f, **kwargs)
return path
Note that if you provide a custom saving function, you are responsible for handling the corresponding loading function that matches your saving function. If you're unsure, it's recommended to use the default loading function, which covers most cases you would encounter.
tmp_gin_model_info = gin_model_info.copy()
tmp_gin_model_info.name = "tmp_gin_supervised_infomax"
temp_model_store.register(modelcard=tmp_gin_model_info, model=gin_model)
0%| | 0.00/7.12M [00:00<?, ?B/s]
2023-05-19 15:14:52.082 | INFO | molfeat.store.modelstore:register:150 - Successfuly registered model tmp_gin_supervised_infomax !
temp_model_store.available_models
[ModelInfo(name='tmp_gin_supervised_infomax', inputs='smiles', type='pretrained', version=0, group='dgllife', submitter='Datamol', description='GIN neural network model pre-trained with mutual information maximisation on molecules from ChEMBL.', representation='graph', require_3D=False, tags=['GIN', 'dgl', 'pytorch', 'graph'], authors=['Weihua Hu', 'Bowen Liu', 'Joseph Gomes', 'Marinka Zitnik', 'Percy Liang', 'Vijay Pande', 'Jure Leskovec'], reference='https://arxiv.org/abs/1905.12265', created_at=datetime.datetime(2023, 2, 14, 17, 42, 6, 967631), sha256sum='7a100a7eb680e62d98b0e1d3a906bf740f140dceda55b69a86ddf3fd78ace245', model_usage=None)]
temp_model_store.load("tmp_gin_supervised_infomax")
0%| | 0.00/663 [00:00<?, ?B/s]
0%| | 0.00/7.12M [00:00<?, ?B/s]
(GIN( (dropout): Dropout(p=0.5, inplace=False) (node_embeddings): ModuleList( (0): Embedding(120, 300) (1): Embedding(3, 300) ) (gnn_layers): ModuleList( (0): GINLayer( (mlp): Sequential( (0): Linear(in_features=300, out_features=600, bias=True) (1): ReLU() (2): Linear(in_features=600, out_features=300, bias=True) ) (edge_embeddings): ModuleList( (0): Embedding(6, 300) (1): Embedding(3, 300) ) (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): GINLayer( (mlp): Sequential( (0): Linear(in_features=300, out_features=600, bias=True) (1): ReLU() (2): Linear(in_features=600, out_features=300, bias=True) ) (edge_embeddings): ModuleList( (0): Embedding(6, 300) (1): Embedding(3, 300) ) (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): GINLayer( (mlp): Sequential( (0): Linear(in_features=300, out_features=600, bias=True) (1): ReLU() (2): Linear(in_features=600, out_features=300, bias=True) ) (edge_embeddings): ModuleList( (0): Embedding(6, 300) (1): Embedding(3, 300) ) (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (3): GINLayer( (mlp): Sequential( (0): Linear(in_features=300, out_features=600, bias=True) (1): ReLU() (2): Linear(in_features=600, out_features=300, bias=True) ) (edge_embeddings): ModuleList( (0): Embedding(6, 300) (1): Embedding(3, 300) ) (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (4): GINLayer( (mlp): Sequential( (0): Linear(in_features=300, out_features=600, bias=True) (1): ReLU() (2): Linear(in_features=600, out_features=300, bias=True) ) (edge_embeddings): ModuleList( (0): Embedding(6, 300) (1): Embedding(3, 300) ) (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) ), ModelInfo(name='tmp_gin_supervised_infomax', inputs='smiles', type='pretrained', version=0, group='dgllife', submitter='Datamol', description='GIN neural network model pre-trained with mutual information maximisation on molecules from ChEMBL.', representation='graph', require_3D=False, tags=['GIN', 'dgl', 'pytorch', 'graph'], authors=['Weihua Hu', 'Bowen Liu', 'Joseph Gomes', 'Marinka Zitnik', 'Percy Liang', 'Vijay Pande', 'Jure Leskovec'], reference='https://arxiv.org/abs/1905.12265', created_at=datetime.datetime(2023, 2, 14, 17, 42, 6, 967631), sha256sum='7a100a7eb680e62d98b0e1d3a906bf740f140dceda55b69a86ddf3fd78ace245', model_usage=None))
We can see that the new model tmp_gin_supervised_infomax
has been saved
Loading model from a custom model store¶
# clean existing env file
! rm -rf .env
import os
from molfeat.trans.pretrained.dgl_pretrained import PretrainedDGLTransformer
# making sure that the default bucket is not set
model = PretrainedDGLTransformer(kind="tmp_gin_supervised_infomax", dtype=float)
model.featurizer.store.model_store_bucket
'gs://molfeat-store-prod/artifacts/'
model(["CCO", "CCN"])
--------------------------------------------------------------------------- IndexError Traceback (most recent call last) File ~/Code/datamol-org/molfeat-core/molfeat/store/loader.py:99, in PretrainedStoreModel._load_or_raise(cls, name, download_path, store, **kwargs) 98 try: ---> 99 modelcard = store.search(name=name)[0] 100 artifact_dir = store.download(modelcard, download_path, **kwargs) IndexError: list index out of range During handling of the above exception, another exception occurred: ModelStoreError Traceback (most recent call last) /Users/manu/Code/datamol-org/molfeat-core/docs/tutorials/custom_model_store.ipynb Cell 35 in <cell line: 1>() ----> <a href='vscode-notebook-cell:/Users/manu/Code/datamol-org/molfeat-core/docs/tutorials/custom_model_store.ipynb#Y111sZmlsZQ%3D%3D?line=0'>1</a> model(["CCO", "CCN"]) File ~/Code/datamol-org/molfeat-core/molfeat/trans/base.py:376, in MoleculeTransformer.__call__(self, mols, enforce_dtype, ignore_errors, **kwargs) 351 def __call__( 352 self, 353 mols: List[Union[dm.Mol, str]], (...) 356 **kwargs, 357 ): 358 r""" 359 Calculate features for molecules. Using __call__, instead of transform. 360 If ignore_error is True, a list of features and valid ids are returned. (...) 374 375 """ --> 376 features = self.transform(mols, ignore_errors=ignore_errors, enforce_dtype=False, **kwargs) 377 ids = np.arange(len(features)) 378 if ignore_errors: File ~/.miniconda/envs/molfeat-dev/lib/python3.10/site-packages/sklearn/utils/_set_output.py:140, in _wrap_method_output.<locals>.wrapped(self, X, *args, **kwargs) 138 @wraps(f) 139 def wrapped(self, X, *args, **kwargs): --> 140 data_to_wrap = f(self, X, *args, **kwargs) 141 if isinstance(data_to_wrap, tuple): 142 # only wrap the first output for cross decomposition 143 return ( 144 _wrap_data_with_container(method, data_to_wrap[0], X, self), 145 *data_to_wrap[1:], 146 ) File ~/Code/datamol-org/molfeat-core/molfeat/trans/pretrained/base.py:208, in PretrainedMolTransformer.transform(self, smiles, **kwargs) 206 if len(mols) > 0: 207 converted_mols = self._convert(mols, **kwargs) --> 208 out = self._embed(converted_mols, **kwargs) 210 if not isinstance(out, list): 211 out = list(out) File ~/Code/datamol-org/molfeat-core/molfeat/trans/pretrained/dgl_pretrained.py:221, in PretrainedDGLTransformer._embed(self, smiles, **kwargs) 219 def _embed(self, smiles: List[str], **kwargs): 220 """Embed molecules into a latent space""" --> 221 self._preload() 222 dataset, successes = self.graph_featurizer(smiles, kind=self.kind) 223 if self.kind in DGLModel.available_models(query="^jtvae"): File ~/Code/datamol-org/molfeat-core/molfeat/trans/pretrained/base.py:90, in PretrainedMolTransformer._preload(self) 88 """Preload the pretrained model for later queries""" 89 if self.featurizer is not None and isinstance(self.featurizer, PretrainedModel): ---> 90 self.featurizer = self.featurizer.load() 91 self.preload = True File ~/Code/datamol-org/molfeat-core/molfeat/trans/pretrained/dgl_pretrained.py:85, in DGLModel.load(self) 83 if self._model is not None: 84 return self._model ---> 85 download_output_dir = self._artifact_load( 86 name=self.name, download_path=self.cache_path, store=self.store 87 ) 88 model_path = dm.fs.join(download_output_dir, self.store.MODEL_PATH_NAME) 89 with fsspec.open(model_path, "rb") as f: File ~/Code/datamol-org/molfeat-core/molfeat/store/loader.py:81, in PretrainedStoreModel._artifact_load(cls, name, download_path, **kwargs) 79 if not dm.fs.exists(download_path): 80 cls._load_or_raise.cache_clear() ---> 81 return cls._load_or_raise(name, download_path, **kwargs) File ~/Code/datamol-org/molfeat-core/molfeat/store/loader.py:103, in PretrainedStoreModel._load_or_raise(cls, name, download_path, store, **kwargs) 101 except Exception as e: 102 mess = f"Can't retrieve model {name} from the store !" --> 103 raise ModelStoreError(mess) 104 return artifact_dir ModelStoreError: Can't retrieve model tmp_gin_supervised_infomax from the store !
As expected it's not working, we would need to either load the model first from our store on purpose or change the default loading bucket using the environment variable MOLFEAT_MODEL_STORE_BUCKET
.
%%bash -s "$temp_dir.name"
echo "export MOLFEAT_MODEL_STORE_BUCKET=$1" > .env
# we reimport to reload the store information without restarting the kernel
import dotenv
dotenv.load_dotenv(override=True)
from molfeat.store import ModelStore
model = PretrainedDGLTransformer(kind="tmp_gin_supervised_infomax", dtype=float)
model.featurizer.store.model_store_bucket
'/var/folders/zt/ck4vrp4n4vsb0v16tnlh9h9m0000gn/T/tmp5k07n15a'
model(["CCO", "CCN"]).shape
(2, 300)
Going a bit further: serializing a custom pretrained model into a private model store¶
In the following example, we will explore how to setup a complex pretrained featurizer and load it from a personal modelstore.
First, we need to install the following library to provide the embeddings. We are following the template from the previous tutorial to show how we can serialize the custom astrochem_embedding
model into our private store.
pip install astrochem_embedding
For our customd model, we would need to define the loading process which dictates how the model should be loaded from the store.
It's recommended to inherit from PretrainedStoreModel
if you have a complex loading process.
import datamol as dm
import joblib
import fsspec
import torch
from molfeat.trans.pretrained import PretrainedMolTransformer
from molfeat.store.loader import PretrainedStoreModel
class AstroPretrainedStoreModel(PretrainedStoreModel):
r"""
Define a loading class to load the astrochem model from the store
"""
def load(self):
"""Load VICGAE model"""
download_output_dir = self._artifact_load(
name=self.name, download_path=self.cache_path, store=self.store
)
model_path = dm.fs.join(download_output_dir, self.store.MODEL_PATH_NAME)
with fsspec.open(model_path, "rb") as f:
model = joblib.load(f)
model.eval()
return model
# We define the model class for loading and transforming data
class MyAstroChemFeaturizer(PretrainedMolTransformer):
"""
In this more practical example, we use embeddings from VICGAE a variance-invariance-covariance
regularized GRU autoencoder trained on SELFIES strings.
"""
def __init__(self, name="astrochem_embedding", *args, **kwargs):
super().__init__(*args, **kwargs)
# we load the model from the store
self.model = AstroPretrainedStoreModel(name=name).load()
def _embed(self, smiles, **kwargs):
return [self.model.embed_smiles(x) for x in smiles]
from astrochem_embedding import VICGAE
model = VICGAE.from_pretrained()
# Let's define our model's info card and then save the model to the store
info = ModelInfo(
name = "astrochem_embedding",
inputs = "selfies",
type="pretrained",
group="astrochem",
version=0,
submitter="Datamol",
description="A variance-invariance-covariance regularized GRU autoencoder for astrochemistry using selfies strings!",
representation="vector",
require_3D=False,
tags = ["pretrained", "astrochemistry", "selfies"],
authors= ["Datamol"],
reference = "Lee, K. L. K. (2021). Language models for astrochemistry (Version 0.1.2) [Computer software]. https://github.com/laserkelvin/astrochem_embedding",
)
# We define how to use the model using a string that can be displayed in the docs
usage_string = """
import torch
import datamol as dm
# <how to import MyAstroChemFeaturizer if needed>
transformer = MyAstroChemFeaturizer(dtype=torch.float)
transformer(dm.freesolv()["smiles"][:10]).shape
"""
info.set_usage(usage_string)
# we register the model, this is a simple model that we can just pickle to the store bucket.
temp_model_store.register(info, model=model)
0%| | 0.00/509k [00:00<?, ?B/s]
2023-05-19 15:16:59.293 | INFO | molfeat.store.modelstore:register:150 - Successfuly registered model astrochem_embedding !
print(info.usage())
import torch import datamol as dm # <how to import MyAstroChemFeaturizer if needed> transformer = MyAstroChemFeaturizer(dtype=torch.float) transformer(dm.freesolv()["smiles"][:10]).shape
# let's execute the test example and check
transformer = MyAstroChemFeaturizer(dtype=torch.float)
transformer(dm.freesolv()["smiles"][:10]).shape
0%| | 0.00/882 [00:00<?, ?B/s]
0%| | 0.00/509k [00:00<?, ?B/s]
torch.Size([10, 32])
# a bit of cleaning
temp_dir.cleanup()
! rm -rf .env
You can now create a private modelstore
to save, index and share your custom models.