Skip to content

molfeat.trans.base

BaseFeaturizer

Bases: BaseEstimator

Molecule featurizer base class that needs to be implemented by all featurizers. This featurizer is compatible with scikit-learn estimators and thus can be plugged into a pipeline

Source code in molfeat/trans/base.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
class BaseFeaturizer(BaseEstimator):
    """
    Molecule featurizer base class that needs to be implemented by all featurizers.
    This featurizer is compatible with scikit-learn estimators and thus can be plugged into a pipeline
    """

    def __init__(
        self,
        n_jobs: int = 1,
        verbose: bool = True,
        dtype: Optional[Union[str, Callable]] = None,
        parallel_kwargs: Optional[Dict[str, Any]] = None,
        **params,
    ):
        self._n_jobs = n_jobs
        self.dtype = dtype
        self.verbose = verbose
        self.parallel_kwargs = parallel_kwargs or {}
        for k, v in params.items():
            setattr(self, k, v)
        self._input_params = dict(n_jobs=n_jobs, dtype=dtype, verbose=verbose, **params)

    @property
    def n_jobs(self):
        """Get the number of concurrent jobs to run with this featurizer"""
        return self._n_jobs

    @n_jobs.setter
    def n_jobs(self, val):
        if val >= 1:
            self._n_jobs = val
        elif val == -1:
            self._n_jobs = joblib.cpu_count()

    def _get_param_names(self):
        """Get parameter names for the estimator"""
        return self._input_params.keys()

    def _update_params(self):
        """Update parameters of the current estimator"""
        ...

    def set_params(self, **params):
        """Set the parameters of this estimator.

        Returns:
            self: estimator instance
        """
        super().set_params(**params)
        for k, v in params.items():
            if k in self._input_params:
                self._input_params[k] = v
        self._update_params()
        return self

    def copy(self):
        """Return a copy of this object."""
        copy_obj = self.__class__(**self._input_params)
        for k, v in self.__dict__.items():
            if not hasattr(copy_obj, k):
                setattr(copy_obj, k, copy.deepcopy(v))
        return copy_obj

    def preprocess(self, inputs: list, labels: Optional[list] = None):
        """Preprocess input

        Args:
            inputs: inputs to preprocess
            labels: labels to preprocess (optional)

        Returns:
            processed: pre-processed input list
        """
        return inputs, labels

    def get_collate_fn(self, *args, **kwargs):
        """
        Get collate function of this featurizer. In the implementation of this function
        you should set the relevant attributes or argument of the underlying collate function
        (e.g via functools.partial) and return the function itself

        Returns:
            fn: Collate function for pytorch or None
        """
        return None

n_jobs property writable

Get the number of concurrent jobs to run with this featurizer

copy()

Return a copy of this object.

Source code in molfeat/trans/base.py
116
117
118
119
120
121
122
def copy(self):
    """Return a copy of this object."""
    copy_obj = self.__class__(**self._input_params)
    for k, v in self.__dict__.items():
        if not hasattr(copy_obj, k):
            setattr(copy_obj, k, copy.deepcopy(v))
    return copy_obj

get_collate_fn(*args, **kwargs)

Get collate function of this featurizer. In the implementation of this function you should set the relevant attributes or argument of the underlying collate function (e.g via functools.partial) and return the function itself

Returns:

Name Type Description
fn

Collate function for pytorch or None

Source code in molfeat/trans/base.py
136
137
138
139
140
141
142
143
144
145
def get_collate_fn(self, *args, **kwargs):
    """
    Get collate function of this featurizer. In the implementation of this function
    you should set the relevant attributes or argument of the underlying collate function
    (e.g via functools.partial) and return the function itself

    Returns:
        fn: Collate function for pytorch or None
    """
    return None

preprocess(inputs, labels=None)

Preprocess input

Parameters:

Name Type Description Default
inputs list

inputs to preprocess

required
labels Optional[list]

labels to preprocess (optional)

None

Returns:

Name Type Description
processed

pre-processed input list

Source code in molfeat/trans/base.py
124
125
126
127
128
129
130
131
132
133
134
def preprocess(self, inputs: list, labels: Optional[list] = None):
    """Preprocess input

    Args:
        inputs: inputs to preprocess
        labels: labels to preprocess (optional)

    Returns:
        processed: pre-processed input list
    """
    return inputs, labels

set_params(**params)

Set the parameters of this estimator.

Returns:

Name Type Description
self

estimator instance

Source code in molfeat/trans/base.py
103
104
105
106
107
108
109
110
111
112
113
114
def set_params(self, **params):
    """Set the parameters of this estimator.

    Returns:
        self: estimator instance
    """
    super().set_params(**params)
    for k, v in params.items():
        if k in self._input_params:
            self._input_params[k] = v
    self._update_params()
    return self

MoleculeTransformer

Bases: TransformerMixin, BaseFeaturizer

Base class for molecular data transformer such as Fingerprinter etc. If you create a subclass of this featurizer, you will need to make sure that the input argument of the init are kept as is in the object attributes.

Note

The transformer supports a variety of datatype, they are only enforced when passing the enforce_dtype=True attributes in __call__. For pandas dataframes, use 'pandas'|'df'|'dataframe'|pd.DataFrame

Using a custom Calculator

You can use your own calculator for featurization. It's recommended to subclass molfeat.calc.base.SerializableCalculator If you calculator also implements a batch_compute method, it will be used for batch featurization and parallelization options will be passed to it.

Source code in molfeat/trans/base.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
class MoleculeTransformer(TransformerMixin, BaseFeaturizer, metaclass=_TransformerMeta):
    """
    Base class for molecular data transformer such as Fingerprinter etc.
    If you create a subclass of this featurizer, you will need to make sure that the
    input argument of the init are kept as is in the object attributes.

    !!! note
        The transformer supports a variety of datatype, they are only enforced when passing the
        `enforce_dtype=True` attributes in `__call__`. For pandas dataframes, use `'pandas'|'df'|'dataframe'|pd.DataFrame`

    ???+ tip "Using a custom Calculator"
        You can use your own calculator for featurization. It's recommended to subclass `molfeat.calc.base.SerializableCalculator`
        If you calculator also implements a `batch_compute` method, it will be used for batch featurization and parallelization options will be passed to it.
    """

    def __init__(
        self,
        featurizer: Union[str, Callable],
        n_jobs: int = 1,
        verbose: bool = False,
        dtype: Optional[Union[str, Callable]] = None,
        parallel_kwargs: Optional[Dict[str, Any]] = None,
        **params,
    ):
        """Mol transformer base class

        Args:
            featurizer: featurizer to use
            n_jobs (int, optional): Number of job to run in parallel. Defaults to 1.
            verbose (bool, optional): Verbosity level. Defaults to True.
            dtype (callable, optional): Output data type. Defaults to None, where numpy arrays are returned.
            parallel_kwargs (dict, optional): Optional kwargs to pass to the dm.parallelized function. Defaults to None.

        """
        super().__init__(
            n_jobs=n_jobs,
            verbose=verbose,
            dtype=dtype,
            featurizer=featurizer,
            parallel_kwargs=parallel_kwargs,
            **params,
        )
        if callable(featurizer):
            self.featurizer = featurizer
        else:
            self.featurizer = get_calculator(featurizer, **params)

        self.cols_to_keep = None
        self._fitted = False

        self._save_input_args()
        if self.featurizer and not (
            isinstance(self.featurizer, str) or is_callable(self.featurizer)
        ):
            raise AttributeError(f"Featurizer {self.featurizer} must be a callable or a string")

    def _save_input_args(self):
        """Save the input arguments of a transformer to the attribute
        `_input_args` of the object.
        """

        # NOTE(hadim): don't override existing _input_args so
        # it's possible to use MoleculeTransformer as a featurizer
        # instead of simply a base class.
        if not hasattr(self, "_input_args"):
            self._input_args = get_input_args()

    def _update_params(self):
        if not callable(self.featurizer):
            params = copy.deepcopy(self._input_params)
            params.pop("featurizer")
            self.featurizer = get_calculator(self.featurizer, **params)
        self._fitted = False

    def __setstate__(self, state):
        state.pop("callbacks", None)
        self.__dict__.update(state)
        self.__dict__["parallel_kwargs"] = state.get("parallel_kwargs", {})
        self._update_params()

    def fit(self, X: List[Union[dm.Mol, str]], y: Optional[list] = None, **fit_params):
        """Fit the current transformer on given dataset.

        The goal of fitting is for example to identify nan columns values
        that needs to be removed from the dataset

        Args:
            X: input list of molecules
            y (list, optional): Optional list of molecular properties. Defaults to None.

        Returns:
            self: MolTransformer instance after fitting
        """
        feats = self.transform(X, ignore_errors=True)
        lengths = [len(x) for x in feats if not datatype.is_null(x)]
        if lengths:
            # we will ignore all nan
            feats = datatype.to_numpy([f for f in feats if not datatype.is_null(f)])
            self.cols_to_keep = (~np.any(np.isnan(feats), axis=0)).nonzero()[0]
        self._fitted = True
        return self

    def _transform(self, mol: dm.Mol):
        r"""
        Compute features for a single molecule.
        This method would potentially need to be reimplemented by child classes

        Args:
            mol (dm.Mol): molecule to transform into features

        Returns
            feat: featurized input molecule

        """
        feat = None
        try:
            feat = datatype.to_numpy(self.featurizer(mol))
            if self.cols_to_keep is not None:
                feat = feat[self.cols_to_keep]
        except Exception as e:
            if self.verbose:
                logger.error(e)
        return feat

    def transform(
        self,
        mols: List[Union[dm.Mol, str]],
        ignore_errors: bool = False,
        **kwargs,
    ):
        r"""
        Compute the features for a set of molecules.

        !!! note
            Note that depending on the `ignore_errors` argument, all failed
            featurization (caused whether by invalid smiles or error during
            data transformation) will be substitued by None features for the
            corresponding molecule. This is done, so you can find the positions
            of these molecules and filter them out according to your own logic.

        Args:
            mols: a list containing smiles or mol objects
            ignore_errors (bool, optional): Whether to silently ignore errors


        Returns:
            features: a list of features for each molecule in the input set
        """
        # Convert single mol to iterable format
        if isinstance(mols, pd.DataFrame):
            mols = mols[mols.columns[0]]
        if isinstance(mols, (str, dm.Mol)) or not isinstance(mols, Iterable):
            mols = [mols]

        def _to_mol(x):
            return dm.to_mol(x) if x else None

        parallel_kwargs = getattr(self, "parallel_kwargs", {})

        if hasattr(self.featurizer, "batch_compute") and callable(self.featurizer.batch_compute):
            # this calculator can be batched which will be faster
            features = self.featurizer.batch_compute(mols, n_jobs=self.n_jobs, **parallel_kwargs)
        else:
            mols = dm.parallelized(_to_mol, mols, n_jobs=self.n_jobs, **parallel_kwargs)
            if self.n_jobs not in [0, 1]:
                # use a proxy model to run in parallel
                cpy = self.copy()
                features = dm.parallelized(
                    cpy._transform,
                    mols,
                    n_jobs=self.n_jobs,
                    **cpy.parallel_kwargs,
                )
            else:
                features = [self._transform(mol) for mol in mols]
        if not ignore_errors:
            for ind, feat in enumerate(features):
                if feat is None:
                    raise ValueError(
                        f"Cannot transform molecule at index {ind}. Please check logs (set verbose to True) to see errors!"
                    )

        # sklearn feature validation for sklearn pipeline
        return datatype.as_numpy_array_if_possible(features, self.dtype)

    def __len__(self):
        """Compute featurizer length"""

        # check length and _length attribute
        cols_to_keep = getattr(self, "cols_to_keep", None)
        cur_length = None

        if cols_to_keep is not None:
            cur_length = len(cols_to_keep)
        else:
            cur_length = getattr(self, "length", getattr(self, "_length", None))
            # then check the featurizer length if it's a callable and not a string/None
            if (
                cur_length is None
                and callable(self.featurizer)
                and hasattr(self.featurizer, "__len__")
            ):
                cur_length = len(self.featurizer)

        if cur_length is None:
            raise ValueError(
                f"Cannot auto-determine length of this MolTransformer: {self.__class__.__name__}"
            )

        return cur_length

    def __call__(
        self,
        mols: List[Union[dm.Mol, str]],
        enforce_dtype: bool = True,
        ignore_errors: bool = False,
        **kwargs,
    ):
        r"""
        Calculate features for molecules. Using __call__, instead of transform.
        If ignore_error is True, a list of features and valid ids are returned.
        Note that most Transfomers allow you to specify
        a return datatype.

        Args:
            mols:  Mol or SMILES of the molecules to be transformed
            enforce_dtype: whether to enforce the instance dtype in the generated fingerprint
            ignore_errors: Whether to ignore errors during featurization or raise an error.
            kwargs: Named parameters for the transform method

        Returns:
            feats: list of valid features
            ids: all valid molecule positions that did not failed during featurization.
                Only returned when ignore_errors is True.

        """
        features = self.transform(mols, ignore_errors=ignore_errors, enforce_dtype=False, **kwargs)
        ids = np.arange(len(features))
        if ignore_errors:
            features, ids = self._filter_none(features)
        if self.dtype is not None and enforce_dtype:
            features = datatype.cast(features, dtype=self.dtype, columns=self.columns)
        if not ignore_errors:
            return features
        return features, ids

    @staticmethod
    def _filter_none(features):
        ids_bad = []
        # If the features are a list, filter the None ids
        if isinstance(features, (tuple, list, np.ndarray)):
            for f_id, feat in enumerate(features):
                if datatype.is_null(feat):
                    ids_bad.append(f_id)
            ids_to_keep = [
                this_id for this_id in np.arange(0, len(features)) if this_id not in ids_bad
            ]
            features = [features[ii] for ii in ids_to_keep]

        # If the features are a dict or DataFrame, filter the ids when any key id is None
        elif isinstance(features, (dict, pd.DataFrame)):
            if isinstance(features, dict):
                features = pd.DataFrame(features)
            for feat_col in features.columns:
                for f_id, feat in enumerate(features[feat_col].values.flatten()):
                    if feat is None:
                        ids_bad.append(f_id)
            ids_bad = np.unique(ids_bad).tolist()
            all_ids = np.arange(0, features.shape[0])
            ids_to_keep = [this_id for this_id in all_ids if this_id not in ids_bad]
            features = features.iloc[ids_to_keep, :]

        else:
            ids_to_keep = np.arange(0, features.shape[0])
        return features, list(ids_to_keep)

    @property
    def columns(self):
        """Get the list of columns for this molecular descriptor

        Returns:
            columns (list): Name of the columns of the descriptor
        """
        columns = getattr(self.featurizer, "columns", None)
        cols_to_keep = getattr(self, "cols_to_keep", None)
        if columns is not None and cols_to_keep is not None and len(cols_to_keep) > 0:
            columns = [columns[i] for i in cols_to_keep]
        return columns

    @staticmethod
    def batch_transform(
        transformer: Callable,
        mols: List[Union[dm.Mol, str]],
        batch_size: int = 256,
        n_jobs: Optional[int] = None,
        concatenate: bool = True,
        progress: bool = True,
        leave_progress: bool = False,
        **parallel_kwargs,
    ):
        """
        Batched computation of featurization of a list of molecules

        Args:
            transformer: Fingerprint transformer
            mols: List of molecules to featurize
            batch_size: Batch size
            n_jobs: number of jobs to run in parallel
            concatenate: Whether to concatenate the results or return the list of batched results
            progress: whether to show progress bar
            leave_progress: whether to leave progress bar after completion
            parallel_kwargs: additional arguments to pass to dm.parallelized

        Returns:
            List of batches
        """

        step_size = int(np.ceil(len(mols) / batch_size))
        batched_mols = np.array_split(mols, step_size)

        tqdm_kwargs = parallel_kwargs.setdefault("tqdm_kwargs", {})
        tqdm_kwargs.update(leave=leave_progress, desc="Batch compute:")
        parallel_kwargs["tqdm_kwargs"] = tqdm_kwargs

        # it's recommended to use a precomputed molecule transformer
        # instead of the internal cache for pretrained models
        cache_attr = "cache"
        existing_cache = getattr(transformer, cache_attr, None)
        if existing_cache is None:
            cache_attr = "precompute_cache"
            existing_cache = getattr(transformer, cache_attr, None)

        use_mp_cache = (
            existing_cache is not None
            and not isinstance(existing_cache, MPDataCache)
            and n_jobs not in [None, 0, 1]  # this is based on datamol sequential vs parallel
        )
        if use_mp_cache:
            # we need to change the cache system to one that works with multiprocessing
            # to have a shared memory
            new_cache = MPDataCache()
            new_cache.update(existing_cache)
            setattr(transformer, cache_attr, new_cache)

        transformed = dm.parallelized(
            transformer,
            batched_mols,
            n_jobs=n_jobs,
            progress=progress,
            **parallel_kwargs,
        )
        if use_mp_cache:
            # we set back the original transformation while updating it with
            # all the missing values
            existing_cache.update(getattr(transformer, cache_attr, {}))
            setattr(transformer, cache_attr, existing_cache)

        if concatenate:
            # if we ask for concatenation, then we would need to fix None values ideally
            fixed_transformations = []
            for computed_trans in transformed:
                if computed_trans is None:
                    computed_trans = np.full(len(computed_trans), len(transformer), np.nan)
                else:
                    for i, x in enumerate(computed_trans):
                        if x is None:
                            computed_trans[i] = np.full(len(transformer), np.nan)
                fixed_transformations.append(computed_trans)
            return np.concatenate(fixed_transformations)
        return transformed

    # Featurizer to state methods

    def to_state_dict(self) -> dict:
        """Serialize the featurizer to a state dict."""

        if getattr(self, "_input_args") is None:
            raise ValueError(f"Cannot save state for this transformer '{self.__class__.__name__}'")

        # Process the input arguments before building the state
        args = copy.deepcopy(self._input_args)

        # Deal with dtype
        if "dtype" in args and not isinstance(args["dtype"], str):
            args["dtype"] = map_dtype(args["dtype"])

        ## Deal with graph atom/bond featurizers
        # NOTE(hadim): it's important to highlight that atom/bond featurizers can't be
        # customized with this logic.
        if args.get("atom_featurizer") is not None:
            if hasattr(args.get("atom_featurizer"), "to_state_dict"):
                args["atom_featurizer"] = args["atom_featurizer"].to_state_dict()
                args["_atom_featurizer_is_pickled"] = False
            else:
                logger.warning
                (
                    "You are attempting to pickle an atom featurizer without a `to_state_dict` function into a hex string"
                )
                args["atom_featurizer"] = fn_to_hex(args["atom_featurizer"])
                args["_atom_featurizer_is_pickled"] = True

        # deal with bond featurizer
        if args.get("bond_featurizer") is not None:
            if hasattr(args.get("bond_featurizer"), "to_state_dict"):
                args["bond_featurizer"] = args["bond_featurizer"].to_state_dict()
                args["_bond_featurizer_is_pickled"] = False
            else:
                logger.warning(
                    "You are attempting to pickle a bond featurizer without a `to_state_dict` function into a hex string"
                )
                args["bond_featurizer"] = fn_to_hex(args["bond_featurizer"])
                args["_bond_featurizer_is_pickled"] = True

        ## Deal with custom featurizer
        if "featurizer" in args and isinstance(args["featurizer"], Callable):
            if hasattr(args["featurizer"], "to_state_dict"):
                args["featurizer"] = args["featurizer"].to_state_dict()
                args["_featurizer_is_pickled"] = False
            else:
                logger.warning(
                    "You are attempting to pickle a callable without a `to_state_dict` function into a hex string"
                )
                args["featurizer"] = fn_to_hex(args["featurizer"])
                args["_featurizer_is_pickled"] = True

        # Build the state
        state = {}
        state["name"] = self.__class__.__name__
        state["args"] = args
        state["_molfeat_version"] = MOLFEAT_VERSION
        return state

    def to_state_json(self) -> str:
        return json.dumps(self.to_state_dict())

    def to_state_yaml(self) -> str:
        return yaml.dump(self.to_state_dict(), Dumper=yaml.SafeDumper)

    def to_state_json_file(self, filepath: Union[str, Path]):
        with fsspec.open(filepath, "w") as f:
            f.write(self.to_state_json())  # type: ignore

    def to_state_yaml_file(self, filepath: Union[str, Path]):
        with fsspec.open(filepath, "w") as f:
            f.write(self.to_state_yaml())  # type: ignore

    # State to featurizer methods

    @staticmethod
    def from_state_dict(state: dict, override_args: Optional[dict] = None) -> "MoleculeTransformer":
        """Reload a featurizer from a state dict."""

        # Don't alter the original state dict
        state = copy.deepcopy(state)

        # MoleculeTransformer is a special case that has his own logic
        if state["name"] == "PrecomputedMolTransformer":
            return PrecomputedMolTransformer.from_state_dict(
                state=state,
                override_args=override_args,
            )

        # Get the name
        transformer_class = _TRANSFORMERS.get(state["name"])
        if transformer_class is None:
            raise ValueError(f"The featurizer '{state['name']}' is not supported.")
        if isinstance(transformer_class, str):
            # Get the transformer class from its path
            transformer_class = import_from_string(transformer_class)

        # Process the state as needed
        args = state.get("args", {})

        # Deal with dtype
        if "dtype" in args and isinstance(args["dtype"], str):
            args["dtype"] = map_dtype(args["dtype"])

        ## Deal with graph atom/bond featurizers
        if args.get("atom_featurizer") is not None:
            if not args.get("_atom_featurizer_is_pickled"):
                klass_name = args["atom_featurizer"].get("name")
                args["atom_featurizer"] = ATOM_FEATURIZER_MAPPING_REVERSE[
                    klass_name
                ].from_state_dict(args["atom_featurizer"])
            else:
                # buffer = io.BytesIO(bytes.fromhex(args["atom_featurizer"]))
                # args["atom_featurizer"] = joblib.load(buffer)
                args["atom_featurizer"] = hex_to_fn(args["atom_featurizer"])
            args.pop("_atom_featurizer_is_pickled", None)
        if args.get("bond_featurizer") is not None:
            if not args.get("_bond_featurizer_is_pickled"):
                klass_name = args["bond_featurizer"].get("name")
                args["bond_featurizer"] = BOND_FEATURIZER_MAPPING_REVERSE[
                    klass_name
                ].from_state_dict(args["bond_featurizer"])
            else:
                args["bond_featurizer"] = hex_to_fn(args["bond_featurizer"])
            args.pop("_bond_featurizer_is_pickled", None)
        ## Deal with custom featurizer
        if "featurizer" in args:
            if args.get("_featurizer_is_pickled") is True:
                args["featurizer"] = hex_to_fn(args["featurizer"])
                args.pop("_featurizer_is_pickled")
            elif (
                isinstance(args["featurizer"], Mapping)
                and args["featurizer"].get("name") in _CALCULATORS
            ):
                # we have found a known calculator
                klass_name = args["featurizer"].get("name")
                args["featurizer"] = _CALCULATORS[klass_name].from_state_dict(args["featurizer"])
                args.pop("_featurizer_is_pickled")

        if override_args is not None:
            args.update(override_args)

        # Create the transformer
        featurizer = transformer_class(**args)
        return featurizer

    @staticmethod
    def from_state_json(
        state_json: str,
        override_args: Optional[dict] = None,
    ) -> "MoleculeTransformer":
        state_dict = json.loads(state_json)
        return MoleculeTransformer.from_state_dict(state_dict, override_args=override_args)

    @staticmethod
    def from_state_yaml(
        state_yaml: str,
        override_args: Optional[dict] = None,
    ) -> "MoleculeTransformer":
        state_dict = yaml.load(state_yaml, Loader=yaml.SafeLoader)
        return MoleculeTransformer.from_state_dict(state_dict, override_args=override_args)

    @staticmethod
    def from_state_json_file(
        filepath: Union[str, Path],
        override_args: Optional[dict] = None,
    ) -> "MoleculeTransformer":
        with fsspec.open(filepath, "r") as f:
            featurizer = MoleculeTransformer.from_state_json(f.read(), override_args=override_args)  # type: ignore
        return featurizer

    @staticmethod
    def from_state_yaml_file(
        filepath: Union[str, Path],
        override_args: Optional[dict] = None,
    ) -> "MoleculeTransformer":
        with fsspec.open(filepath, "r") as f:
            featurizer = MoleculeTransformer.from_state_yaml(f.read(), override_args=override_args)  # type: ignore
        return featurizer

columns property

Get the list of columns for this molecular descriptor

Returns:

Name Type Description
columns list

Name of the columns of the descriptor

__call__(mols, enforce_dtype=True, ignore_errors=False, **kwargs)

Calculate features for molecules. Using call, instead of transform. If ignore_error is True, a list of features and valid ids are returned. Note that most Transfomers allow you to specify a return datatype.

Parameters:

Name Type Description Default
mols List[Union[Mol, str]]

Mol or SMILES of the molecules to be transformed

required
enforce_dtype bool

whether to enforce the instance dtype in the generated fingerprint

True
ignore_errors bool

Whether to ignore errors during featurization or raise an error.

False
kwargs

Named parameters for the transform method

{}

Returns:

Name Type Description
feats

list of valid features

ids

all valid molecule positions that did not failed during featurization. Only returned when ignore_errors is True.

Source code in molfeat/trans/base.py
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
def __call__(
    self,
    mols: List[Union[dm.Mol, str]],
    enforce_dtype: bool = True,
    ignore_errors: bool = False,
    **kwargs,
):
    r"""
    Calculate features for molecules. Using __call__, instead of transform.
    If ignore_error is True, a list of features and valid ids are returned.
    Note that most Transfomers allow you to specify
    a return datatype.

    Args:
        mols:  Mol or SMILES of the molecules to be transformed
        enforce_dtype: whether to enforce the instance dtype in the generated fingerprint
        ignore_errors: Whether to ignore errors during featurization or raise an error.
        kwargs: Named parameters for the transform method

    Returns:
        feats: list of valid features
        ids: all valid molecule positions that did not failed during featurization.
            Only returned when ignore_errors is True.

    """
    features = self.transform(mols, ignore_errors=ignore_errors, enforce_dtype=False, **kwargs)
    ids = np.arange(len(features))
    if ignore_errors:
        features, ids = self._filter_none(features)
    if self.dtype is not None and enforce_dtype:
        features = datatype.cast(features, dtype=self.dtype, columns=self.columns)
    if not ignore_errors:
        return features
    return features, ids

__init__(featurizer, n_jobs=1, verbose=False, dtype=None, parallel_kwargs=None, **params)

Mol transformer base class

Parameters:

Name Type Description Default
featurizer Union[str, Callable]

featurizer to use

required
n_jobs int

Number of job to run in parallel. Defaults to 1.

1
verbose bool

Verbosity level. Defaults to True.

False
dtype callable

Output data type. Defaults to None, where numpy arrays are returned.

None
parallel_kwargs dict

Optional kwargs to pass to the dm.parallelized function. Defaults to None.

None
Source code in molfeat/trans/base.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def __init__(
    self,
    featurizer: Union[str, Callable],
    n_jobs: int = 1,
    verbose: bool = False,
    dtype: Optional[Union[str, Callable]] = None,
    parallel_kwargs: Optional[Dict[str, Any]] = None,
    **params,
):
    """Mol transformer base class

    Args:
        featurizer: featurizer to use
        n_jobs (int, optional): Number of job to run in parallel. Defaults to 1.
        verbose (bool, optional): Verbosity level. Defaults to True.
        dtype (callable, optional): Output data type. Defaults to None, where numpy arrays are returned.
        parallel_kwargs (dict, optional): Optional kwargs to pass to the dm.parallelized function. Defaults to None.

    """
    super().__init__(
        n_jobs=n_jobs,
        verbose=verbose,
        dtype=dtype,
        featurizer=featurizer,
        parallel_kwargs=parallel_kwargs,
        **params,
    )
    if callable(featurizer):
        self.featurizer = featurizer
    else:
        self.featurizer = get_calculator(featurizer, **params)

    self.cols_to_keep = None
    self._fitted = False

    self._save_input_args()
    if self.featurizer and not (
        isinstance(self.featurizer, str) or is_callable(self.featurizer)
    ):
        raise AttributeError(f"Featurizer {self.featurizer} must be a callable or a string")

__len__()

Compute featurizer length

Source code in molfeat/trans/base.py
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
def __len__(self):
    """Compute featurizer length"""

    # check length and _length attribute
    cols_to_keep = getattr(self, "cols_to_keep", None)
    cur_length = None

    if cols_to_keep is not None:
        cur_length = len(cols_to_keep)
    else:
        cur_length = getattr(self, "length", getattr(self, "_length", None))
        # then check the featurizer length if it's a callable and not a string/None
        if (
            cur_length is None
            and callable(self.featurizer)
            and hasattr(self.featurizer, "__len__")
        ):
            cur_length = len(self.featurizer)

    if cur_length is None:
        raise ValueError(
            f"Cannot auto-determine length of this MolTransformer: {self.__class__.__name__}"
        )

    return cur_length

batch_transform(transformer, mols, batch_size=256, n_jobs=None, concatenate=True, progress=True, leave_progress=False, **parallel_kwargs) staticmethod

Batched computation of featurization of a list of molecules

Parameters:

Name Type Description Default
transformer Callable

Fingerprint transformer

required
mols List[Union[Mol, str]]

List of molecules to featurize

required
batch_size int

Batch size

256
n_jobs Optional[int]

number of jobs to run in parallel

None
concatenate bool

Whether to concatenate the results or return the list of batched results

True
progress bool

whether to show progress bar

True
leave_progress bool

whether to leave progress bar after completion

False
parallel_kwargs

additional arguments to pass to dm.parallelized

{}

Returns:

Type Description

List of batches

Source code in molfeat/trans/base.py
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
@staticmethod
def batch_transform(
    transformer: Callable,
    mols: List[Union[dm.Mol, str]],
    batch_size: int = 256,
    n_jobs: Optional[int] = None,
    concatenate: bool = True,
    progress: bool = True,
    leave_progress: bool = False,
    **parallel_kwargs,
):
    """
    Batched computation of featurization of a list of molecules

    Args:
        transformer: Fingerprint transformer
        mols: List of molecules to featurize
        batch_size: Batch size
        n_jobs: number of jobs to run in parallel
        concatenate: Whether to concatenate the results or return the list of batched results
        progress: whether to show progress bar
        leave_progress: whether to leave progress bar after completion
        parallel_kwargs: additional arguments to pass to dm.parallelized

    Returns:
        List of batches
    """

    step_size = int(np.ceil(len(mols) / batch_size))
    batched_mols = np.array_split(mols, step_size)

    tqdm_kwargs = parallel_kwargs.setdefault("tqdm_kwargs", {})
    tqdm_kwargs.update(leave=leave_progress, desc="Batch compute:")
    parallel_kwargs["tqdm_kwargs"] = tqdm_kwargs

    # it's recommended to use a precomputed molecule transformer
    # instead of the internal cache for pretrained models
    cache_attr = "cache"
    existing_cache = getattr(transformer, cache_attr, None)
    if existing_cache is None:
        cache_attr = "precompute_cache"
        existing_cache = getattr(transformer, cache_attr, None)

    use_mp_cache = (
        existing_cache is not None
        and not isinstance(existing_cache, MPDataCache)
        and n_jobs not in [None, 0, 1]  # this is based on datamol sequential vs parallel
    )
    if use_mp_cache:
        # we need to change the cache system to one that works with multiprocessing
        # to have a shared memory
        new_cache = MPDataCache()
        new_cache.update(existing_cache)
        setattr(transformer, cache_attr, new_cache)

    transformed = dm.parallelized(
        transformer,
        batched_mols,
        n_jobs=n_jobs,
        progress=progress,
        **parallel_kwargs,
    )
    if use_mp_cache:
        # we set back the original transformation while updating it with
        # all the missing values
        existing_cache.update(getattr(transformer, cache_attr, {}))
        setattr(transformer, cache_attr, existing_cache)

    if concatenate:
        # if we ask for concatenation, then we would need to fix None values ideally
        fixed_transformations = []
        for computed_trans in transformed:
            if computed_trans is None:
                computed_trans = np.full(len(computed_trans), len(transformer), np.nan)
            else:
                for i, x in enumerate(computed_trans):
                    if x is None:
                        computed_trans[i] = np.full(len(transformer), np.nan)
            fixed_transformations.append(computed_trans)
        return np.concatenate(fixed_transformations)
    return transformed

fit(X, y=None, **fit_params)

Fit the current transformer on given dataset.

The goal of fitting is for example to identify nan columns values that needs to be removed from the dataset

Parameters:

Name Type Description Default
X List[Union[Mol, str]]

input list of molecules

required
y list

Optional list of molecular properties. Defaults to None.

None

Returns:

Name Type Description
self

MolTransformer instance after fitting

Source code in molfeat/trans/base.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def fit(self, X: List[Union[dm.Mol, str]], y: Optional[list] = None, **fit_params):
    """Fit the current transformer on given dataset.

    The goal of fitting is for example to identify nan columns values
    that needs to be removed from the dataset

    Args:
        X: input list of molecules
        y (list, optional): Optional list of molecular properties. Defaults to None.

    Returns:
        self: MolTransformer instance after fitting
    """
    feats = self.transform(X, ignore_errors=True)
    lengths = [len(x) for x in feats if not datatype.is_null(x)]
    if lengths:
        # we will ignore all nan
        feats = datatype.to_numpy([f for f in feats if not datatype.is_null(f)])
        self.cols_to_keep = (~np.any(np.isnan(feats), axis=0)).nonzero()[0]
    self._fitted = True
    return self

from_state_dict(state, override_args=None) staticmethod

Reload a featurizer from a state dict.

Source code in molfeat/trans/base.py
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
@staticmethod
def from_state_dict(state: dict, override_args: Optional[dict] = None) -> "MoleculeTransformer":
    """Reload a featurizer from a state dict."""

    # Don't alter the original state dict
    state = copy.deepcopy(state)

    # MoleculeTransformer is a special case that has his own logic
    if state["name"] == "PrecomputedMolTransformer":
        return PrecomputedMolTransformer.from_state_dict(
            state=state,
            override_args=override_args,
        )

    # Get the name
    transformer_class = _TRANSFORMERS.get(state["name"])
    if transformer_class is None:
        raise ValueError(f"The featurizer '{state['name']}' is not supported.")
    if isinstance(transformer_class, str):
        # Get the transformer class from its path
        transformer_class = import_from_string(transformer_class)

    # Process the state as needed
    args = state.get("args", {})

    # Deal with dtype
    if "dtype" in args and isinstance(args["dtype"], str):
        args["dtype"] = map_dtype(args["dtype"])

    ## Deal with graph atom/bond featurizers
    if args.get("atom_featurizer") is not None:
        if not args.get("_atom_featurizer_is_pickled"):
            klass_name = args["atom_featurizer"].get("name")
            args["atom_featurizer"] = ATOM_FEATURIZER_MAPPING_REVERSE[
                klass_name
            ].from_state_dict(args["atom_featurizer"])
        else:
            # buffer = io.BytesIO(bytes.fromhex(args["atom_featurizer"]))
            # args["atom_featurizer"] = joblib.load(buffer)
            args["atom_featurizer"] = hex_to_fn(args["atom_featurizer"])
        args.pop("_atom_featurizer_is_pickled", None)
    if args.get("bond_featurizer") is not None:
        if not args.get("_bond_featurizer_is_pickled"):
            klass_name = args["bond_featurizer"].get("name")
            args["bond_featurizer"] = BOND_FEATURIZER_MAPPING_REVERSE[
                klass_name
            ].from_state_dict(args["bond_featurizer"])
        else:
            args["bond_featurizer"] = hex_to_fn(args["bond_featurizer"])
        args.pop("_bond_featurizer_is_pickled", None)
    ## Deal with custom featurizer
    if "featurizer" in args:
        if args.get("_featurizer_is_pickled") is True:
            args["featurizer"] = hex_to_fn(args["featurizer"])
            args.pop("_featurizer_is_pickled")
        elif (
            isinstance(args["featurizer"], Mapping)
            and args["featurizer"].get("name") in _CALCULATORS
        ):
            # we have found a known calculator
            klass_name = args["featurizer"].get("name")
            args["featurizer"] = _CALCULATORS[klass_name].from_state_dict(args["featurizer"])
            args.pop("_featurizer_is_pickled")

    if override_args is not None:
        args.update(override_args)

    # Create the transformer
    featurizer = transformer_class(**args)
    return featurizer

to_state_dict()

Serialize the featurizer to a state dict.

Source code in molfeat/trans/base.py
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
def to_state_dict(self) -> dict:
    """Serialize the featurizer to a state dict."""

    if getattr(self, "_input_args") is None:
        raise ValueError(f"Cannot save state for this transformer '{self.__class__.__name__}'")

    # Process the input arguments before building the state
    args = copy.deepcopy(self._input_args)

    # Deal with dtype
    if "dtype" in args and not isinstance(args["dtype"], str):
        args["dtype"] = map_dtype(args["dtype"])

    ## Deal with graph atom/bond featurizers
    # NOTE(hadim): it's important to highlight that atom/bond featurizers can't be
    # customized with this logic.
    if args.get("atom_featurizer") is not None:
        if hasattr(args.get("atom_featurizer"), "to_state_dict"):
            args["atom_featurizer"] = args["atom_featurizer"].to_state_dict()
            args["_atom_featurizer_is_pickled"] = False
        else:
            logger.warning
            (
                "You are attempting to pickle an atom featurizer without a `to_state_dict` function into a hex string"
            )
            args["atom_featurizer"] = fn_to_hex(args["atom_featurizer"])
            args["_atom_featurizer_is_pickled"] = True

    # deal with bond featurizer
    if args.get("bond_featurizer") is not None:
        if hasattr(args.get("bond_featurizer"), "to_state_dict"):
            args["bond_featurizer"] = args["bond_featurizer"].to_state_dict()
            args["_bond_featurizer_is_pickled"] = False
        else:
            logger.warning(
                "You are attempting to pickle a bond featurizer without a `to_state_dict` function into a hex string"
            )
            args["bond_featurizer"] = fn_to_hex(args["bond_featurizer"])
            args["_bond_featurizer_is_pickled"] = True

    ## Deal with custom featurizer
    if "featurizer" in args and isinstance(args["featurizer"], Callable):
        if hasattr(args["featurizer"], "to_state_dict"):
            args["featurizer"] = args["featurizer"].to_state_dict()
            args["_featurizer_is_pickled"] = False
        else:
            logger.warning(
                "You are attempting to pickle a callable without a `to_state_dict` function into a hex string"
            )
            args["featurizer"] = fn_to_hex(args["featurizer"])
            args["_featurizer_is_pickled"] = True

    # Build the state
    state = {}
    state["name"] = self.__class__.__name__
    state["args"] = args
    state["_molfeat_version"] = MOLFEAT_VERSION
    return state

transform(mols, ignore_errors=False, **kwargs)

Compute the features for a set of molecules.

Note

Note that depending on the ignore_errors argument, all failed featurization (caused whether by invalid smiles or error during data transformation) will be substitued by None features for the corresponding molecule. This is done, so you can find the positions of these molecules and filter them out according to your own logic.

Parameters:

Name Type Description Default
mols List[Union[Mol, str]]

a list containing smiles or mol objects

required
ignore_errors bool

Whether to silently ignore errors

False

Returns:

Name Type Description
features

a list of features for each molecule in the input set

Source code in molfeat/trans/base.py
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
def transform(
    self,
    mols: List[Union[dm.Mol, str]],
    ignore_errors: bool = False,
    **kwargs,
):
    r"""
    Compute the features for a set of molecules.

    !!! note
        Note that depending on the `ignore_errors` argument, all failed
        featurization (caused whether by invalid smiles or error during
        data transformation) will be substitued by None features for the
        corresponding molecule. This is done, so you can find the positions
        of these molecules and filter them out according to your own logic.

    Args:
        mols: a list containing smiles or mol objects
        ignore_errors (bool, optional): Whether to silently ignore errors


    Returns:
        features: a list of features for each molecule in the input set
    """
    # Convert single mol to iterable format
    if isinstance(mols, pd.DataFrame):
        mols = mols[mols.columns[0]]
    if isinstance(mols, (str, dm.Mol)) or not isinstance(mols, Iterable):
        mols = [mols]

    def _to_mol(x):
        return dm.to_mol(x) if x else None

    parallel_kwargs = getattr(self, "parallel_kwargs", {})

    if hasattr(self.featurizer, "batch_compute") and callable(self.featurizer.batch_compute):
        # this calculator can be batched which will be faster
        features = self.featurizer.batch_compute(mols, n_jobs=self.n_jobs, **parallel_kwargs)
    else:
        mols = dm.parallelized(_to_mol, mols, n_jobs=self.n_jobs, **parallel_kwargs)
        if self.n_jobs not in [0, 1]:
            # use a proxy model to run in parallel
            cpy = self.copy()
            features = dm.parallelized(
                cpy._transform,
                mols,
                n_jobs=self.n_jobs,
                **cpy.parallel_kwargs,
            )
        else:
            features = [self._transform(mol) for mol in mols]
    if not ignore_errors:
        for ind, feat in enumerate(features):
            if feat is None:
                raise ValueError(
                    f"Cannot transform molecule at index {ind}. Please check logs (set verbose to True) to see errors!"
                )

    # sklearn feature validation for sklearn pipeline
    return datatype.as_numpy_array_if_possible(features, self.dtype)

PrecomputedMolTransformer

Bases: MoleculeTransformer

Convenience class for storing precomputed features.

Source code in molfeat/trans/base.py
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
class PrecomputedMolTransformer(MoleculeTransformer):
    """Convenience class for storing precomputed features."""

    def __init__(
        self,
        cache: Optional[Union[_Cache, Mapping[Any, Any], CacheList]] = None,
        cache_dict: Optional[Dict[str, Union[_Cache, Mapping[Any, Any], CacheList]]] = None,
        cache_key: Optional[str] = None,
        *args,
        featurizer: Optional[Union[MoleculeTransformer, str]] = None,
        state_path: Optional[str] = None,
        **kwargs,
    ):
        """
        Transformer that supports precomputation of features. You can either pass an empty cache or a pre-populated cache

        Args:
            cache: a datastructure of type mapping that maps each molecule to the precomputed features
            cache_dict: A dictionary of cache objects. This is a convenient structure when use multiple
                datacache for model selection.
            cache_key: The key of cache object to use.
            featurizer: optional featurizer used to compute the features of values not in the cache.
                Either the featurizer object or a string.
            state_path: optional state file path used to initiate the transformer object at the initialization
        """
        if (state_path is not None) and (
            (cache is not None) or (cache_dict is not None and cache_key is not None)
        ):
            raise ValueError(
                "`PrecomputedMolTransformer` can only be initiated by either `state_path` or"
                " the rest of parameters for cache and featurizer. But both are given."
            )

        super().__init__(*args, featurizer="none", **kwargs)

        if state_path is not None:
            self.__dict__ = self.from_state_file(state_path=state_path).__dict__.copy()
        else:
            if cache_dict is not None and cache_key is not None:
                self.cache_key = cache_key
                self.cache = cache_dict[self.cache_key]
            elif cache is not None:
                self.cache = cache
            else:
                raise AttributeError("The cache is not specified.")

            if isinstance(featurizer, str):
                self.base_featurizer = MoleculeTransformer(featurizer, *args, **kwargs)
            else:
                self.base_featurizer = featurizer

        # Set the length of the featurizer
        if len(self.cache) > 0:
            self.length = len(list(self.cache.values())[0])
        elif self.base_featurizer is not None:
            self.length = len(self.base_featurizer)
        else:
            raise AttributeError(
                "The cache is empty and the base featurizer is not specified. It's impossible"
                " to determine the length of the featurizer."
            )

    def _transform(self, mol: dm.Mol):
        r"""
        Return precomputed feature for a single molecule

        Args:
            mol (dm.Mol): molecule to transform into features

        Returns
            feat: featurized input molecule

        """
        feat = self.cache.get(mol)
        # if feat is None and we have an existing featurizer, we can update the cache
        if feat is None and self.base_featurizer is not None:
            feat = self.base_featurizer._transform(mol)
            self.cache[mol] = feat

        try:
            feat = datatype.to_numpy(feat)
            if self.cols_to_keep is not None:
                feat = feat[self.cols_to_keep]
        except Exception as e:
            if self.verbose:
                logger.error(e)
        return feat

    def update(self, feat_dict: Mapping[Any, Any]):
        r"""
        Fill the cache with new set of features for the molecules in mols.

        Args:
            feat_dict: A dictionary of molecules to features.
        """
        self.cache.update(feat_dict)

    def __getstate__(self):
        """Get the state for pickling"""
        state = {k: copy.deepcopy(v) for k, v in self.__dict__.items() if k not in ["cache"]}
        if isinstance(self.cache, FileCache):
            state["file_cache_args"] = dict(
                cache_file=self.cache.cache_file,
                name=self.cache.name,
                mol_hasher=self.cache.mol_hasher,
                n_jobs=self.cache.n_jobs,
                verbose=self.cache.verbose,
                file_type=self.cache.file_type,
                parquet_kwargs=self.cache.parquet_kwargs,
            )
        else:
            # EN: we do not copy the cache
            state["cache"] = self.cache
        return state

    def __setstate__(self, state):
        if "file_cache_args" in state:
            cache = FileCache(**state.pop("file_cache_args"))
            state["cache"] = cache
        return super().__setstate__(state)

    def to_state_dict(self, save_to_file: bool = True) -> dict:
        """Serialize a PrecomputedMolTransformer object to a state dict.

        Notes:
            - The base_featurizer must be set or a ValueError will be raised.
            - The cache must be a FileCache object or a ValueError will be raised.

        Args:
            save_to_file: whether to save the cache to file.
        """

        if self.base_featurizer is None:
            raise ValueError(
                "You can't serialize a PrecomputedMolTransformer that does not contain a"
                " featurizer."
            )

        if not isinstance(self.cache, FileCache):
            raise ValueError("The cache must be a FileCache object.")

        state = {}
        state["name"] = "PrecomputedMolTransformer"
        state["base_featurizer"] = self.base_featurizer.to_state_dict()
        state["cache"] = self.cache.to_state_dict(save_to_file=save_to_file)
        state["_molfeat_version"] = MOLFEAT_VERSION

        return state

    @staticmethod
    def from_state_dict(
        state: dict,
        override_args: Optional[dict] = None,
    ) -> "PrecomputedMolTransformer":
        # Don't alter the original state dict
        state = copy.deepcopy(state)

        args = {}

        # Load the FileCache object
        args["cache"] = FileCache.from_state_dict(state["cache"])

        # Load the base featurizer
        args["featurizer"] = MoleculeTransformer.from_state_dict(state["base_featurizer"])

        if override_args is not None:
            args.update(override_args)

        # Doesn't allow state_path in the initiation args
        args.pop("state_path", None)
        return PrecomputedMolTransformer(**args)

    def from_state_file(
        self,
        state_path: str,
        override_args: Optional[dict] = None,
    ) -> "PrecomputedMolTransformer":
        if state_path.endswith("yaml") or state_path.endswith("yml"):
            return self.from_state_yaml_file(filepath=state_path, override_args=override_args)
        elif state_path.endswith("json"):
            return self.from_state_json_file(filepath=state_path, override_args=override_args)
        else:
            raise ValueError(
                "Only files with 'yaml' or 'json' format are allowed. "
                "The filename must be ending with `yaml`, 'yml' or 'json'."
            )

__getstate__()

Get the state for pickling

Source code in molfeat/trans/base.py
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
def __getstate__(self):
    """Get the state for pickling"""
    state = {k: copy.deepcopy(v) for k, v in self.__dict__.items() if k not in ["cache"]}
    if isinstance(self.cache, FileCache):
        state["file_cache_args"] = dict(
            cache_file=self.cache.cache_file,
            name=self.cache.name,
            mol_hasher=self.cache.mol_hasher,
            n_jobs=self.cache.n_jobs,
            verbose=self.cache.verbose,
            file_type=self.cache.file_type,
            parquet_kwargs=self.cache.parquet_kwargs,
        )
    else:
        # EN: we do not copy the cache
        state["cache"] = self.cache
    return state

__init__(cache=None, cache_dict=None, cache_key=None, *args, featurizer=None, state_path=None, **kwargs)

Transformer that supports precomputation of features. You can either pass an empty cache or a pre-populated cache

Parameters:

Name Type Description Default
cache Optional[Union[_Cache, Mapping[Any, Any], CacheList]]

a datastructure of type mapping that maps each molecule to the precomputed features

None
cache_dict Optional[Dict[str, Union[_Cache, Mapping[Any, Any], CacheList]]]

A dictionary of cache objects. This is a convenient structure when use multiple datacache for model selection.

None
cache_key Optional[str]

The key of cache object to use.

None
featurizer Optional[Union[MoleculeTransformer, str]]

optional featurizer used to compute the features of values not in the cache. Either the featurizer object or a string.

None
state_path Optional[str]

optional state file path used to initiate the transformer object at the initialization

None
Source code in molfeat/trans/base.py
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
def __init__(
    self,
    cache: Optional[Union[_Cache, Mapping[Any, Any], CacheList]] = None,
    cache_dict: Optional[Dict[str, Union[_Cache, Mapping[Any, Any], CacheList]]] = None,
    cache_key: Optional[str] = None,
    *args,
    featurizer: Optional[Union[MoleculeTransformer, str]] = None,
    state_path: Optional[str] = None,
    **kwargs,
):
    """
    Transformer that supports precomputation of features. You can either pass an empty cache or a pre-populated cache

    Args:
        cache: a datastructure of type mapping that maps each molecule to the precomputed features
        cache_dict: A dictionary of cache objects. This is a convenient structure when use multiple
            datacache for model selection.
        cache_key: The key of cache object to use.
        featurizer: optional featurizer used to compute the features of values not in the cache.
            Either the featurizer object or a string.
        state_path: optional state file path used to initiate the transformer object at the initialization
    """
    if (state_path is not None) and (
        (cache is not None) or (cache_dict is not None and cache_key is not None)
    ):
        raise ValueError(
            "`PrecomputedMolTransformer` can only be initiated by either `state_path` or"
            " the rest of parameters for cache and featurizer. But both are given."
        )

    super().__init__(*args, featurizer="none", **kwargs)

    if state_path is not None:
        self.__dict__ = self.from_state_file(state_path=state_path).__dict__.copy()
    else:
        if cache_dict is not None and cache_key is not None:
            self.cache_key = cache_key
            self.cache = cache_dict[self.cache_key]
        elif cache is not None:
            self.cache = cache
        else:
            raise AttributeError("The cache is not specified.")

        if isinstance(featurizer, str):
            self.base_featurizer = MoleculeTransformer(featurizer, *args, **kwargs)
        else:
            self.base_featurizer = featurizer

    # Set the length of the featurizer
    if len(self.cache) > 0:
        self.length = len(list(self.cache.values())[0])
    elif self.base_featurizer is not None:
        self.length = len(self.base_featurizer)
    else:
        raise AttributeError(
            "The cache is empty and the base featurizer is not specified. It's impossible"
            " to determine the length of the featurizer."
        )

to_state_dict(save_to_file=True)

Serialize a PrecomputedMolTransformer object to a state dict.

Notes
  • The base_featurizer must be set or a ValueError will be raised.
  • The cache must be a FileCache object or a ValueError will be raised.

Parameters:

Name Type Description Default
save_to_file bool

whether to save the cache to file.

True
Source code in molfeat/trans/base.py
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
def to_state_dict(self, save_to_file: bool = True) -> dict:
    """Serialize a PrecomputedMolTransformer object to a state dict.

    Notes:
        - The base_featurizer must be set or a ValueError will be raised.
        - The cache must be a FileCache object or a ValueError will be raised.

    Args:
        save_to_file: whether to save the cache to file.
    """

    if self.base_featurizer is None:
        raise ValueError(
            "You can't serialize a PrecomputedMolTransformer that does not contain a"
            " featurizer."
        )

    if not isinstance(self.cache, FileCache):
        raise ValueError("The cache must be a FileCache object.")

    state = {}
    state["name"] = "PrecomputedMolTransformer"
    state["base_featurizer"] = self.base_featurizer.to_state_dict()
    state["cache"] = self.cache.to_state_dict(save_to_file=save_to_file)
    state["_molfeat_version"] = MOLFEAT_VERSION

    return state

update(feat_dict)

Fill the cache with new set of features for the molecules in mols.

Parameters:

Name Type Description Default
feat_dict Mapping[Any, Any]

A dictionary of molecules to features.

required
Source code in molfeat/trans/base.py
790
791
792
793
794
795
796
797
def update(self, feat_dict: Mapping[Any, Any]):
    r"""
    Fill the cache with new set of features for the molecules in mols.

    Args:
        feat_dict: A dictionary of molecules to features.
    """
    self.cache.update(feat_dict)