Skip to content

Base Pretrained Models

Pretrained Model

PretrainedMolTransformer

Bases: MoleculeTransformer

Transformer based on pretrained featurizer

Note

  • When implementing a subclass of this class, you need to define the _embed and optionally the _convert methods.
  • If your model is an instance of PretrainedModel that handles loading of the model from a store or through a complex mechanism then you can decide whether you want to preload the true underlying model. You will be in charge of handling the logic of when you need to call preload, and when you don't. Note however that by default preloading is only attempted when the featurizer is still an instance of PretrainedModel.

Attributes featurizer (object): featurizer object dtype (type, optional): Data type. Use call instead precompute_cache: (bool, optional): Whether to precompute the features into a local cache. Defaults to False. Note that due to molecular hashing, some pretrained featurizers might be better off just not using any cache as they can be faster. Furthermore, the cache is not saved when pickling the object. If you want to save the cache, you need to save the object separately. _require_mols (bool): Whether the embedding takes mols or smiles as input preload: whether to preload the pretrained model from the store (if available) during initialization.

Source code in molfeat/trans/pretrained/base.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
class PretrainedMolTransformer(MoleculeTransformer):
    r"""
    Transformer based on pretrained featurizer

    !!! note
        * When implementing a subclass of this class, you need to define the `_embed` and optionally the `_convert` methods.
        * If your model is an instance of PretrainedModel that handles loading of the model from a store or through a complex mechanism
          then you can decide whether you want to preload the true underlying model. You will be in charge of handling the logic of when you need to call preload, and when you don't.
          Note however that by default preloading is only attempted when the featurizer is still an instance of PretrainedModel.


    Attributes
        featurizer (object): featurizer object
        dtype (type, optional): Data type. Use call instead
        precompute_cache: (bool, optional): Whether to precompute the features into a local cache. Defaults to False.
            Note that due to molecular hashing, some pretrained featurizers might be better off just not using any cache as they can be faster.
            Furthermore, the cache is not saved when pickling the object. If you want to save the cache, you need to save the object separately.
        _require_mols (bool): Whether the embedding takes mols or smiles as input
        preload: whether to preload the pretrained model from the store (if available) during initialization.

    """

    def __init__(
        self,
        dtype: Optional[Callable] = None,
        precompute_cache: Optional[Union[bool, DataCache]] = None,
        preload: bool = False,
        **params,
    ):
        self._save_input_args()

        featurizer = params.pop("featurizer", None)
        super().__init__(dtype=dtype, featurizer="none", **params)
        self.featurizer = featurizer
        self._require_mols = False
        self.preload = preload
        self._feat_length = None
        if precompute_cache is False:
            precompute_cache = None
        if precompute_cache is True:
            name = str(self.__class__.__name__)
            precompute_cache = DataCache(name=name)
        self.precompute_cache = precompute_cache

    def set_cache(self, cache: DataCache):
        """Set the cache for the transformer

        Args:
            cache: cache object
        """
        self.precompute_cache = cache

    def _get_param_names(self):
        """Get parameter names for the estimator"""
        out = self._input_params.keys()
        out = [x for x in out if x != "featurizer"]
        return out

    def _embed(self, smiles: str, **kwargs):
        """Compute molecular embeddings for input list of smiles
        This functiom takes a list of smiles or molecules and return the featurization
        corresponding to the inputs.  In `transform` and `_transform`, this function is
        called after calling `_convert`

        Args:
            smiles: input smiles
        """
        raise NotImplementedError

    def _preload(self):
        """Preload the pretrained model for later queries"""
        if self.featurizer is not None and isinstance(self.featurizer, PretrainedModel):
            self.featurizer = self.featurizer.load()
            self.preload = True

    def __getstate__(self):
        """Getting state to allow pickling"""
        d = copy.deepcopy(self.__dict__)
        d["precompute_cache"] = None
        if isinstance(getattr(self, "featurizer", None), PretrainedModel) or self.preload:
            d.pop("featurizer", None)
        return d

    def __setstate__(self, d):
        """Setting state during reloading pickling"""
        self.__dict__.update(d)
        self._update_params()

    def fit(self, *args, **kwargs):
        return self

    def _convert(self, inputs: list, **kwargs):
        """Convert molecules to the right format

        In `transform` and `_transform`, this function is called before calling `_embed`

        Args:
            inputs: inputs to preprocess

        Returns:
            processed: pre-processed input list
        """
        if not self._require_mols:
            inputs = [dm.to_smiles(m) for m in inputs]
        return inputs

    def preprocess(self, inputs: list, labels: Optional[list] = None):
        """Run preprocessing on the input data
        Args:
            inputs: list of input molecules
            labels: list of labels
        """
        out = super().preprocess(inputs, labels)
        if self.precompute_cache not in [False, None]:
            try:
                self.transform(inputs)
            except Exception:
                pass
        return out

    def _transform(self, mol: dm.Mol, **kwargs):
        r"""
        Compute features for a single molecule.
        This method would potentially need to be reimplemented by any child class

        Args:
            mol (dm.Mol): molecule to transform into features

        Returns
            feat: featurized input molecule

        """
        feat = None
        if self.precompute_cache is not None:
            feat = self.precompute_cache.get(mol)
        if feat is None:
            try:
                mols = [dm.to_mol(mol)]
                mols = self._convert(mols, **kwargs)
                feat = self._embed(mols, **kwargs)
                feat = feat[0]
            except Exception as e:
                if self.verbose:
                    logger.error(e)

            if self.precompute_cache is not None:
                self.precompute_cache[mol] = feat
        return feat

    def transform(self, smiles: List[str], **kwargs):
        """Perform featurization of the input molecules

        The featurization process is as follow:
        1. convert the input molecules into the right format, expected by the pre-trained model using `_convert`
        2. compute embedding of the molecule using `_embed`
        3. perform any model-specific postprocessing and cache update

        The dtype returned is the native datatype of the transformer.
        Use `__call__` to get the dtype in the `dtype` attribute format

        Args:
            mols: a list containing smiles or mol objects

        Returns:
            out: featurized molecules
        """
        if isinstance(smiles, str) or not isinstance(smiles, Iterable):
            smiles = [smiles]

        n_mols = len(smiles)
        ind_to_compute = dict(zip(range(n_mols), range(n_mols)))
        pre_computed = [None] * n_mols

        if self.precompute_cache not in [False, None]:
            ind_to_compute = {}
            pre_computed = self.precompute_cache.fetch(smiles)
            ind = 0
            for i, v in enumerate(pre_computed):
                if v is None:
                    ind_to_compute[i] = ind
                    ind += 1

        parallel_kwargs = getattr(self, "parallel_kwargs", {})
        mols = dm.parallelized(
            dm.to_mol, smiles, n_jobs=getattr(self, "n_jobs", 1), **parallel_kwargs
        )
        mols = [mols[i] for i in ind_to_compute]

        if len(mols) > 0:
            converted_mols = self._convert(mols, **kwargs)
            out = self._embed(converted_mols, **kwargs)

            if not isinstance(out, list):
                out = list(out)

            if self.precompute_cache is not None:
                # cache value now
                self.precompute_cache.update(dict(zip(mols, out)))
        out = [
            out[ind_to_compute[i]] if i in ind_to_compute else pre_computed[i]
            for i in range(n_mols)
        ]
        return datatype.as_numpy_array_if_possible(out, self.dtype)

    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return str(self) == str(other)
        return False

    def _update_params(self):
        self._fitted = False

    def __len__(self):
        if self._feat_length is None:
            self._preload()
            tmp_mol = dm.to_mol("CCC")
            embs = self._transform(tmp_mol)
            self._feat_length = len(embs)
        return self._feat_length

    def __ne__(self, other):
        return not (self == other)

    def __hash__(self):
        return hash(repr(self))

    def __repr__(self):
        return "{}(dtype={})".format(
            self.__class__.__name__,
            _parse_to_evaluable_str(self.dtype),
        )

__getstate__()

Getting state to allow pickling

Source code in molfeat/trans/pretrained/base.py
93
94
95
96
97
98
99
def __getstate__(self):
    """Getting state to allow pickling"""
    d = copy.deepcopy(self.__dict__)
    d["precompute_cache"] = None
    if isinstance(getattr(self, "featurizer", None), PretrainedModel) or self.preload:
        d.pop("featurizer", None)
    return d

__setstate__(d)

Setting state during reloading pickling

Source code in molfeat/trans/pretrained/base.py
101
102
103
104
def __setstate__(self, d):
    """Setting state during reloading pickling"""
    self.__dict__.update(d)
    self._update_params()

preprocess(inputs, labels=None)

Run preprocessing on the input data Args: inputs: list of input molecules labels: list of labels

Source code in molfeat/trans/pretrained/base.py
124
125
126
127
128
129
130
131
132
133
134
135
136
def preprocess(self, inputs: list, labels: Optional[list] = None):
    """Run preprocessing on the input data
    Args:
        inputs: list of input molecules
        labels: list of labels
    """
    out = super().preprocess(inputs, labels)
    if self.precompute_cache not in [False, None]:
        try:
            self.transform(inputs)
        except Exception:
            pass
    return out

set_cache(cache)

Set the cache for the transformer

Parameters:

Name Type Description Default
cache DataCache

cache object

required
Source code in molfeat/trans/pretrained/base.py
62
63
64
65
66
67
68
def set_cache(self, cache: DataCache):
    """Set the cache for the transformer

    Args:
        cache: cache object
    """
    self.precompute_cache = cache

transform(smiles, **kwargs)

Perform featurization of the input molecules

The featurization process is as follow: 1. convert the input molecules into the right format, expected by the pre-trained model using _convert 2. compute embedding of the molecule using _embed 3. perform any model-specific postprocessing and cache update

The dtype returned is the native datatype of the transformer. Use __call__ to get the dtype in the dtype attribute format

Parameters:

Name Type Description Default
mols

a list containing smiles or mol objects

required

Returns:

Name Type Description
out

featurized molecules

Source code in molfeat/trans/pretrained/base.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def transform(self, smiles: List[str], **kwargs):
    """Perform featurization of the input molecules

    The featurization process is as follow:
    1. convert the input molecules into the right format, expected by the pre-trained model using `_convert`
    2. compute embedding of the molecule using `_embed`
    3. perform any model-specific postprocessing and cache update

    The dtype returned is the native datatype of the transformer.
    Use `__call__` to get the dtype in the `dtype` attribute format

    Args:
        mols: a list containing smiles or mol objects

    Returns:
        out: featurized molecules
    """
    if isinstance(smiles, str) or not isinstance(smiles, Iterable):
        smiles = [smiles]

    n_mols = len(smiles)
    ind_to_compute = dict(zip(range(n_mols), range(n_mols)))
    pre_computed = [None] * n_mols

    if self.precompute_cache not in [False, None]:
        ind_to_compute = {}
        pre_computed = self.precompute_cache.fetch(smiles)
        ind = 0
        for i, v in enumerate(pre_computed):
            if v is None:
                ind_to_compute[i] = ind
                ind += 1

    parallel_kwargs = getattr(self, "parallel_kwargs", {})
    mols = dm.parallelized(
        dm.to_mol, smiles, n_jobs=getattr(self, "n_jobs", 1), **parallel_kwargs
    )
    mols = [mols[i] for i in ind_to_compute]

    if len(mols) > 0:
        converted_mols = self._convert(mols, **kwargs)
        out = self._embed(converted_mols, **kwargs)

        if not isinstance(out, list):
            out = list(out)

        if self.precompute_cache is not None:
            # cache value now
            self.precompute_cache.update(dict(zip(mols, out)))
    out = [
        out[ind_to_compute[i]] if i in ind_to_compute else pre_computed[i]
        for i in range(n_mols)
    ]
    return datatype.as_numpy_array_if_possible(out, self.dtype)