Skip to content

molfeat.store

ModelInfo

Bases: BaseModel

Source code in molfeat/store/modelcard.py
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class ModelInfo(BaseModel):
    model_config = ConfigDict(
        protected_namespaces=(
            "protected_",
        )  # Prevents warning from usage of model_ prefix in fields
    )

    name: str
    inputs: str = "smiles"
    type: Literal["pretrained", "hand-crafted", "hashed", "count"]
    version: int = 0
    group: Optional[str] = "all"
    submitter: str
    description: str
    representation: Literal["graph", "line-notation", "vector", "tensor", "other"]
    require_3D: Optional[bool] = False
    tags: Optional[List[str]] = []
    authors: Optional[List[str]]
    reference: Optional[str] = None
    created_at: datetime = Field(default_factory=datetime.now)
    sha256sum: Optional[str] = None
    model_usage: Optional[str] = None

    def path(self, root_path: str):
        """Generate the folder path where to save this model

        Args:
            root_path: path to the root folder
        """
        version = str(self.version or 0)
        return dm.fs.join(root_path, self.group, self.name, version)

    def match(self, new_card: Union["ModelInfo", dict], match_only: Optional[List[str]] = None):
        """Compare two model card information and returns True if they are the same

        Args:
            new_card: card to search for in the modelstore
            match_only: list of minimum attribute that should match between the two model information
        """

        self_content = self.model_dump().copy()
        if not isinstance(new_card, dict):
            new_card = new_card.model_dump()
        new_content = new_card.copy()
        # we always remove the datetime field
        self_content.pop("created_at", None)
        new_content.pop("created_at", None)
        if match_only is not None:
            self_content = {k: self_content.get(k) for k in match_only}
            new_content = {k: new_content.get(k) for k in match_only}
        return self_content == new_content

    def set_usage(self, usage: str):
        """Set the usage of the model

        Args:
            usage: usage of the model
        """
        self.model_usage = usage

    def usage(self):
        """Return the usage of the model"""
        if self.model_usage is not None and self.model_usage:
            return self.model_usage
        import_statement, loader_statement = get_model_init(self)
        comment = "# sanitize and standardize your molecules if needed"
        if self.require_3D:
            comment += "\n# <generate 3D coordinates here> "
        usage = f"""
        import datamol as dm
        {import_statement}
        smiles = dm.freesolv().iloc[:100].smiles
        {comment}
        transformer = {loader_statement}
        features = transformer(smiles)
        """
        usage = "\n".join([x.strip() for x in usage.split("\n")])
        return usage

match(new_card, match_only=None)

Compare two model card information and returns True if they are the same

Parameters:

Name Type Description Default
new_card Union[ModelInfo, dict]

card to search for in the modelstore

required
match_only Optional[List[str]]

list of minimum attribute that should match between the two model information

None
Source code in molfeat/store/modelcard.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def match(self, new_card: Union["ModelInfo", dict], match_only: Optional[List[str]] = None):
    """Compare two model card information and returns True if they are the same

    Args:
        new_card: card to search for in the modelstore
        match_only: list of minimum attribute that should match between the two model information
    """

    self_content = self.model_dump().copy()
    if not isinstance(new_card, dict):
        new_card = new_card.model_dump()
    new_content = new_card.copy()
    # we always remove the datetime field
    self_content.pop("created_at", None)
    new_content.pop("created_at", None)
    if match_only is not None:
        self_content = {k: self_content.get(k) for k in match_only}
        new_content = {k: new_content.get(k) for k in match_only}
    return self_content == new_content

path(root_path)

Generate the folder path where to save this model

Parameters:

Name Type Description Default
root_path str

path to the root folder

required
Source code in molfeat/store/modelcard.py
74
75
76
77
78
79
80
81
def path(self, root_path: str):
    """Generate the folder path where to save this model

    Args:
        root_path: path to the root folder
    """
    version = str(self.version or 0)
    return dm.fs.join(root_path, self.group, self.name, version)

set_usage(usage)

Set the usage of the model

Parameters:

Name Type Description Default
usage str

usage of the model

required
Source code in molfeat/store/modelcard.py
103
104
105
106
107
108
109
def set_usage(self, usage: str):
    """Set the usage of the model

    Args:
        usage: usage of the model
    """
    self.model_usage = usage

usage()

Return the usage of the model

Source code in molfeat/store/modelcard.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def usage(self):
    """Return the usage of the model"""
    if self.model_usage is not None and self.model_usage:
        return self.model_usage
    import_statement, loader_statement = get_model_init(self)
    comment = "# sanitize and standardize your molecules if needed"
    if self.require_3D:
        comment += "\n# <generate 3D coordinates here> "
    usage = f"""
    import datamol as dm
    {import_statement}
    smiles = dm.freesolv().iloc[:100].smiles
    {comment}
    transformer = {loader_statement}
    features = transformer(smiles)
    """
    usage = "\n".join([x.strip() for x in usage.split("\n")])
    return usage

get_model_init(card)

Get the model initialization code

Parameters:

Name Type Description Default
card

model card to use

required
Source code in molfeat/store/modelcard.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def get_model_init(card):
    """Get the model initialization code

    Args:
        card: model card to use
    """
    if card.group == "all" and card.type != "pretrained":
        import_statement = "from molfeat.trans import MoleculeTransformer"
        loader_statement = f"MoleculeTransformer(featurizer='{card.name}', dtype=float)"
    elif card.group in ["rdkit", "fp", "shape"]:
        import_statement = "from molfeat.trans.fp import FPVecTransformer"
        loader_statement = f"FPVecTransformer(kind='{card.name}', dtype=float)"
    elif card.group == "dgllife":
        import_statement = "from molfeat.trans.pretrained import PretrainedDGLTransformer"
        loader_statement = f"PretrainedDGLTransformer(kind='{card.name}', dtype=float)"
    elif card.group == "graphormer":
        import_statement = "from molfeat.trans.pretrained import GraphormerTransformer"
        loader_statement = f"GraphormerTransformer(kind='{card.name}', dtype=float)"
    elif card.group == "fcd":
        import_statement = "from molfeat.trans.pretrained import FCDTransformer"
        loader_statement = "FCDTransformer()"
    elif card.group == "pharmacophore":
        name = card.name.split("-")[-1]
        if card.require_3D:
            import_class = "Pharmacophore3D"
        else:
            import_class = "Pharmacophore2D"
        import_statement = f"from molfeat.trans.base import MoleculeTransformer\nfrom molfeat.calc.pharmacophore import {import_class}"
        loader_statement = (
            f"MoleculeTransformer(featurizer={import_class}(factory='{name}'), dtype=float)"
        )
    elif card.group == "huggingface":
        import_statement = (
            "from molfeat.trans.pretrained.hf_transformers import PretrainedHFTransformer"
        )
        loader_statement = (
            f"PretrainedHFTransformer(kind='{card.name}', notation='{card.inputs}', dtype=float)"
        )
    else:
        raise ValueError(f"Unknown model group {card.group}")
    return import_statement, loader_statement

ModelStore

A class for artefact serializing from any url

This class not only allow pretrained model serializing and loading, but also help in listing model availability and registering models.

For simplicity. * There is no versioning. * Only one model should match a given name * Model deletion is not allowed (on the read-only default store) * Only a single store is supported per model store instance

Building a New Model Store

To create a new model store, you will mainly need a model store bucket path. The default model store bucket, located at gs://molfeat-store-prod/artifacts/, is read-only.

To build your own model store bucket, follow the instructions below:

  1. Create a local or remote cloud directory that can be accessed by fsspec (and the corresponding filesystem).
  2. [Optional] Sync the default model store bucket to your new path if you want to access the default models.
  3. Set the environment variable MOLFEAT_MODEL_STORE_BUCKET to your new path. This variable will be used as the default model store bucket when creating a new model store instance without specifying a path. Note that setting up this path is necessary if you want to access models directly by their names, without manually loading them from your custom model store.
Source code in molfeat/store/modelstore.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
class ModelStore:
    """A class for artefact serializing from any url

    This class not only allow pretrained model serializing and loading,
    but also help in listing model availability and registering models.

    For simplicity.
        * There is no versioning.
        * Only one model should match a given name
        * Model deletion is not allowed (on the read-only default store)
        * Only a single store is supported per model store instance

    !!! note "Building a New Model Store"
        To create a new model store, you will mainly need a model store bucket path. The default model store bucket, located at `gs://molfeat-store-prod/artifacts/`, is **read-only**.

        To build your own model store bucket, follow the instructions below:

        1. Create a local or remote cloud directory that can be accessed by fsspec (and the corresponding filesystem).
        2. [Optional] Sync the default model store bucket to your new path if you want to access the default models.
        3. Set the environment variable `MOLFEAT_MODEL_STORE_BUCKET` to your new path. This variable will be used as the default model store bucket when creating a new model store instance without specifying a path.
            Note that setting up this path is necessary if you want to access models directly by their names, without manually loading them from your custom model store.


    """

    # EN: be careful not to recreate ada
    # EN: should we just use modelstore ?
    MODEL_STORE_BUCKET = "gs://molfeat-store-prod/artifacts/"
    MODEL_PATH_NAME = "model.save"
    METADATA_PATH_NAME = "metadata.json"

    def __init__(self, model_store_bucket: Optional[str] = None):
        if model_store_bucket is None:
            model_store_bucket = os.getenv("MOLFEAT_MODEL_STORE_BUCKET", self.MODEL_STORE_BUCKET)
        self.model_store_bucket = model_store_bucket
        self._available_models = []
        self._update_store()

    def _update_store(self):
        """Initialize the store with all available models"""
        all_metadata = dm.fs.glob(dm.fs.join(self.model_store_bucket, "**/metadata.json"))
        self._available_models = []
        for mtd_file in all_metadata:
            with fsspec.open(mtd_file, "r") as IN:
                mtd_content = yaml.safe_load(IN)
                model_info = ModelInfo(**mtd_content)
                self._available_models.append(model_info)

    @property
    def available_models(self):
        """Return a list of all models that have been serialized in molfeat"""
        return self._available_models

    def __len__(self):
        """Return the length of the model store"""
        return len(self.available_models)

    def register(
        self,
        modelcard: Union[ModelInfo, dict],
        model: Optional[Any] = None,
        chunk_size: int = 2048,
        save_fn: Optional[Callable] = None,
        save_fn_kwargs: Optional[dict] = None,
        force: bool = True,
    ):
        """
        Register a new model to the store

        !!! note `save_fn`
            You can pass additional kwargs for your `save_fn` through the `save_fn_kwargs` argument.
            It's expected that `save_fn` will be called as : `save_fn(model, <model_upload_path>, **save_fn_wargs)`,
            with `<model_upload_path>` being provided by the model store, and that it will return the path to the serialized model.
            If not provided, `joblib.dump` is used by default.

        Args:
            modelcard: Model information
            model: A path to the model artifact or any object that needs to be saved
            chunk_size: the chunk size for the upload
            save_fn: any custom function for serializing the model, that takes the model, the upload path and parameters `save_fn_kwargs` as inputs.
            save_fn_kwargs: any additional kwargs to pass to save_fn
            force: whether to force upload to the bucket

        """
        if not isinstance(modelcard, ModelInfo):
            modelcard = ModelInfo(**modelcard)
        # we save the model first
        if self.exists(card=modelcard):
            logger.warning(f"Model {modelcard.name} exists already ...")
            if not force:
                return

        model_root_dir = modelcard.path(self.model_store_bucket)
        model_path = model
        model_upload_path = dm.fs.join(model_root_dir, self.MODEL_PATH_NAME)
        model_metadata_upload_path = dm.fs.join(model_root_dir, self.METADATA_PATH_NAME)

        save_fn_kwargs = save_fn_kwargs or {}
        if save_fn is None:
            if not isinstance(model, (pathlib.Path, os.PathLike)):
                local_model_path = tempfile.NamedTemporaryFile(delete=False)
                with local_model_path:
                    joblib.dump(model, local_model_path)
                model_path = local_model_path.name
            # Upload the artifact to the bucket
            dm.fs.copy_file(
                model_path,
                model_upload_path,
                progress=True,
                leave_progress=False,
                chunk_size=chunk_size,
                force=force,
            )
        else:
            model_path = save_fn(model, model_upload_path, **save_fn_kwargs)
            # we reset to None if the save_fn has not returned anything
            model_path = model_path or model_upload_path
        modelcard.sha256sum = commons.sha256sum(model_path)
        # then we save the metadata as json
        with fsspec.open(model_metadata_upload_path, "w") as OUT:
            OUT.write(modelcard.json())
        self._update_store()
        logger.info(f"Successfuly registered model {modelcard.name} !")

    def _filelock(self, lock_name: str):
        """Create an empty lock file into `cache_dir_path/locks/lock_name`"""

        lock_path = dm.fs.join(
            str(platformdirs.user_cache_dir("molfeat")), "_lock_files", lock_name
        )
        dm.fs.get_mapper(lock_path)
        # ensure file is created
        # out = mapper.fs.touch(lock_path) # does not work  -_-
        with fsspec.open(lock_path, "w", auto_mkdir=True):
            pass

        return filelock.FileLock(lock_path)

    def download(
        self,
        modelcard: ModelInfo,
        output_dir: Optional[Union[os.PathLike, pathlib.Path]] = None,
        chunk_size: int = 2048,
        force: bool = False,
    ):
        """Download an artifact locally

        Args:
            modelcard: information on the model to download
            output_dir: path where to save the downloaded artifact
            chunk_size: chunk size to use for download
            force: whether to force download even if the file exists already
        """

        remote_dir = modelcard.path(self.model_store_bucket)
        model_name = modelcard.name
        if not self.exists(modelcard, check_remote=True):
            raise ModelStoreError(f"Model {model_name} does not exist in the model store !")

        if output_dir is None:
            output_dir = dm.fs.join(platformdirs.user_cache_dir("molfeat"), model_name)

        dm.fs.mkdir(output_dir, exist_ok=True)

        model_remote_path = dm.fs.join(remote_dir, self.MODEL_PATH_NAME)
        model_dest_path = dm.fs.join(output_dir, self.MODEL_PATH_NAME)
        metadata_remote_path = dm.fs.join(remote_dir, self.METADATA_PATH_NAME)
        metadata_dest_path = dm.fs.join(output_dir, self.METADATA_PATH_NAME)

        # avoid downloading if the file exists already
        if (
            not (
                dm.fs.exists(metadata_dest_path)
                and (dm.fs.exists(model_dest_path) == dm.fs.exists(model_remote_path))
            )
            or force
        ):
            # metadata should exists if the model exists
            with self._filelock(f"{model_name}.metadata.json.lock"):
                dm.fs.copy_file(
                    metadata_remote_path,
                    metadata_dest_path,
                    progress=True,
                    leave_progress=False,
                    force=True,
                )

            if dm.fs.exists(model_remote_path):
                with self._filelock(f"{model_name}.lock"):
                    if dm.fs.is_dir(model_remote_path):
                        # we copy the model dir
                        dm.fs.copy_dir(
                            model_remote_path,
                            model_dest_path,
                            progress=True,
                            leave_progress=False,
                            chunk_size=chunk_size,
                            force=force,
                        )
                    else:
                        # we copy the model dir
                        dm.fs.copy_file(
                            model_remote_path,
                            model_dest_path,
                            progress=True,
                            leave_progress=False,
                            chunk_size=chunk_size,
                            force=force,
                        )

        cache_sha256sum = commons.sha256sum(model_dest_path)
        if modelcard.sha256sum is not None and cache_sha256sum != modelcard.sha256sum:
            mapper = dm.fs.get_mapper(output_dir)
            mapper.fs.delete(output_dir, recursive=True)
            raise ModelStoreError(
                f"""The destination artifact at {model_dest_path} has a different sha256sum ({cache_sha256sum}) """
                f"""than the Remote artifact sha256sum ({modelcard.sha256sum}). The destination artifact has been removed !"""
            )

        return output_dir

    def load(
        self,
        model_name: Union[str, dict, ModelInfo],
        load_fn: Optional[Callable] = None,
        load_fn_kwargs: Optional[dict] = None,
        download_output_dir: Optional[Union[os.PathLike, pathlib.Path]] = None,
        chunk_size: int = 2048,
        force: bool = False,
    ):
        """
        Load a model by its name

        Args:
            model_name: name of the model to load
            load_fn: Custom loading function to load the model
            load_fn_kwargs: Optional dict of additional kwargs to provide to the loading function
            download_output_dir: Argument for download function to specify the download folder
            chunk_size: chunk size for download
            force: whether to reforce the download of the file

        Returns:
            model: Optional model, if the model requires download or loading weights
            model_info: model information card
        """
        if isinstance(model_name, str):
            # find the model with the same name
            modelcard = self.search(name=model_name)[0]
        else:
            modelcard = model_name
        output_dir = self.download(
            modelcard=modelcard,
            output_dir=download_output_dir,
            chunk_size=chunk_size,
            force=force,
        )
        if load_fn is None:
            load_fn = joblib.load
        model_path = dm.fs.join(output_dir, self.MODEL_PATH_NAME)
        metadata_path = dm.fs.join(output_dir, self.METADATA_PATH_NAME)

        # deal with non-pretrained models that might not have a serialized file
        model = None
        load_fn_kwargs = load_fn_kwargs or {}
        if dm.fs.exists(model_path):
            model = load_fn(model_path, **load_fn_kwargs)
        with fsspec.open(metadata_path, "r") as IN:
            model_info_dict = yaml.safe_load(IN)
        model_info = ModelInfo(**model_info_dict)
        return model, model_info

    def __contains__(self, card: Optional[ModelInfo] = None):
        return self.exists(card)

    def exists(
        self,
        card: Optional[ModelInfo] = None,
        check_remote: bool = False,
        **match_params,
    ) -> bool:
        """Returns True if a model is registered in the store

        Args:
            card: card of the model to check
            check_remote: whether to check if the remote path of the model exists
            match_params: parameters for matching as expected by `ModelInfo.match`
        """

        found = False
        for model_info in self.available_models:
            if model_info.match(card, **match_params):
                found = True
                break
        return found and (not check_remote or dm.fs.exists(card.path(self.model_store_bucket)))

    def search(self, modelcard: Optional[ModelInfo] = None, **search_kwargs):
        """ "Return all model card that match the required search parameters

        Args:
            modelcard: model card to search for
            search_kwargs: search parameters to use
        """
        search_infos = {}
        found = []
        if modelcard is not None:
            search_infos = modelcard.dict().copy()
        search_infos.update(search_kwargs)
        for model in self.available_models:
            if model.match(search_infos, match_only=list(search_infos.keys())):
                found.append(model)
        return found

available_models property

Return a list of all models that have been serialized in molfeat

__len__()

Return the length of the model store

Source code in molfeat/store/modelstore.py
78
79
80
def __len__(self):
    """Return the length of the model store"""
    return len(self.available_models)

download(modelcard, output_dir=None, chunk_size=2048, force=False)

Download an artifact locally

Parameters:

Name Type Description Default
modelcard ModelInfo

information on the model to download

required
output_dir Optional[Union[PathLike, Path]]

path where to save the downloaded artifact

None
chunk_size int

chunk size to use for download

2048
force bool

whether to force download even if the file exists already

False
Source code in molfeat/store/modelstore.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
def download(
    self,
    modelcard: ModelInfo,
    output_dir: Optional[Union[os.PathLike, pathlib.Path]] = None,
    chunk_size: int = 2048,
    force: bool = False,
):
    """Download an artifact locally

    Args:
        modelcard: information on the model to download
        output_dir: path where to save the downloaded artifact
        chunk_size: chunk size to use for download
        force: whether to force download even if the file exists already
    """

    remote_dir = modelcard.path(self.model_store_bucket)
    model_name = modelcard.name
    if not self.exists(modelcard, check_remote=True):
        raise ModelStoreError(f"Model {model_name} does not exist in the model store !")

    if output_dir is None:
        output_dir = dm.fs.join(platformdirs.user_cache_dir("molfeat"), model_name)

    dm.fs.mkdir(output_dir, exist_ok=True)

    model_remote_path = dm.fs.join(remote_dir, self.MODEL_PATH_NAME)
    model_dest_path = dm.fs.join(output_dir, self.MODEL_PATH_NAME)
    metadata_remote_path = dm.fs.join(remote_dir, self.METADATA_PATH_NAME)
    metadata_dest_path = dm.fs.join(output_dir, self.METADATA_PATH_NAME)

    # avoid downloading if the file exists already
    if (
        not (
            dm.fs.exists(metadata_dest_path)
            and (dm.fs.exists(model_dest_path) == dm.fs.exists(model_remote_path))
        )
        or force
    ):
        # metadata should exists if the model exists
        with self._filelock(f"{model_name}.metadata.json.lock"):
            dm.fs.copy_file(
                metadata_remote_path,
                metadata_dest_path,
                progress=True,
                leave_progress=False,
                force=True,
            )

        if dm.fs.exists(model_remote_path):
            with self._filelock(f"{model_name}.lock"):
                if dm.fs.is_dir(model_remote_path):
                    # we copy the model dir
                    dm.fs.copy_dir(
                        model_remote_path,
                        model_dest_path,
                        progress=True,
                        leave_progress=False,
                        chunk_size=chunk_size,
                        force=force,
                    )
                else:
                    # we copy the model dir
                    dm.fs.copy_file(
                        model_remote_path,
                        model_dest_path,
                        progress=True,
                        leave_progress=False,
                        chunk_size=chunk_size,
                        force=force,
                    )

    cache_sha256sum = commons.sha256sum(model_dest_path)
    if modelcard.sha256sum is not None and cache_sha256sum != modelcard.sha256sum:
        mapper = dm.fs.get_mapper(output_dir)
        mapper.fs.delete(output_dir, recursive=True)
        raise ModelStoreError(
            f"""The destination artifact at {model_dest_path} has a different sha256sum ({cache_sha256sum}) """
            f"""than the Remote artifact sha256sum ({modelcard.sha256sum}). The destination artifact has been removed !"""
        )

    return output_dir

exists(card=None, check_remote=False, **match_params)

Returns True if a model is registered in the store

Parameters:

Name Type Description Default
card Optional[ModelInfo]

card of the model to check

None
check_remote bool

whether to check if the remote path of the model exists

False
match_params

parameters for matching as expected by ModelInfo.match

{}
Source code in molfeat/store/modelstore.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
def exists(
    self,
    card: Optional[ModelInfo] = None,
    check_remote: bool = False,
    **match_params,
) -> bool:
    """Returns True if a model is registered in the store

    Args:
        card: card of the model to check
        check_remote: whether to check if the remote path of the model exists
        match_params: parameters for matching as expected by `ModelInfo.match`
    """

    found = False
    for model_info in self.available_models:
        if model_info.match(card, **match_params):
            found = True
            break
    return found and (not check_remote or dm.fs.exists(card.path(self.model_store_bucket)))

load(model_name, load_fn=None, load_fn_kwargs=None, download_output_dir=None, chunk_size=2048, force=False)

Load a model by its name

Parameters:

Name Type Description Default
model_name Union[str, dict, ModelInfo]

name of the model to load

required
load_fn Optional[Callable]

Custom loading function to load the model

None
load_fn_kwargs Optional[dict]

Optional dict of additional kwargs to provide to the loading function

None
download_output_dir Optional[Union[PathLike, Path]]

Argument for download function to specify the download folder

None
chunk_size int

chunk size for download

2048
force bool

whether to reforce the download of the file

False

Returns:

Name Type Description
model

Optional model, if the model requires download or loading weights

model_info

model information card

Source code in molfeat/store/modelstore.py
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
def load(
    self,
    model_name: Union[str, dict, ModelInfo],
    load_fn: Optional[Callable] = None,
    load_fn_kwargs: Optional[dict] = None,
    download_output_dir: Optional[Union[os.PathLike, pathlib.Path]] = None,
    chunk_size: int = 2048,
    force: bool = False,
):
    """
    Load a model by its name

    Args:
        model_name: name of the model to load
        load_fn: Custom loading function to load the model
        load_fn_kwargs: Optional dict of additional kwargs to provide to the loading function
        download_output_dir: Argument for download function to specify the download folder
        chunk_size: chunk size for download
        force: whether to reforce the download of the file

    Returns:
        model: Optional model, if the model requires download or loading weights
        model_info: model information card
    """
    if isinstance(model_name, str):
        # find the model with the same name
        modelcard = self.search(name=model_name)[0]
    else:
        modelcard = model_name
    output_dir = self.download(
        modelcard=modelcard,
        output_dir=download_output_dir,
        chunk_size=chunk_size,
        force=force,
    )
    if load_fn is None:
        load_fn = joblib.load
    model_path = dm.fs.join(output_dir, self.MODEL_PATH_NAME)
    metadata_path = dm.fs.join(output_dir, self.METADATA_PATH_NAME)

    # deal with non-pretrained models that might not have a serialized file
    model = None
    load_fn_kwargs = load_fn_kwargs or {}
    if dm.fs.exists(model_path):
        model = load_fn(model_path, **load_fn_kwargs)
    with fsspec.open(metadata_path, "r") as IN:
        model_info_dict = yaml.safe_load(IN)
    model_info = ModelInfo(**model_info_dict)
    return model, model_info

register(modelcard, model=None, chunk_size=2048, save_fn=None, save_fn_kwargs=None, force=True)

Register a new model to the store

!!! note save_fn You can pass additional kwargs for your save_fn through the save_fn_kwargs argument. It's expected that save_fn will be called as : save_fn(model, <model_upload_path>, **save_fn_wargs), with <model_upload_path> being provided by the model store, and that it will return the path to the serialized model. If not provided, joblib.dump is used by default.

Parameters:

Name Type Description Default
modelcard Union[ModelInfo, dict]

Model information

required
model Optional[Any]

A path to the model artifact or any object that needs to be saved

None
chunk_size int

the chunk size for the upload

2048
save_fn Optional[Callable]

any custom function for serializing the model, that takes the model, the upload path and parameters save_fn_kwargs as inputs.

None
save_fn_kwargs Optional[dict]

any additional kwargs to pass to save_fn

None
force bool

whether to force upload to the bucket

True
Source code in molfeat/store/modelstore.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def register(
    self,
    modelcard: Union[ModelInfo, dict],
    model: Optional[Any] = None,
    chunk_size: int = 2048,
    save_fn: Optional[Callable] = None,
    save_fn_kwargs: Optional[dict] = None,
    force: bool = True,
):
    """
    Register a new model to the store

    !!! note `save_fn`
        You can pass additional kwargs for your `save_fn` through the `save_fn_kwargs` argument.
        It's expected that `save_fn` will be called as : `save_fn(model, <model_upload_path>, **save_fn_wargs)`,
        with `<model_upload_path>` being provided by the model store, and that it will return the path to the serialized model.
        If not provided, `joblib.dump` is used by default.

    Args:
        modelcard: Model information
        model: A path to the model artifact or any object that needs to be saved
        chunk_size: the chunk size for the upload
        save_fn: any custom function for serializing the model, that takes the model, the upload path and parameters `save_fn_kwargs` as inputs.
        save_fn_kwargs: any additional kwargs to pass to save_fn
        force: whether to force upload to the bucket

    """
    if not isinstance(modelcard, ModelInfo):
        modelcard = ModelInfo(**modelcard)
    # we save the model first
    if self.exists(card=modelcard):
        logger.warning(f"Model {modelcard.name} exists already ...")
        if not force:
            return

    model_root_dir = modelcard.path(self.model_store_bucket)
    model_path = model
    model_upload_path = dm.fs.join(model_root_dir, self.MODEL_PATH_NAME)
    model_metadata_upload_path = dm.fs.join(model_root_dir, self.METADATA_PATH_NAME)

    save_fn_kwargs = save_fn_kwargs or {}
    if save_fn is None:
        if not isinstance(model, (pathlib.Path, os.PathLike)):
            local_model_path = tempfile.NamedTemporaryFile(delete=False)
            with local_model_path:
                joblib.dump(model, local_model_path)
            model_path = local_model_path.name
        # Upload the artifact to the bucket
        dm.fs.copy_file(
            model_path,
            model_upload_path,
            progress=True,
            leave_progress=False,
            chunk_size=chunk_size,
            force=force,
        )
    else:
        model_path = save_fn(model, model_upload_path, **save_fn_kwargs)
        # we reset to None if the save_fn has not returned anything
        model_path = model_path or model_upload_path
    modelcard.sha256sum = commons.sha256sum(model_path)
    # then we save the metadata as json
    with fsspec.open(model_metadata_upload_path, "w") as OUT:
        OUT.write(modelcard.json())
    self._update_store()
    logger.info(f"Successfuly registered model {modelcard.name} !")

search(modelcard=None, **search_kwargs)

"Return all model card that match the required search parameters

Parameters:

Name Type Description Default
modelcard Optional[ModelInfo]

model card to search for

None
search_kwargs

search parameters to use

{}
Source code in molfeat/store/modelstore.py
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
def search(self, modelcard: Optional[ModelInfo] = None, **search_kwargs):
    """ "Return all model card that match the required search parameters

    Args:
        modelcard: model card to search for
        search_kwargs: search parameters to use
    """
    search_infos = {}
    found = []
    if modelcard is not None:
        search_infos = modelcard.dict().copy()
    search_infos.update(search_kwargs)
    for model in self.available_models:
        if model.match(search_infos, match_only=list(search_infos.keys())):
            found.append(model)
    return found

PretrainedModel

Bases: ABC

Base class for loading pretrained models

Source code in molfeat/store/loader.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class PretrainedModel(abc.ABC):
    """Base class for loading pretrained models"""

    @classmethod
    def _artifact_load(cls, name: str, download_path: Optional[os.PathLike] = None, **kwargs):
        """Load an artifact based on its name

        Args:
            name: name of the model to load
            download_path: path to a directory where to save the downloaded files
        """
        ...

    @classmethod
    def _load_or_raise(
        cls,
        name: str,
        download_path: Optional[os.PathLike] = None,
        **kwargs,
    ):
        """Load model or raise an exception

        Args:
            name: name of the model to load
            download_path: local download path of the model

        """
        ...

    @abc.abstractmethod
    def load(self):
        """Load the model"""
        ...

load() abstractmethod

Load the model

Source code in molfeat/store/loader.py
40
41
42
43
@abc.abstractmethod
def load(self):
    """Load the model"""
    ...

PretrainedStoreModel

Bases: PretrainedModel

Class for loading pretrained models from the model zoo

Source code in molfeat/store/loader.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
class PretrainedStoreModel(PretrainedModel):
    r"""
    Class for loading pretrained models from the model zoo
    """

    def __init__(
        self,
        name: str,
        cache_path: Optional[os.PathLike] = None,
        store: Optional[ModelStore] = None,
    ):
        """Interface for pretrained model from the default modelstore

        Args:
            name: name of the pretrained transformer in the model store
            cache_path: optional local cache path.
            store: ModelStore to use for loading the pretrained model
        """
        self.name = name
        self.cache_path = cache_path
        if store is None:
            store = ModelStore()
        self.store = store

    @classmethod
    def _artifact_load(cls, name: str, download_path: Optional[os.PathLike] = None, **kwargs):
        """Load internal artefact from the model store

        Args:
            name: name of the model to load
            download_path: path to a directory where to save the downloaded files
        """

        if not dm.fs.exists(download_path):
            cls._load_or_raise.cache_clear()
        return cls._load_or_raise(name, download_path, **kwargs)

    @classmethod
    @lru_cache(maxsize=100)
    def _load_or_raise(
        cls,
        name: str,
        download_path: Optional[os.PathLike] = None,
        store: Optional[ModelStore] = None,
        **kwargs,
    ):
        """Load when from ada or raise exception
        Args:
            name: name
        """
        if store is None:
            store = ModelStore()
        try:
            modelcard = store.search(name=name)[0]
            artifact_dir = store.download(modelcard, download_path, **kwargs)
        except Exception:
            mess = f"Can't retrieve model {name} from the store !"
            raise ModelStoreError(mess)
        return artifact_dir

    def load(self):
        """Load the model"""
        raise NotImplementedError

__init__(name, cache_path=None, store=None)

Interface for pretrained model from the default modelstore

Parameters:

Name Type Description Default
name str

name of the pretrained transformer in the model store

required
cache_path Optional[PathLike]

optional local cache path.

None
store Optional[ModelStore]

ModelStore to use for loading the pretrained model

None
Source code in molfeat/store/loader.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def __init__(
    self,
    name: str,
    cache_path: Optional[os.PathLike] = None,
    store: Optional[ModelStore] = None,
):
    """Interface for pretrained model from the default modelstore

    Args:
        name: name of the pretrained transformer in the model store
        cache_path: optional local cache path.
        store: ModelStore to use for loading the pretrained model
    """
    self.name = name
    self.cache_path = cache_path
    if store is None:
        store = ModelStore()
    self.store = store

load()

Load the model

Source code in molfeat/store/loader.py
106
107
108
def load(self):
    """Load the model"""
    raise NotImplementedError