Skip to content

molfeat.utils

Cache

CacheList

Proxy for supporting search using a list of cache

Source code in molfeat/utils/cache.py
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
class CacheList:
    """Proxy for supporting search using a list of cache"""

    def __init__(self, *caches):
        self.caches = caches

    def __getitem__(self, key):
        for cache in self.caches:
            val = cache.get(key)
            if val is not None:
                return val
        raise KeyError(f"{key} not found in any cache")

    def __contains__(self, key: Any):
        """Check whether a key is in the cache
        Args:
            key: key to check in the cache
        """
        return any(key in cache for cache in self.caches)

    def __len__(self):
        """Return the length of the cache"""
        return sum(len(c) for c in self.caches)

    def __iter__(self):
        """Iterate over all the caches"""
        return itertools.chain(*iter(self.cache))

    def __setitem__(self, key: Any, item: Any):
        """Add an item to the cache

        Args:
            key: input key to set
            item: value of the key to set
        """
        # select a random cache and add the item to the cache
        cache = random.choice(self.caches)
        cache.update({key: item})

    def __call__(self, *args, **kwargs):
        """
        Compute the features for a list of molecules and save them to the cache
        """

        raise NotImplementedError(
            "Dynamic updating of a cache list using a featurizer is not supported!"
        )

    def clear(self, *args, **kwargs):
        """Clear all the caches and make them inaccesible"""
        for cache in self.caches:
            cache.clear(*args, **kwargs)

    def update(self, new_cache: Mapping[Any, Any]):
        cache = random.choice(self.caches)
        cache.update(new_cache)

    def get(self, key, default: Optional[Any] = None):
        """Get the cached value for a specific key
        Args:
            key: key to get
            default: default value to return when the key is not found
        """
        for cache in self.caches:
            val = cache.get(key)
            if val is not None:
                return val
        return default

    def keys(self):
        """Get list of keys in the cache"""
        return list(itertools.chain(*(c.keys() for c in self.caches)))

    def values(self):
        """Get list of values in the cache"""
        return list(itertools.chain(*(c.values() for c in self.caches)))

    def items(self):
        """Return iterator of key, values in the cache"""
        return list(itertools.chain(*(c.items() for c in self.caches)))

    def to_dict(self):
        """Convert current cache to a dictionary"""
        return dict(self.items())

    def fetch(
        self,
        mols: List[Union[dm.Mol, str]],
    ):
        """Get the representation for a single

        Args:
            mols: list of molecules
        """
        if isinstance(mols, str) or not isinstance(mols, Iterable):
            mols = [mols]
        return [self.get(mol) for mol in mols]

__call__(*args, **kwargs)

Compute the features for a list of molecules and save them to the cache

Source code in molfeat/utils/cache.py
724
725
726
727
728
729
730
731
def __call__(self, *args, **kwargs):
    """
    Compute the features for a list of molecules and save them to the cache
    """

    raise NotImplementedError(
        "Dynamic updating of a cache list using a featurizer is not supported!"
    )

__contains__(key)

Check whether a key is in the cache Args: key: key to check in the cache

Source code in molfeat/utils/cache.py
698
699
700
701
702
703
def __contains__(self, key: Any):
    """Check whether a key is in the cache
    Args:
        key: key to check in the cache
    """
    return any(key in cache for cache in self.caches)

__iter__()

Iterate over all the caches

Source code in molfeat/utils/cache.py
709
710
711
def __iter__(self):
    """Iterate over all the caches"""
    return itertools.chain(*iter(self.cache))

__len__()

Return the length of the cache

Source code in molfeat/utils/cache.py
705
706
707
def __len__(self):
    """Return the length of the cache"""
    return sum(len(c) for c in self.caches)

__setitem__(key, item)

Add an item to the cache

Parameters:

Name Type Description Default
key Any

input key to set

required
item Any

value of the key to set

required
Source code in molfeat/utils/cache.py
713
714
715
716
717
718
719
720
721
722
def __setitem__(self, key: Any, item: Any):
    """Add an item to the cache

    Args:
        key: input key to set
        item: value of the key to set
    """
    # select a random cache and add the item to the cache
    cache = random.choice(self.caches)
    cache.update({key: item})

clear(*args, **kwargs)

Clear all the caches and make them inaccesible

Source code in molfeat/utils/cache.py
733
734
735
736
def clear(self, *args, **kwargs):
    """Clear all the caches and make them inaccesible"""
    for cache in self.caches:
        cache.clear(*args, **kwargs)

fetch(mols)

Get the representation for a single

Parameters:

Name Type Description Default
mols List[Union[Mol, str]]

list of molecules

required
Source code in molfeat/utils/cache.py
770
771
772
773
774
775
776
777
778
779
780
781
def fetch(
    self,
    mols: List[Union[dm.Mol, str]],
):
    """Get the representation for a single

    Args:
        mols: list of molecules
    """
    if isinstance(mols, str) or not isinstance(mols, Iterable):
        mols = [mols]
    return [self.get(mol) for mol in mols]

get(key, default=None)

Get the cached value for a specific key Args: key: key to get default: default value to return when the key is not found

Source code in molfeat/utils/cache.py
742
743
744
745
746
747
748
749
750
751
752
def get(self, key, default: Optional[Any] = None):
    """Get the cached value for a specific key
    Args:
        key: key to get
        default: default value to return when the key is not found
    """
    for cache in self.caches:
        val = cache.get(key)
        if val is not None:
            return val
    return default

items()

Return iterator of key, values in the cache

Source code in molfeat/utils/cache.py
762
763
764
def items(self):
    """Return iterator of key, values in the cache"""
    return list(itertools.chain(*(c.items() for c in self.caches)))

keys()

Get list of keys in the cache

Source code in molfeat/utils/cache.py
754
755
756
def keys(self):
    """Get list of keys in the cache"""
    return list(itertools.chain(*(c.keys() for c in self.caches)))

to_dict()

Convert current cache to a dictionary

Source code in molfeat/utils/cache.py
766
767
768
def to_dict(self):
    """Convert current cache to a dictionary"""
    return dict(self.items())

values()

Get list of values in the cache

Source code in molfeat/utils/cache.py
758
759
760
def values(self):
    """Get list of values in the cache"""
    return list(itertools.chain(*(c.values() for c in self.caches)))

DataCache

Bases: _Cache

Molecular features caching system that cache computed values in memory for reuse later

Source code in molfeat/utils/cache.py
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
class DataCache(_Cache):
    """
    Molecular features caching system that cache computed values in memory for reuse later
    """

    def __init__(
        self,
        name: str,
        n_jobs: int = -1,
        mol_hasher: Optional[Union[Callable, str, MolToKey]] = None,
        verbose: Union[bool, int] = False,
        cache_file: Optional[Union[os.PathLike, bool]] = None,
        delete_on_exit: bool = False,
        clear_on_exit: bool = True,
    ):
        """Precomputed fingerprint caching callback

        Args:
            name: name of the cache
            n_jobs: number of parallel jobs to use when performing any computation
            mol_hasher: function to use to hash molecules. If not provided, `dm.unique_id`` is used by default
            verbose: whether to print progress. Default to False
            cache_file: Cache location. Defaults to None, which will use in-memory caching.
            delete_on_exit: Whether to delete the cache file on exit. Defaults to False.
            clear_on_exit: Whether to clear the cache on exit of the interpreter. Default to True
        """
        super().__init__(name=name, mol_hasher=mol_hasher, n_jobs=n_jobs, verbose=verbose)

        if cache_file is True:
            cache_file = pathlib.Path(
                platformdirs.user_cache_dir(appname="molfeat")
            ) / "precomputed/{}_{}.db".format(self.name, str(uuid.uuid4())[:8])

            cache_file = str(cache_file)
        self.cache_file = cache_file
        self.cache = {}
        self._initialize_cache()
        self.delete_on_exit = delete_on_exit
        self.clear_on_exit = clear_on_exit
        if self.clear_on_exit:
            atexit.register(partial(self.clear, delete=delete_on_exit))

    def _initialize_cache(self):
        if self.cache_file not in [None, False]:
            # force creation of cache directory
            cache_parent = pathlib.Path(self.cache_file).parent
            cache_parent.mkdir(parents=True, exist_ok=True)
            self.cache = shelve.open(self.cache_file)
        else:
            self.cache = {}

    def clear(self, delete: bool = False):
        """Clear cache memory if needed.
        Note that a cleared cache cannot be used anymore

        Args:
            delete: whether to delete the cache file if on disk
        """
        self.cache.clear()
        if isinstance(self.cache, shelve.Shelf):
            self.cache.close()
            # EN: temporary set it to a dict before reopening
            # this needs to be done to prevent operating on close files
            self.cache = {}
        if delete:
            if self.cache_file is not None:
                for path in glob.glob(str(self.cache_file) + "*"):
                    try:
                        os.unlink(path)
                    except Exception:  # noqa
                        pass
        else:
            self._initialize_cache()

    def update(self, new_cache: Mapping[Any, Any]):
        """Update the cache with new values

        Args:
            new_cache: new cache with items to use to update current cache
        """
        for k, v in new_cache.items():
            k = self.mol_hasher(k)
            self.cache[k] = v
        return self

    def _sync_cache(self):
        """Perform a cache sync to ensure values are up to date"""
        if isinstance(self.cache, shelve.Shelf):
            self.cache.sync()

    @classmethod
    def load_from_file(cls, filepath: Union[os.PathLike, str]):
        """Load a datache from a file (including remote file)

        Args:
            filepath: path to the file to load
        """
        cached_data = None
        with fsspec.open(filepath, "rb") as f:
            cached_data = joblib.load(f)
        data = cached_data.pop("data", {})
        new_cache = cls(**cached_data)
        new_cache.update(data)
        return new_cache

    def save_to_file(self, filepath: Union[os.PathLike, str]):
        """Save the cache to a file

        Args:
            filepath: path to the file to save
        """
        information = dict(
            name=self.name,
            n_jobs=self.n_jobs,
            mol_hasher=self.mol_hasher,
            verbose=self.verbose,
            cache_file=(self.cache_file is not None),
            delete_on_exit=self.delete_on_exit,
        )
        information["data"] = self.to_dict()
        with fsspec.open(filepath, "wb") as f:
            joblib.dump(information, f)

__init__(name, n_jobs=-1, mol_hasher=None, verbose=False, cache_file=None, delete_on_exit=False, clear_on_exit=True)

Precomputed fingerprint caching callback

Parameters:

Name Type Description Default
name str

name of the cache

required
n_jobs int

number of parallel jobs to use when performing any computation

-1
mol_hasher Optional[Union[Callable, str, MolToKey]]

function to use to hash molecules. If not provided, `dm.unique_id`` is used by default

None
verbose Union[bool, int]

whether to print progress. Default to False

False
cache_file Optional[Union[PathLike, bool]]

Cache location. Defaults to None, which will use in-memory caching.

None
delete_on_exit bool

Whether to delete the cache file on exit. Defaults to False.

False
clear_on_exit bool

Whether to clear the cache on exit of the interpreter. Default to True

True
Source code in molfeat/utils/cache.py
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
def __init__(
    self,
    name: str,
    n_jobs: int = -1,
    mol_hasher: Optional[Union[Callable, str, MolToKey]] = None,
    verbose: Union[bool, int] = False,
    cache_file: Optional[Union[os.PathLike, bool]] = None,
    delete_on_exit: bool = False,
    clear_on_exit: bool = True,
):
    """Precomputed fingerprint caching callback

    Args:
        name: name of the cache
        n_jobs: number of parallel jobs to use when performing any computation
        mol_hasher: function to use to hash molecules. If not provided, `dm.unique_id`` is used by default
        verbose: whether to print progress. Default to False
        cache_file: Cache location. Defaults to None, which will use in-memory caching.
        delete_on_exit: Whether to delete the cache file on exit. Defaults to False.
        clear_on_exit: Whether to clear the cache on exit of the interpreter. Default to True
    """
    super().__init__(name=name, mol_hasher=mol_hasher, n_jobs=n_jobs, verbose=verbose)

    if cache_file is True:
        cache_file = pathlib.Path(
            platformdirs.user_cache_dir(appname="molfeat")
        ) / "precomputed/{}_{}.db".format(self.name, str(uuid.uuid4())[:8])

        cache_file = str(cache_file)
    self.cache_file = cache_file
    self.cache = {}
    self._initialize_cache()
    self.delete_on_exit = delete_on_exit
    self.clear_on_exit = clear_on_exit
    if self.clear_on_exit:
        atexit.register(partial(self.clear, delete=delete_on_exit))

clear(delete=False)

Clear cache memory if needed. Note that a cleared cache cannot be used anymore

Parameters:

Name Type Description Default
delete bool

whether to delete the cache file if on disk

False
Source code in molfeat/utils/cache.py
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
def clear(self, delete: bool = False):
    """Clear cache memory if needed.
    Note that a cleared cache cannot be used anymore

    Args:
        delete: whether to delete the cache file if on disk
    """
    self.cache.clear()
    if isinstance(self.cache, shelve.Shelf):
        self.cache.close()
        # EN: temporary set it to a dict before reopening
        # this needs to be done to prevent operating on close files
        self.cache = {}
    if delete:
        if self.cache_file is not None:
            for path in glob.glob(str(self.cache_file) + "*"):
                try:
                    os.unlink(path)
                except Exception:  # noqa
                    pass
    else:
        self._initialize_cache()

load_from_file(filepath) classmethod

Load a datache from a file (including remote file)

Parameters:

Name Type Description Default
filepath Union[PathLike, str]

path to the file to load

required
Source code in molfeat/utils/cache.py
377
378
379
380
381
382
383
384
385
386
387
388
389
390
@classmethod
def load_from_file(cls, filepath: Union[os.PathLike, str]):
    """Load a datache from a file (including remote file)

    Args:
        filepath: path to the file to load
    """
    cached_data = None
    with fsspec.open(filepath, "rb") as f:
        cached_data = joblib.load(f)
    data = cached_data.pop("data", {})
    new_cache = cls(**cached_data)
    new_cache.update(data)
    return new_cache

save_to_file(filepath)

Save the cache to a file

Parameters:

Name Type Description Default
filepath Union[PathLike, str]

path to the file to save

required
Source code in molfeat/utils/cache.py
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
def save_to_file(self, filepath: Union[os.PathLike, str]):
    """Save the cache to a file

    Args:
        filepath: path to the file to save
    """
    information = dict(
        name=self.name,
        n_jobs=self.n_jobs,
        mol_hasher=self.mol_hasher,
        verbose=self.verbose,
        cache_file=(self.cache_file is not None),
        delete_on_exit=self.delete_on_exit,
    )
    information["data"] = self.to_dict()
    with fsspec.open(filepath, "wb") as f:
        joblib.dump(information, f)

update(new_cache)

Update the cache with new values

Parameters:

Name Type Description Default
new_cache Mapping[Any, Any]

new cache with items to use to update current cache

required
Source code in molfeat/utils/cache.py
361
362
363
364
365
366
367
368
369
370
def update(self, new_cache: Mapping[Any, Any]):
    """Update the cache with new values

    Args:
        new_cache: new cache with items to use to update current cache
    """
    for k, v in new_cache.items():
        k = self.mol_hasher(k)
        self.cache[k] = v
    return self

FileCache

Bases: _Cache

Read only cache that holds in precomputed data in a pickle, csv or h5py file.

The convention used requires the 'keys' and 'values' columns when the input file needs to be loaded as a dataframe.

Source code in molfeat/utils/cache.py
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
class FileCache(_Cache):
    """
    Read only cache that holds in precomputed data in a pickle, csv or h5py file.

    The convention used requires the 'keys' and  'values' columns when
    the input file needs to be loaded as a dataframe.
    """

    _PICKLE_PROTOCOL = 4
    SUPPORTED_TYPES = ["pickle", "pkl", "csv", "parquet", "pq", "hdf5", "h5"]

    def __init__(
        self,
        cache_file: Union[os.PathLike, str],
        name: Optional[str] = None,
        mol_hasher: Optional[Union[Callable, str, MolToKey]] = None,
        n_jobs: Optional[int] = None,
        verbose: Union[bool, int] = False,
        file_type: str = "parquet",
        clear_on_exit: bool = True,
        parquet_kwargs: Optional[Dict[Any, Any]] = None,
    ):
        """Precomputed fingerprint caching callback

        !!! note
            Do not pickle this object, instead use the provided saving methods.

        Args:
            cache_file: Cache location. Can be a local file or a remote file
            name: optional name to give the cache
            mol_hasher: function to use to hash molecules. If not provided, `dm.unique_id` is used by default
            n_jobs: number of parallel jobs to use when performing any computation
            verbose: whether to print information about the cache
            clear_on_exit: whether to clear the cache on exit of the interpreter
            file_type: File type that was provided. One of "csv", "pickle", "hdf5" and "parquet"
                For "csv" and "parquet", we expect columns "keys" and "values"
                For a pickle, we expect either a mapping or a dataframe with "keys" and "values" columns
            parquet_kwargs: Argument to pass to the parquet reader.
        """
        super().__init__(name=name, mol_hasher=mol_hasher, n_jobs=n_jobs, verbose=verbose)

        self.cache_file = cache_file
        self.file_type = file_type
        self.parquet_kwargs = parquet_kwargs or {}
        self.clear_on_exit = clear_on_exit

        if self.file_type not in FileCache.SUPPORTED_TYPES:
            raise ValueError(
                f"Unsupported file type, expected one of {FileCache.SUPPORTED_TYPES}, got '{self.file_type}'"
            )

        if self.cache_file is not None and dm.fs.exists(self.cache_file):
            self._load_cache()
        else:
            self.cache = {}

        if self.clear_on_exit:
            atexit.register(self.clear)

    def clear(self):
        """Clear cache memory at exit and close any open file
        Note that a cleared cache cannot be used anymore !
        """
        if self.file_type in ["hdf5", "h5"]:
            self.cache.close()
        else:
            del self.cache
        # reset cache to empty
        self.cache = {}

    def items(self):
        """Return iterator of key, values in the cache"""
        if self.file_type in ["hdf5", "h5"]:
            return ((k, np.asarray(v)) for k, v in self.cache.items())
        return super().items()

    def _load_cache(self):
        """Load cache internally if needed"""

        file_exists = dm.utils.fs.exists(self.cache_file)

        if self.file_type in ["hdf5", "h5"]:
            f = fsspec.open("simplecache::" + self.cache_file, "rb+").open()
            self.cache = h5py.File(f, "r+")

        elif not file_exists:
            self.cache = {}

        elif self.file_type in ["pickle", "pkl"]:
            with fsspec.open(self.cache_file, "rb") as IN:
                self.cache = joblib.load(IN)

        elif self.file_type == "csv":
            with fsspec.open(self.cache_file, "rb") as IN:
                # Allow the CSV file to exist but with an empty content
                try:
                    self.cache = pd.read_csv(
                        IN,
                        converters={"values": lambda x: commons.unpack_bits(ast.literal_eval(x))},
                    )
                except pandas.errors.EmptyDataError:
                    self.cache = {}

        elif self.file_type in ["parquet", "pq"]:
            self.cache = pd.read_parquet(
                self.cache_file,
                columns=["keys", "values"],
                **self.parquet_kwargs,
            )
        # convert dataframe to dict if needed
        if isinstance(self.cache, pd.DataFrame):
            self.cache = self.cache.set_index("keys").to_dict()["values"]

    def update(self, new_cache: Mapping[Any, Any]):
        """Update the cache with new values

        Args:
            new_cache: new cache with items to use to update current cache
        """
        for k, v in new_cache.items():
            key = self.mol_hasher(k)
            if self.file_type in ["hdf5", "h5"]:
                self.cache.create_dataset(key, data=v)
            else:
                self.cache[key] = v
        return self

    @classmethod
    def load_from_file(cls, filepath: Union[os.PathLike, str], **kwargs):
        """Load a FileCache from a file

        Args:
            filepath: path to the file to load
            kwargs: keyword arguments to pass to the constructor
        """
        new_cache = cls(cache_file=filepath, **kwargs)
        return new_cache

    def to_dataframe(self, pack_bits: bool = False):
        """Convert the cache to a dataframe. The converted dataframe would have `keys` and `values` columns

        Args:
            pack_bits: whether to pack the values columns into bits.
                By using molfeat.utils.commons.unpack_bits, the values column can be reloaded as an array
        """
        if pack_bits:
            loaded_items = [
                (k, commons.pack_bits(x, protocol=self._PICKLE_PROTOCOL)) for k, x in self.items()
            ]
        else:
            loaded_items = self.items()
        df = pd.DataFrame(loaded_items, columns=["keys", "values"])
        return df

    def save_to_file(
        self,
        filepath: Optional[Union[os.PathLike, str]] = None,
        file_type: Optional[str] = None,
        **kwargs,
    ):
        """Save the cache to a file

        Args:
            filepath: path to the file to save. If None, the cache is saved to the original file.
            file_type: format used to save the cache to file one of "pickle", "csv", "hdf5", "parquet".
                If None, the original file type is used.
            kwargs: keyword arguments to pass to the serializer to disk (e.g to pass to pd.to_csv or pd.to_parquet)
        """

        if filepath is None:
            filepath = self.cache_file

        if file_type is None:
            file_type = self.file_type

        if file_type in ["pkl", "pickle"]:
            with fsspec.open(filepath, "wb") as f:
                joblib.dump(self.to_dict(), f)

        elif file_type in ["csv", "parquet", "pq"]:
            df = self.to_dataframe(pack_bits=(file_type == "csv"))

            if file_type == "csv":
                with fsspec.open(filepath, "w") as f:
                    df.to_csv(f, index=False, **kwargs)
            else:
                df.to_parquet(filepath, index=False, **kwargs)

        elif file_type in ["hdf5", "h5"]:
            with fsspec.open(filepath, "wb") as IN:
                with h5py.File(IN, "w") as f:
                    for k, v in self.items():
                        f.create_dataset(k, data=v)
        else:
            raise ValueError("Unsupported output protocol: {}".format(file_type))

    def to_state_dict(self, save_to_file: bool = True) -> dict:
        """Serialize the cache to a state dict.

        Args:
            save_to_file: whether to save the cache to file.
        """

        if save_to_file is True:
            self.save_to_file()

        state = {}
        state["_cache_name"] = "FileCache"
        state["cache_file"] = self.cache_file
        state["name"] = self.name
        state["n_jobs"] = self.n_jobs
        state["verbose"] = self.verbose
        state["file_type"] = self.file_type
        state["clear_on_exit"] = self.clear_on_exit
        state["parquet_kwargs"] = self.parquet_kwargs
        state["mol_hasher"] = self.mol_hasher.to_state_dict()

        return state

    @staticmethod
    def from_state_dict(state: dict, override_args: Optional[dict] = None) -> "FileCache":
        # Don't alter the original state dict
        state = copy.deepcopy(state)

        cache_name = state.pop("_cache_name")

        if cache_name != "FileCache":
            raise ValueError(f"The cache object name is invalid: {cache_name}")

        # Load the MolToKey object
        state["mol_hasher"] = MolToKey.from_state_dict(state["mol_hasher"])

        if override_args is not None:
            state.update(override_args)

        return FileCache(**state)

__init__(cache_file, name=None, mol_hasher=None, n_jobs=None, verbose=False, file_type='parquet', clear_on_exit=True, parquet_kwargs=None)

Precomputed fingerprint caching callback

Note

Do not pickle this object, instead use the provided saving methods.

Parameters:

Name Type Description Default
cache_file Union[PathLike, str]

Cache location. Can be a local file or a remote file

required
name Optional[str]

optional name to give the cache

None
mol_hasher Optional[Union[Callable, str, MolToKey]]

function to use to hash molecules. If not provided, dm.unique_id is used by default

None
n_jobs Optional[int]

number of parallel jobs to use when performing any computation

None
verbose Union[bool, int]

whether to print information about the cache

False
clear_on_exit bool

whether to clear the cache on exit of the interpreter

True
file_type str

File type that was provided. One of "csv", "pickle", "hdf5" and "parquet" For "csv" and "parquet", we expect columns "keys" and "values" For a pickle, we expect either a mapping or a dataframe with "keys" and "values" columns

'parquet'
parquet_kwargs Optional[Dict[Any, Any]]

Argument to pass to the parquet reader.

None
Source code in molfeat/utils/cache.py
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
def __init__(
    self,
    cache_file: Union[os.PathLike, str],
    name: Optional[str] = None,
    mol_hasher: Optional[Union[Callable, str, MolToKey]] = None,
    n_jobs: Optional[int] = None,
    verbose: Union[bool, int] = False,
    file_type: str = "parquet",
    clear_on_exit: bool = True,
    parquet_kwargs: Optional[Dict[Any, Any]] = None,
):
    """Precomputed fingerprint caching callback

    !!! note
        Do not pickle this object, instead use the provided saving methods.

    Args:
        cache_file: Cache location. Can be a local file or a remote file
        name: optional name to give the cache
        mol_hasher: function to use to hash molecules. If not provided, `dm.unique_id` is used by default
        n_jobs: number of parallel jobs to use when performing any computation
        verbose: whether to print information about the cache
        clear_on_exit: whether to clear the cache on exit of the interpreter
        file_type: File type that was provided. One of "csv", "pickle", "hdf5" and "parquet"
            For "csv" and "parquet", we expect columns "keys" and "values"
            For a pickle, we expect either a mapping or a dataframe with "keys" and "values" columns
        parquet_kwargs: Argument to pass to the parquet reader.
    """
    super().__init__(name=name, mol_hasher=mol_hasher, n_jobs=n_jobs, verbose=verbose)

    self.cache_file = cache_file
    self.file_type = file_type
    self.parquet_kwargs = parquet_kwargs or {}
    self.clear_on_exit = clear_on_exit

    if self.file_type not in FileCache.SUPPORTED_TYPES:
        raise ValueError(
            f"Unsupported file type, expected one of {FileCache.SUPPORTED_TYPES}, got '{self.file_type}'"
        )

    if self.cache_file is not None and dm.fs.exists(self.cache_file):
        self._load_cache()
    else:
        self.cache = {}

    if self.clear_on_exit:
        atexit.register(self.clear)

clear()

Clear cache memory at exit and close any open file Note that a cleared cache cannot be used anymore !

Source code in molfeat/utils/cache.py
506
507
508
509
510
511
512
513
514
515
def clear(self):
    """Clear cache memory at exit and close any open file
    Note that a cleared cache cannot be used anymore !
    """
    if self.file_type in ["hdf5", "h5"]:
        self.cache.close()
    else:
        del self.cache
    # reset cache to empty
    self.cache = {}

items()

Return iterator of key, values in the cache

Source code in molfeat/utils/cache.py
517
518
519
520
521
def items(self):
    """Return iterator of key, values in the cache"""
    if self.file_type in ["hdf5", "h5"]:
        return ((k, np.asarray(v)) for k, v in self.cache.items())
    return super().items()

load_from_file(filepath, **kwargs) classmethod

Load a FileCache from a file

Parameters:

Name Type Description Default
filepath Union[PathLike, str]

path to the file to load

required
kwargs

keyword arguments to pass to the constructor

{}
Source code in molfeat/utils/cache.py
574
575
576
577
578
579
580
581
582
583
@classmethod
def load_from_file(cls, filepath: Union[os.PathLike, str], **kwargs):
    """Load a FileCache from a file

    Args:
        filepath: path to the file to load
        kwargs: keyword arguments to pass to the constructor
    """
    new_cache = cls(cache_file=filepath, **kwargs)
    return new_cache

save_to_file(filepath=None, file_type=None, **kwargs)

Save the cache to a file

Parameters:

Name Type Description Default
filepath Optional[Union[PathLike, str]]

path to the file to save. If None, the cache is saved to the original file.

None
file_type Optional[str]

format used to save the cache to file one of "pickle", "csv", "hdf5", "parquet". If None, the original file type is used.

None
kwargs

keyword arguments to pass to the serializer to disk (e.g to pass to pd.to_csv or pd.to_parquet)

{}
Source code in molfeat/utils/cache.py
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
def save_to_file(
    self,
    filepath: Optional[Union[os.PathLike, str]] = None,
    file_type: Optional[str] = None,
    **kwargs,
):
    """Save the cache to a file

    Args:
        filepath: path to the file to save. If None, the cache is saved to the original file.
        file_type: format used to save the cache to file one of "pickle", "csv", "hdf5", "parquet".
            If None, the original file type is used.
        kwargs: keyword arguments to pass to the serializer to disk (e.g to pass to pd.to_csv or pd.to_parquet)
    """

    if filepath is None:
        filepath = self.cache_file

    if file_type is None:
        file_type = self.file_type

    if file_type in ["pkl", "pickle"]:
        with fsspec.open(filepath, "wb") as f:
            joblib.dump(self.to_dict(), f)

    elif file_type in ["csv", "parquet", "pq"]:
        df = self.to_dataframe(pack_bits=(file_type == "csv"))

        if file_type == "csv":
            with fsspec.open(filepath, "w") as f:
                df.to_csv(f, index=False, **kwargs)
        else:
            df.to_parquet(filepath, index=False, **kwargs)

    elif file_type in ["hdf5", "h5"]:
        with fsspec.open(filepath, "wb") as IN:
            with h5py.File(IN, "w") as f:
                for k, v in self.items():
                    f.create_dataset(k, data=v)
    else:
        raise ValueError("Unsupported output protocol: {}".format(file_type))

to_dataframe(pack_bits=False)

Convert the cache to a dataframe. The converted dataframe would have keys and values columns

Parameters:

Name Type Description Default
pack_bits bool

whether to pack the values columns into bits. By using molfeat.utils.commons.unpack_bits, the values column can be reloaded as an array

False
Source code in molfeat/utils/cache.py
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
def to_dataframe(self, pack_bits: bool = False):
    """Convert the cache to a dataframe. The converted dataframe would have `keys` and `values` columns

    Args:
        pack_bits: whether to pack the values columns into bits.
            By using molfeat.utils.commons.unpack_bits, the values column can be reloaded as an array
    """
    if pack_bits:
        loaded_items = [
            (k, commons.pack_bits(x, protocol=self._PICKLE_PROTOCOL)) for k, x in self.items()
        ]
    else:
        loaded_items = self.items()
    df = pd.DataFrame(loaded_items, columns=["keys", "values"])
    return df

to_state_dict(save_to_file=True)

Serialize the cache to a state dict.

Parameters:

Name Type Description Default
save_to_file bool

whether to save the cache to file.

True
Source code in molfeat/utils/cache.py
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
def to_state_dict(self, save_to_file: bool = True) -> dict:
    """Serialize the cache to a state dict.

    Args:
        save_to_file: whether to save the cache to file.
    """

    if save_to_file is True:
        self.save_to_file()

    state = {}
    state["_cache_name"] = "FileCache"
    state["cache_file"] = self.cache_file
    state["name"] = self.name
    state["n_jobs"] = self.n_jobs
    state["verbose"] = self.verbose
    state["file_type"] = self.file_type
    state["clear_on_exit"] = self.clear_on_exit
    state["parquet_kwargs"] = self.parquet_kwargs
    state["mol_hasher"] = self.mol_hasher.to_state_dict()

    return state

update(new_cache)

Update the cache with new values

Parameters:

Name Type Description Default
new_cache Mapping[Any, Any]

new cache with items to use to update current cache

required
Source code in molfeat/utils/cache.py
560
561
562
563
564
565
566
567
568
569
570
571
572
def update(self, new_cache: Mapping[Any, Any]):
    """Update the cache with new values

    Args:
        new_cache: new cache with items to use to update current cache
    """
    for k, v in new_cache.items():
        key = self.mol_hasher(k)
        if self.file_type in ["hdf5", "h5"]:
            self.cache.create_dataset(key, data=v)
        else:
            self.cache[key] = v
    return self

MPDataCache

Bases: DataCache

A datacache that supports multiprocessing natively

Source code in molfeat/utils/cache.py
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
class MPDataCache(DataCache):
    """A datacache that supports multiprocessing natively"""

    def __init__(
        self,
        name: Optional[str] = None,
        n_jobs: int = -1,
        mol_hasher: Optional[Union[Callable, str, MolToKey]] = None,
        verbose: Union[bool, int] = False,
        clear_on_exit: bool = False,
    ):
        """Multiprocessing datacache that save cache into a shared memory

        Args:
            name: name of the cache
            n_jobs: number of parallel jobs to use when performing any computation
            mol_hasher: function to use to hash molecules. If not provided, `dm.unique_id`` is used by default
            verbose: whether to print progress. Default to False
            clear_on_exit: Whether to clear the cache on exit. Default is False to allow sharing the cache content
        """
        super().__init__(
            name=name,
            n_jobs=n_jobs,
            mol_hasher=mol_hasher,
            cache_file=None,
            verbose=verbose,
            delete_on_exit=False,
            clear_on_exit=clear_on_exit,
        )

    def _initialize_cache(self):
        """Initialize empty cache using a shared dict"""
        manager = mp.Manager()  # this might not be a great idea to initialize everytime...
        self.cache = manager.dict()

__init__(name=None, n_jobs=-1, mol_hasher=None, verbose=False, clear_on_exit=False)

Multiprocessing datacache that save cache into a shared memory

Parameters:

Name Type Description Default
name Optional[str]

name of the cache

None
n_jobs int

number of parallel jobs to use when performing any computation

-1
mol_hasher Optional[Union[Callable, str, MolToKey]]

function to use to hash molecules. If not provided, `dm.unique_id`` is used by default

None
verbose Union[bool, int]

whether to print progress. Default to False

False
clear_on_exit bool

Whether to clear the cache on exit. Default is False to allow sharing the cache content

False
Source code in molfeat/utils/cache.py
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
def __init__(
    self,
    name: Optional[str] = None,
    n_jobs: int = -1,
    mol_hasher: Optional[Union[Callable, str, MolToKey]] = None,
    verbose: Union[bool, int] = False,
    clear_on_exit: bool = False,
):
    """Multiprocessing datacache that save cache into a shared memory

    Args:
        name: name of the cache
        n_jobs: number of parallel jobs to use when performing any computation
        mol_hasher: function to use to hash molecules. If not provided, `dm.unique_id`` is used by default
        verbose: whether to print progress. Default to False
        clear_on_exit: Whether to clear the cache on exit. Default is False to allow sharing the cache content
    """
    super().__init__(
        name=name,
        n_jobs=n_jobs,
        mol_hasher=mol_hasher,
        cache_file=None,
        verbose=verbose,
        delete_on_exit=False,
        clear_on_exit=clear_on_exit,
    )

MolToKey

Convert a molecule to a key

Source code in molfeat/utils/cache.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class MolToKey:
    """Convert a molecule to a key"""

    SUPPORTED_HASH_FN = {
        "dm.unique_id": dm.unique_id,
        "dm.to_inchikey": dm.to_inchikey,
    }

    def __init__(self, hash_fn: Optional[Union[Callable, str]] = "dm.unique_id"):
        """Init function for molecular key generator.

        Args:
            hash_fn: hash function to use for the molecular key
        """

        if isinstance(hash_fn, str):
            if hash_fn not in self.SUPPORTED_HASH_FN:
                raise ValueError(
                    f"Hash function {hash_fn} is not supported. "
                    f"Supported hash functions are: {self.SUPPORTED_HASH_FN.keys()}"
                )

            self.hash_name = hash_fn
            self.hash_fn = self.SUPPORTED_HASH_FN[hash_fn]

        else:
            self.hash_fn = hash_fn
            self.hash_name = None

            if self.hash_fn is None:
                self.hash_fn = dm.unique_id
                self.hash_name = "dm.unique_id"

    def __call__(self, mol: dm.Mol):
        """Convert a molecule object to a key that can be used for the cache system

        Args:
            mol: input molecule object
        """
        with dm.without_rdkit_log():
            is_mol = dm.to_mol(mol) is not None
            if is_mol and self.hash_fn is not None:
                return self.hash_fn(mol)
        return mol

    def to_state_dict(self):
        """Serialize MolToKey to a state dict."""

        if self.hash_name is None:
            raise ValueError(
                "The hash function has been provided as a function and not a string. "
                "So it's impossible to save the state. You must specifiy the hash function as a string instead."
            )

        state = {}
        state["hash_name"] = self.hash_name
        return state

    @staticmethod
    def from_state_dict(state: dict) -> "MolToKey":
        """Load a MolToKey object from a state dict."""
        return MolToKey(hash_fn=state["hash_name"])

__call__(mol)

Convert a molecule object to a key that can be used for the cache system

Parameters:

Name Type Description Default
mol Mol

input molecule object

required
Source code in molfeat/utils/cache.py
68
69
70
71
72
73
74
75
76
77
78
def __call__(self, mol: dm.Mol):
    """Convert a molecule object to a key that can be used for the cache system

    Args:
        mol: input molecule object
    """
    with dm.without_rdkit_log():
        is_mol = dm.to_mol(mol) is not None
        if is_mol and self.hash_fn is not None:
            return self.hash_fn(mol)
    return mol

__init__(hash_fn='dm.unique_id')

Init function for molecular key generator.

Parameters:

Name Type Description Default
hash_fn Optional[Union[Callable, str]]

hash function to use for the molecular key

'dm.unique_id'
Source code in molfeat/utils/cache.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def __init__(self, hash_fn: Optional[Union[Callable, str]] = "dm.unique_id"):
    """Init function for molecular key generator.

    Args:
        hash_fn: hash function to use for the molecular key
    """

    if isinstance(hash_fn, str):
        if hash_fn not in self.SUPPORTED_HASH_FN:
            raise ValueError(
                f"Hash function {hash_fn} is not supported. "
                f"Supported hash functions are: {self.SUPPORTED_HASH_FN.keys()}"
            )

        self.hash_name = hash_fn
        self.hash_fn = self.SUPPORTED_HASH_FN[hash_fn]

    else:
        self.hash_fn = hash_fn
        self.hash_name = None

        if self.hash_fn is None:
            self.hash_fn = dm.unique_id
            self.hash_name = "dm.unique_id"

from_state_dict(state) staticmethod

Load a MolToKey object from a state dict.

Source code in molfeat/utils/cache.py
93
94
95
96
@staticmethod
def from_state_dict(state: dict) -> "MolToKey":
    """Load a MolToKey object from a state dict."""
    return MolToKey(hash_fn=state["hash_name"])

to_state_dict()

Serialize MolToKey to a state dict.

Source code in molfeat/utils/cache.py
80
81
82
83
84
85
86
87
88
89
90
91
def to_state_dict(self):
    """Serialize MolToKey to a state dict."""

    if self.hash_name is None:
        raise ValueError(
            "The hash function has been provided as a function and not a string. "
            "So it's impossible to save the state. You must specifiy the hash function as a string instead."
        )

    state = {}
    state["hash_name"] = self.hash_name
    return state

Common utils

Common utility functions

align_conformers(mols, ref_id=0, copy=True, conformer_id=-1)

Align a list of molecules to a reference molecule.

Note: consider adding me to datamol.

Parameters:

Name Type Description Default
mols List[Mol]

List of molecules to align. All the molecules must have a conformer.

required
ref_id int

Index of the reference molecule. By default, the first molecule in the list will be used as reference.

0
copy bool

Whether to copy the molecules before performing the alignement.

True
conformer_id int

Conformer id to use.

-1

Returns:

Name Type Description
mols

The aligned molecules.

scores

The score of the alignement.

Source code in molfeat/utils/commons.py
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
def align_conformers(
    mols: List[dm.Mol],
    ref_id: int = 0,
    copy: bool = True,
    conformer_id: int = -1,
):
    """Align a list of molecules to a reference molecule.

    Note: consider adding me to `datamol`.

    Args:
        mols: List of molecules to align. All the molecules must have a conformer.
        ref_id: Index of the reference molecule. By default, the first molecule in the list
            will be used as reference.
        copy: Whether to copy the molecules before performing the alignement.
        conformer_id: Conformer id to use.

    Returns:
        mols: The aligned molecules.
        scores: The score of the alignement.
    """

    # Check all input molecules has a conformer
    if not all([mol.GetNumConformers() >= 1 for mol in mols]):
        raise ValueError("One or more input molecules is missing a conformer.")

    # Make a copy of the molecules since they are going to be modified
    if copy:
        mols = [dm.copy_mol(mol) for mol in mols]

    # Compute Crippen contributions for every atoms and molecules
    crippen_contribs = [rdMolDescriptors._CalcCrippenContribs(mol) for mol in mols]

    # Split reference and probe molecules
    crippen_contrib_ref = crippen_contribs[ref_id]
    crippen_contrib_probes = crippen_contribs
    mol_ref = mols[ref_id]
    mol_probes = mols

    # Loop and align
    scores = []
    for i, mol in enumerate(mol_probes):
        crippenO3A = rdMolAlign.GetCrippenO3A(
            prbMol=mol,
            refMol=mol_ref,
            prbCrippenContribs=crippen_contrib_probes[i],
            refCrippenContribs=crippen_contrib_ref,
            prbCid=conformer_id,
            refCid=conformer_id,
            maxIters=50,
        )
        crippenO3A.Align()

        scores.append(crippenO3A.Score())

    scores = np.array(scores)

    return mols, scores

concat_dict(prop_dict, new_name, order=None)

Concat properties in dict into a single key dict

Parameters:

Name Type Description Default
prop_dict dict

Input dict of property names and their computed values

required
new_name str

new name under which the concatenated property dict will be returned

required
order Optional[Iterable[str]]

Optional list of key that specifies the order in which concatenation should be done. Sorting list by default

None

Returns:

Name Type Description
dict

dictionary of concatenated output values with a single key corresponding to new_name

Source code in molfeat/utils/commons.py
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
def concat_dict(prop_dict: dict, new_name: str, order: Optional[Iterable[str]] = None):
    """Concat properties in dict into a single key dict

    Args:
        prop_dict (dict): Input dict of property names and their computed values
        new_name (str): new name under which the concatenated property dict will be returned
        order: Optional list of key that specifies the order in which concatenation should be done. Sorting list by default

    Returns:
        dict: dictionary of concatenated output values with a single key corresponding to new_name
    """
    if not order:
        order = list(sorted(prop_dict.keys()))

    if len(order) > 0:
        concatenated_val = np.concatenate([prop_dict[x] for x in order], axis=1)
        output_dict = {new_name: concatenated_val}
    return output_dict

ensure_picklable(fn)

Ensure a function is picklable

Parameters:

Name Type Description Default
fn Callable

function to be pickled

required
Source code in molfeat/utils/commons.py
87
88
89
90
91
92
93
94
95
def ensure_picklable(fn: Callable):
    """Ensure a function is picklable

    Args:
        fn: function to be pickled
    """
    if inspect.isfunction(fn) and fn.__name__ == "<lambda>":
        return wrap_non_picklable_objects(fn)
    return fn

filter_arguments(fn, params)

Filter the argument of a function to only retain the valid ones

Parameters:

Name Type Description Default
fn Callable

Function for which arguments will be checked

required
params dict

key-val dictionary of arguments to pass to the input function

required

Returns:

Name Type Description
params_filtered dict

dict of filtered arguments for the function

Source code in molfeat/utils/commons.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def filter_arguments(fn: Callable, params: dict):
    """Filter the argument of a function to only retain the valid ones

    Args:
        fn: Function for which arguments will be checked
        params: key-val dictionary of arguments to pass to the input function

    Returns:
        params_filtered (dict): dict of filtered arguments for the function
    """
    accepted_dict = inspect.signature(fn).parameters
    accepted_list = []
    for key in accepted_dict.keys():
        param = str(accepted_dict[key])
        if param[0] != "*":
            accepted_list.append(param)
    params_filtered = {key: params[key] for key in list(set(accepted_list) & set(params.keys()))}
    return params_filtered

fn_to_hex(fn)

Pickle an object and return its hex representation

Parameters:

Name Type Description Default
fn

object to pickle

required

Returns:

Name Type Description
str

hex representation of object

Source code in molfeat/utils/commons.py
 98
 99
100
101
102
103
104
105
106
107
108
def fn_to_hex(fn):
    """Pickle an object and return its hex representation

    Args:
        fn: object to pickle

    Returns:
        str: hex representation of object
    """
    bytes_str = pickle.dumps(ensure_picklable(fn))
    return bytes_str.hex()

fold_count_fp(fp, dim=2 ** 10, binary=False)

Fast folding of a count fingerprint to the specified dimension

Parameters:

Name Type Description Default
fp Iterable

iterable fingerprint

required
dim int

dimension of the folded array if not provided. Defaults to 2**10.

2 ** 10
binary bool

whether to fold into a binary array or take use a count vector

False

Returns:

Name Type Description
folded

returns folded array to the provided dimension

Source code in molfeat/utils/commons.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def fold_count_fp(fp: Iterable, dim: int = 2**10, binary: bool = False):
    """Fast folding of a count fingerprint to the specified dimension

    Args:
        fp: iterable fingerprint
        dim: dimension of the folded array if not provided. Defaults to 2**10.
        binary: whether to fold into a binary array or take use a count vector

    Returns:
        folded: returns folded array to the provided dimension
    """
    if hasattr(fp, "GetNonzeroElements"):
        tmp = fp.GetNonzeroElements()
    elif hasattr(fp, "GetOnBits"):
        # try to get the dict of onbit
        on_bits = fp.GetOnBits()
        tmp = dict(zip(on_bits, np.ones(len(on_bits))))
    else:
        raise ValueError(f"Format {type(fp)} is not supported")
    out = (
        coo_matrix(
            (
                list(tmp.values()),
                (np.repeat(0, len(tmp)), [i % dim for i in tmp.keys()]),
            ),
            shape=(1, dim),
        )
        .toarray()
        .flatten()
    )
    if binary:
        out = np.clip(out, a_min=0, a_max=1)
    return out

get_class_name(cls)

Get class full name

Parameters:

Name Type Description Default
cls Type

name of the class

required
Source code in molfeat/utils/commons.py
58
59
60
61
62
63
64
65
66
67
68
def get_class_name(cls: Type):
    """Get class full name

    Args:
        cls: name of the class
    """
    module = cls.__module__
    name = cls.__qualname__
    if module is not None and module != "__builtin__":
        name = module + "." + name
    return name

hex_to_fn(hex)

Load a hex string as a callable. Raise error on fail

Parameters:

Name Type Description Default
hex str

hex string to load as a callable

required

Returns:

Name Type Description
callable

callable loaded from the hex string

Source code in molfeat/utils/commons.py
111
112
113
114
115
116
117
118
119
120
121
122
def hex_to_fn(hex: str):
    """Load a hex string as a callable. Raise error on fail

    Args:
        hex: hex string to load as a callable

    Returns:
        callable: callable loaded from the hex string
    """
    # EN: pickling with pickle is probably faster
    fn = pickle.loads(bytes.fromhex(hex))
    return fn

is_callable(func)

Check if func is a function or a callable

Source code in molfeat/utils/commons.py
33
34
35
36
37
def is_callable(func):
    r"""
    Check if func is a function or a callable
    """
    return func and (isinstance(func, FUNCTYPES) or callable(func))

one_hot_encoding(val, allowable_set, encode_unknown=False, dtype=int)

Converts a single value to a one-hot vector.

Parameters:

Name Type Description Default
val int

class to be converted into a one hot vector

required
allowable_set Iterable

a list or 1D array of allowed choices for val to take

required
dtype Callable

data type of the the return. Default = int.

int
encode_unknown bool

whether to map inputs not in allowable set to an additional last element.

False

Returns:

Type Description

A numpy 1D array of length len(allowable_set) + 1

Source code in molfeat/utils/commons.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def one_hot_encoding(
    val: int,
    allowable_set: Iterable,
    encode_unknown: bool = False,
    dtype: Callable = int,
):
    r"""Converts a single value to a one-hot vector.

    Args:
        val: class to be converted into a one hot vector
        allowable_set: a list or 1D array of allowed choices for val to take
        dtype: data type of the the return. Default = int.
        encode_unknown: whether to map inputs not in allowable set to an additional last element.

    Returns:
        A numpy 1D array of length len(allowable_set) + 1
    """

    encoding = np.zeros(len(allowable_set) + int(encode_unknown), dtype=dtype)
    # not using index of, in case, someone fuck up
    # and there are duplicates in the allowed choices
    for i, v in enumerate(allowable_set):
        if v == val:
            encoding[i] = 1
    if np.sum(encoding) == 0 and encode_unknown:  # aka not found
        encoding[-1] = 1
    return encoding

pack_bits(obj, protocol=4)

Pack an object into a bits representation

Parameters:

Name Type Description Default
obj

object to pack

required

Returns:

Name Type Description
bytes

byte-packed version of object

Source code in molfeat/utils/commons.py
355
356
357
358
359
360
361
362
363
364
def pack_bits(obj, protocol=4):
    """Pack an object into a bits representation

    Args:
        obj: object to pack

    Returns:
        bytes: byte-packed version of object
    """
    return pickle.dumps(obj, protocol=protocol)

pack_graph(batch_G, batch_x)

Pack a batch of graph and atom features into a single graph

Parameters:

Name Type Description Default
batch_G List[FloatTensor]

List of adjacency graph, each of size (n_i, n_i).

required
batch_x List[FloatTensor]

List of atom feature matrices, each of size (n_i, F), F being the number of features

required

Returns:

Type Description

new_batch_G, new_batch_x: torch.LongTensor 2D, torch.Tensor 2D This tuple represents a new arbitrary graph that contains the whole batch, and the corresponding atom feature matrix. new_batch_G has a size (N, N), with :math:N = \sum_i n_i, while new_batch_x has size (N,D)

Source code in molfeat/utils/commons.py
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
def pack_graph(
    batch_G: List[torch.FloatTensor],
    batch_x: List[torch.FloatTensor],
):
    r"""
    Pack a batch of graph and atom features into a single graph

    Args:
        batch_G: List of adjacency graph, each of size (n_i, n_i).
        batch_x: List of atom feature matrices, each of size (n_i, F), F being the number of features

    Returns:
        new_batch_G, new_batch_x: torch.LongTensor 2D, torch.Tensor 2D
            This tuple represents a new arbitrary graph that contains the whole batch,
            and the corresponding atom feature matrix. new_batch_G has a size (N, N), with :math:`N = \sum_i n_i`,
            while new_batch_x has size (N,D)
    """

    new_batch_x = torch.cat(tuple(batch_x), dim=0)
    n_neigb = new_batch_x.shape[0]
    # should be on the same device
    new_batch_G = batch_G[0].new_zeros((n_neigb, n_neigb))
    cur_ind = 0
    for g in batch_G:
        g_size = g.shape[0] + cur_ind
        new_batch_G[cur_ind:g_size, cur_ind:g_size] = g
        cur_ind = g_size
    return new_batch_G, new_batch_x

requires_conformer(calculator)

Decorator for any descriptor calculator that requires conformers

Source code in molfeat/utils/commons.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
def requires_conformer(calculator: Callable):
    """Decorator for any descriptor calculator that requires conformers"""

    # this is a method or __call__
    if inspect.getfullargspec(calculator).args[0] == "self":

        @functools.wraps(calculator)
        def calculator_wrapper(ref, mol, *args, **kwargs):
            mol = dm.to_mol(mol)
            if mol.GetNumConformers() < 1:
                raise ValueError("Expected a molecule with conformers information.")
            return calculator(ref, mol, *args, **kwargs)

    else:

        @functools.wraps(calculator)
        def calculator_wrapper(mol, *args, **kwargs):
            mol = dm.to_mol(mol)
            if mol.GetNumConformers() < 1:
                raise ValueError("Expected a molecule with conformers information.")
            return calculator(mol, *args, **kwargs)

    return calculator_wrapper

requires_standardization(calculator=None, *, disconnect_metals=True, remove_salt=True, **standardize_kwargs)

Decorator for any descriptor calculator that required standardization of the molecules Args: calculator: calculator to wrap disconnect_metals: whether to force metal disconnection remove_salt: whether to remove salt from the molecule

Source code in molfeat/utils/commons.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def requires_standardization(
    calculator: Callable = None,
    *,
    disconnect_metals: bool = True,
    remove_salt: bool = True,
    **standardize_kwargs,
):
    """Decorator for any descriptor calculator that required standardization of the molecules
    Args:
        calculator: calculator to wrap
        disconnect_metals: whether to force metal disconnection
        remove_salt: whether to remove salt from the molecule
    """

    def _standardize_mol(calculator):
        @functools.wraps(calculator)
        def wrapped_function(mol, *args, **kwargs):
            mol = _clean_mol_for_descriptors(
                mol,
                disconnect_metals=disconnect_metals,
                remove_salt=remove_salt,
                **standardize_kwargs,
            )
            return calculator(mol, *args, **kwargs)

        @functools.wraps(calculator)
        def class_wrapped_function(ref, mol, *args, **kwargs):
            if not getattr(ref, "do_not_standardize", False):
                mol = _clean_mol_for_descriptors(
                    mol,
                    disconnect_metals=disconnect_metals,
                    remove_salt=remove_salt,
                    **standardize_kwargs,
                )
            return calculator(ref, mol, *args, **kwargs)

        if inspect.getfullargspec(calculator).args[0] == "self":
            return class_wrapped_function
        return wrapped_function

    if calculator is not None:
        return _standardize_mol(calculator)
    return _standardize_mol

sha256sum(filepath)

Return the sha256 sum hash of a file or a directory

Parameters:

Name Type Description Default
filepath Union[str, PathLike]

The path to the file to compute the MD5 hash on.

required
Source code in molfeat/utils/commons.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def sha256sum(filepath: Union[str, os.PathLike]):
    """Return the sha256 sum hash of a file or a directory

    Args:
        filepath: The path to the file to compute the MD5 hash on.
    """
    if dm.fs.is_dir(filepath):
        files = list(dm.fs.glob(os.path.join(filepath, "**", "*")))
    else:
        files = [filepath]
    file_hash = hashlib.sha256()
    for filepath in sorted(files):
        with fsspec.open(filepath) as f:
            file_hash.update(f.read())  # type: ignore
    file_hash = file_hash.hexdigest()
    return file_hash

unpack_bits(bvalues)

Pack an object into a bits representation

Parameters:

Name Type Description Default
bvalues

bytes to be unpacked

required

Returns:

Name Type Description
obj

object that was packed

Source code in molfeat/utils/commons.py
367
368
369
370
371
372
373
374
375
376
def unpack_bits(bvalues):
    """Pack an object into a bits representation

    Args:
        bvalues: bytes to be unpacked

    Returns:
        obj: object that was packed
    """
    return pickle.loads(bvalues)

Require module

check(module, min_version=None, max_version=None) cached

Check if module is available for import

Parameters:

Name Type Description Default
module str

name of the module to check

required
min_version Optional[str]

optional minimum version string to check

None
max_version Optional[str]

optional maximum version string to check

None
Source code in molfeat/utils/requires.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
@functools.lru_cache()
def check(module: str, min_version: Optional[str] = None, max_version: Optional[str] = None):
    """Check if module is available for import

    Args:
        module: name of the module to check
        min_version: optional minimum version string to check
        max_version: optional maximum version string to check
    """
    imported_module = None
    version = None
    min_version = pkg_version.parse(min_version) if min_version is not None else None
    max_version = pkg_version.parse(max_version) if max_version is not None else None
    try:
        imported_module = importlib.import_module(module)
        version = getattr(imported_module, "__version__", None)
    except ImportError:
        return False
    if version is not None:
        try:
            version = pkg_version.parse(version)
        except pkg_version.InvalidVersion:
            # EN: packaging v22 removed LegacyVersion which has consequences
            version = None
    return version is None or (
        (min_version is None or version >= min_version)
        and (max_version is None or version <= max_version)
    )

mock(name)

Mock a function to raise an error

Parameters:

Name Type Description Default
name str

name of the module or function to mock

required
Source code in molfeat/utils/requires.py
38
39
40
41
42
43
44
45
def mock(name: str):
    """Mock a function to raise an error

    Args:
        name: name of the module or function to mock

    """
    return lambda: (_ for _ in ()).throw(Exception(f"{name} is not available"))

Datatype Conversion

as_numpy_array_if_possible(arr, dtype)

Convert an input array to a numpy datatype if possible Args: arr: input array dtype: optional numpy datatype

Source code in molfeat/utils/datatype.py
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
def as_numpy_array_if_possible(arr, dtype: Optional[None]):
    """Convert an input array to a numpy datatype if possible
    Args:
        arr: input array
        dtype: optional numpy datatype
    """
    with suppress(Exception):
        # we only consider auto casting to numpu
        # when the user requests 'a numpy datatype'.
        if (dtype is not None and is_dtype_numpy(dtype)) or (
            dtype in [pd.DataFrame, "dataframe", "pandas", "df"]
        ):
            # skip any non compatible type
            # meaning it should be a list of list or a list of numpy array or a 2D numpy array.
            if (
                isinstance(arr, (list, np.ndarray))
                and isinstance(arr[0], (np.ndarray, list))
                and np.isscalar(arr[0][0])
            ):
                return sk_utils.check_array(
                    arr, accept_sparse=True, force_all_finite=False, ensure_2d=False, allow_nd=True
                )
    return arr

cast(fp, dtype=None, columns=None)

Change the datatype of a list of input array

Parameters:

Name Type Description Default
fp array

Input array to cast (2D)

required
dtype Optional[Callable]

datatype to cast to

None
columns Optional[Iterable]

column names for pandas dataframe

None
Source code in molfeat/utils/datatype.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
def cast(fp, dtype: Optional[Callable] = None, columns: Optional[Iterable] = None):
    """Change the datatype of a list of input array

    Args:
        fp (array): Input array to cast (2D)
        dtype: datatype to cast to
        columns: column names for pandas dataframe
    """
    if fp is None or dtype is None:
        return fp
    if isinstance(fp, dict):
        fp = {k: cast(v, dtype=dtype, columns=columns) for k, v in fp.items()}
    elif dtype in [tuple, list]:
        fp = list(fp)
    elif is_dtype_numpy(dtype):
        if isinstance(fp, (list, tuple)) and not np.isscalar(fp[0]):
            fp = [to_numpy(fp_i, dtype=dtype) for fp_i in fp]
            fp = to_numpy(fp, dtype=dtype)
        else:
            fp = to_numpy(fp, dtype=dtype)
    elif is_dtype_tensor(dtype):
        if isinstance(fp, (list, tuple)) and not np.isscalar(fp[0]):
            tmp_fp = to_numpy(fp[0])
            if len(tmp_fp.shape) > 1:
                fp = torch.cat([to_tensor(fp_i, dtype=dtype) for fp_i in fp])
            else:
                fp = torch.stack([to_tensor(fp_i, dtype=dtype) for fp_i in fp])
        else:
            fp = to_tensor(fp, dtype=dtype)
    elif dtype in [pd.DataFrame, "dataframe", "pandas", "df"]:
        fp = [feat if feat is not None else [] for feat in fp]
        fp = pd.DataFrame(fp)
        if columns is not None:
            fp.columns = columns
    elif is_dtype_bitvect(dtype):
        fp = [to_fp(feat, sparse=(dtype == SparseBitVect)) for feat in fp]
    else:
        raise TypeError("The type {} is not supported".format(dtype))
    return fp

ensure_explicit(x)

Ensure that the input vector is not a sparse bit vector

Parameters:

Name Type Description Default
x Union[SparseBitVect, ExplicitBitVect]

input vector

required

Returns:

Name Type Description
converted

ExplicitBitVect if input is SparseBitVec, else input as is

Source code in molfeat/utils/datatype.py
20
21
22
23
24
25
26
27
28
29
30
31
def ensure_explicit(x: Union[SparseBitVect, ExplicitBitVect]):
    """Ensure that the input vector is not a sparse bit vector

    Args:
        x: input vector

    Returns:
        converted: ExplicitBitVect if input is SparseBitVec, else input as is
    """
    if isinstance(x, SparseBitVect):
        x = ConvertToExplicit(x)
    return x

is_dtype_bitvect(dtype)

Verify if the dtype is a bitvect type

Parameters:

Name Type Description Default
dtype callable

The dtype of a value. E.g. np.int32, str, torch.float

required

Returns:

Type Description

A boolean saying if the dtype is a torch dtype

Source code in molfeat/utils/datatype.py
170
171
172
173
174
175
176
177
178
179
180
181
182
def is_dtype_bitvect(dtype):
    """
    Verify if the dtype is a bitvect type

    Args:
        dtype (callable): The dtype of a value. E.g. np.int32, str, torch.float

    Returns:
        A boolean saying if the dtype is a torch dtype
    """
    return dtype in [ExplicitBitVect, SparseBitVect] or isinstance(
        dtype, (ExplicitBitVect, SparseBitVect)
    )

is_dtype_numpy(dtype)

Verify if the dtype is a numpy dtype

Parameters:

Name Type Description Default
dtype callable

The dtype of a value. E.g. np.int32, str, torch.float

required

Returns A boolean saying if the dtype is a numpy dtype

Source code in molfeat/utils/datatype.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def is_dtype_numpy(dtype):
    r"""
    Verify if the dtype is a numpy dtype

    Args:
        dtype (callable): The dtype of a value. E.g. np.int32, str, torch.float
    Returns
        A boolean saying if the dtype is a numpy dtype
    """
    # special case where user provides a type
    if isinstance(dtype, str):
        with suppress(Exception):
            dtype = np.dtype(dtype).type
    is_torch = is_dtype_tensor(dtype)
    is_num = dtype in (int, float, complex)
    if hasattr(dtype, "__module__"):
        is_numpy = dtype.__module__ == "numpy"
    else:
        is_numpy = False
    return (is_num or is_numpy) and not is_torch

is_dtype_tensor(dtype)

Verify if the dtype is a torch dtype

Parameters:

Name Type Description Default
dtype callable

The dtype of a value. E.g. np.int32, str, torch.float

required

Returns:

Type Description

A boolean saying if the dtype is a torch dtype

Source code in molfeat/utils/datatype.py
157
158
159
160
161
162
163
164
165
166
167
def is_dtype_tensor(dtype):
    r"""
    Verify if the dtype is a torch dtype

    Args:
        dtype (callable): The dtype of a value. E.g. np.int32, str, torch.float

    Returns:
        A boolean saying if the dtype is a torch dtype
    """
    return isinstance(dtype, torch.dtype) or (dtype == torch.Tensor)

is_null(obj)

Check if an obj is null (nan, None or array of nan)

Source code in molfeat/utils/datatype.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def is_null(obj):
    """Check if an obj is null (nan, None or array of nan)"""
    array_nan = False
    all_none = False
    try:
        tmp = to_numpy(obj)
        array_nan = np.all(np.isnan(tmp))
    except Exception:
        pass
    try:
        all_none = all(x is None for x in obj)
    except Exception:
        pass
    return obj is None or all_none or array_nan

to_fp(arr, bitvect=True, sparse=False)

Convert numpy array to fingerprint

Parameters:

Name Type Description Default
arr ndarray

Numpy array to convert to bitvec

required
bitvect bool

whether to assume the data is a bitvect or intvect

True
sparse bool

whether to convert to sparse bit vect

False

Returns:

Name Type Description
fp

RDKit bit vector

Source code in molfeat/utils/datatype.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def to_fp(arr: np.ndarray, bitvect: bool = True, sparse: bool = False):
    """Convert numpy array to fingerprint

    Args:
        arr: Numpy array to convert to bitvec
        bitvect: whether to assume the data is a bitvect or intvect
        sparse: whether to convert to sparse bit vect

    Returns:
        fp: RDKit bit vector
    """
    if not isinstance(arr, list) and arr.ndim > 1:
        raise ValueError("Expect a 1D array as input !")
    if not bitvect:
        fp = UIntSparseIntVect(len(arr))
        for ix, value in enumerate(arr):
            fp[ix] = int(value)
    elif sparse:
        onbits = np.where(arr == 1)[0].tolist()
        fp = SparseBitVect(arr.shape[0])
        fp.SetBitsFromList(onbits)
    else:
        arr = np.asarray(arr)
        bitstring = "".join(arr.astype(str))
        fp = CreateFromBitString(bitstring)
    return fp

to_numpy(x, copy=False, dtype=None)

Convert a tensor to numpy array.

Parameters:

Name Type Description Default
x Object

The Python object to convert.

required
copy bool

Whether to copy the memory. By default, if a tensor is already on CPU, the Numpy array will be a view of the tensor.

False
dtype callable

Optional type to cast the values to

None

Returns:

Type Description

A new Python object with the same structure as x but where the tensors are now Numpy

arrays. Not supported type are left as reference in the new object.

Source code in molfeat/utils/datatype.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def to_numpy(x, copy=False, dtype=None):
    r"""
    Convert a tensor to numpy array.

    Args:
        x (Object): The Python object to convert.
        copy (bool, optional): Whether to copy the memory.
            By default, if a tensor is already on CPU, the
            Numpy array will be a view of the tensor.
        dtype (callable, optional): Optional type to cast the values to

    Returns:
        A new Python object with the same structure as `x` but where the tensors are now Numpy
        arrays. Not supported type are left as reference in the new object.
    """
    if isinstance(x, (list, tuple, np.ndarray)) and torch.is_tensor(x[0]):
        x = [to_numpy(xi, copy=copy, dtype=dtype) for xi in x]
    if isinstance(x, np.ndarray):
        pass
    elif torch.is_tensor(x):
        x = x.cpu().detach().numpy()
        x = x.copy()
    elif isinstance(x, SparseBitVect):
        tmp = np.zeros(x.GetNumBits(), dtype=int)
        for n_bit in list(x.GetOnBits()):
            tmp[n_bit] = 1
        x = tmp
    elif isinstance(x, ExplicitBitVect):
        x = dm.fp_to_array(x)
    elif hasattr(x, "GetNonzeroElements"):
        # one of the other rdkit type
        tmp = np.zeros(x.GetLength())
        bit_idx, values = np.array(list(x.GetNonzeroElements().items())).T
        tmp[bit_idx] = values
        x = tmp
    else:
        x = np.asarray(x)
    if dtype is not None:
        x = x.astype(dtype)
    return x

to_sparse(x, dtype=None)

Converts dense tensor x to sparse format

Parameters:

Name Type Description Default
x Tensor

tensor to convert

required
dtype dtype

Enforces new data type for the output. If None, it keeps the same datatype as x (Default: None)

None

Returns: new torch.sparse Tensor

Source code in molfeat/utils/datatype.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def to_sparse(x, dtype=None):
    r"""
    Converts dense tensor x to sparse format

    Args:
        x (torch.Tensor): tensor to convert
        dtype (torch.dtype, optional): Enforces new data type for the output.
            If None, it keeps the same datatype as x (Default: None)
    Returns:
        new torch.sparse Tensor
    """

    if dtype is not None:
        x = x.type(dtype)

    x_typename = torch.typename(x).split(".")[-1]
    sparse_tensortype = getattr(torch.sparse, x_typename)

    indices = torch.nonzero(x)
    if len(indices.shape) == 0:  # if all elements are zeros
        return sparse_tensortype(*x.shape)
    indices = indices.t()
    values = x[tuple(indices[i] for i in range(indices.shape[0]))]
    return sparse_tensortype(indices, values, x.size())

to_tensor(x, gpu=False, dtype=None)

Convert a numpy array to tensor. The tensor type will be the same as the original array, unless specify otherwise

Parameters:

Name Type Description Default
x ndarray

Numpy array to convert to tensor type

required
gpu bool optional

Whether to move tensor to gpu. Default False

False
dtype dtype

Enforces new data type for the output

None

Returns:

Type Description

New torch.Tensor

Source code in molfeat/utils/datatype.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def to_tensor(x, gpu=False, dtype=None):
    r"""
    Convert a numpy array to tensor. The tensor type will be
    the same as the original array, unless specify otherwise

    Args:
        x (numpy.ndarray): Numpy array to convert to tensor type
        gpu (bool optional): Whether to move tensor to gpu. Default False
        dtype (torch.dtype, optional): Enforces new data type for the output

    Returns:
        New torch.Tensor
    """
    if not torch.is_tensor(x):
        try:
            if torch.is_tensor(x[0]):
                x = torch.stack(x)
        except Exception:
            pass
        x = torch.as_tensor(x)
    if dtype is not None:
        x = x.to(dtype=dtype)
    if gpu and torch.cuda.is_available():
        x = x.cuda()
    return x

Pooling

BartPooler

Bases: Module

Default Bart pooler as implemented in huggingface transformers The Bart pooling function focusing on the eos token ([EOS]) to get a sentence representation.

Source code in molfeat/utils/pooler.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
class BartPooler(nn.Module):
    """
    Default Bart pooler as implemented in huggingface transformers
    The Bart pooling function focusing on the eos token ([EOS]) to get a sentence representation.
    """

    def __init__(self, config, **kwargs):
        super().__init__()
        self.config = config

    def forward(
        self, h: torch.Tensor, inputs: Optional[torch.Tensor] = None, **kwargs
    ) -> torch.Tensor:
        """Forward pass of the pooling layer

        Args:
            h: hidden representation of the input sequence to pool over
            inputs: inputs tokens to the bart underlying model

        Returns:
            pooled_output: pooled representation of the input sequence
        """
        eos_mask = inputs.eq(self.config.get("eos_token_id"))
        if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
            raise ValueError("All examples must have the same number of <eos> tokens.")
        pooled_output = h[eos_mask, :].view(h.size(0), -1, h.size(-1))[:, -1, :]
        return pooled_output

forward(h, inputs=None, **kwargs)

Forward pass of the pooling layer

Parameters:

Name Type Description Default
h Tensor

hidden representation of the input sequence to pool over

required
inputs Optional[Tensor]

inputs tokens to the bart underlying model

None

Returns:

Name Type Description
pooled_output Tensor

pooled representation of the input sequence

Source code in molfeat/utils/pooler.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def forward(
    self, h: torch.Tensor, inputs: Optional[torch.Tensor] = None, **kwargs
) -> torch.Tensor:
    """Forward pass of the pooling layer

    Args:
        h: hidden representation of the input sequence to pool over
        inputs: inputs tokens to the bart underlying model

    Returns:
        pooled_output: pooled representation of the input sequence
    """
    eos_mask = inputs.eq(self.config.get("eos_token_id"))
    if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
        raise ValueError("All examples must have the same number of <eos> tokens.")
    pooled_output = h[eos_mask, :].view(h.size(0), -1, h.size(-1))[:, -1, :]
    return pooled_output

BertPooler

Bases: Module

Default Bert pooler as implemented in huggingface transformers The bert pooling function focuses on a projection of the first token ([CLS]) to get a sentence representation.

Source code in molfeat/utils/pooler.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
class BertPooler(nn.Module):
    """
    Default Bert pooler as implemented in huggingface transformers
    The bert pooling function focuses on a projection of the first token ([CLS]) to get a sentence representation.
    """

    def __init__(
        self,
        config,
        activation: Optional[Callable] = None,
        random_seed: int = None,
        **kwargs,
    ):
        super().__init__()
        self.config = config
        self.random_seed = random_seed
        if self.random_seed is not None:
            torch.manual_seed(self.random_seed)
        hidden_size = config.get("hidden_size")
        self.dense = nn.Linear(hidden_size, hidden_size)
        self.activation = nn.Tanh() if activation is None else activation

    def forward(
        self, h: torch.Tensor, inputs: Optional[torch.Tensor] = None, **kwargs
    ) -> torch.Tensor:
        """Forward pass of the pooling layer

        Args:
            h: hidden representation of the input sequence to pool over
            inputs: optional input that has been provided to the underlying bert model

        Returns:
            pooled_output: pooled representation of the input sequence
        """
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = h[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

forward(h, inputs=None, **kwargs)

Forward pass of the pooling layer

Parameters:

Name Type Description Default
h Tensor

hidden representation of the input sequence to pool over

required
inputs Optional[Tensor]

optional input that has been provided to the underlying bert model

None

Returns:

Name Type Description
pooled_output Tensor

pooled representation of the input sequence

Source code in molfeat/utils/pooler.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def forward(
    self, h: torch.Tensor, inputs: Optional[torch.Tensor] = None, **kwargs
) -> torch.Tensor:
    """Forward pass of the pooling layer

    Args:
        h: hidden representation of the input sequence to pool over
        inputs: optional input that has been provided to the underlying bert model

    Returns:
        pooled_output: pooled representation of the input sequence
    """
    # We "pool" the model by simply taking the hidden state corresponding
    # to the first token.
    first_token_tensor = h[:, 0]
    pooled_output = self.dense(first_token_tensor)
    pooled_output = self.activation(pooled_output)
    return pooled_output

GPTPooler

Bases: Module

Default GPT pooler as implemented in huggingface transformers The GPT pooling function focusing on the last non-padding token given sequence length to get a sentence representation.

Source code in molfeat/utils/pooler.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
class GPTPooler(nn.Module):
    """
    Default GPT pooler as implemented in huggingface transformers
    The GPT pooling function focusing on the last non-padding token given sequence length to get a sentence representation.
    """

    def __init__(self, config, **kwargs):
        super().__init__()
        self.config = config
        self.pad_token_id = config.get("pad_token_id")

    def forward(
        self, h: torch.Tensor, inputs: Optional[torch.Tensor] = None, **kwargs
    ) -> torch.Tensor:
        """Forward pass of the pooling layer

        Args:
            h: hidden representation of the input sequence to pool over
            inputs: inputs tokens to the bart underlying model

        Returns:
            pooled_output: pooled representation of the input sequence
        """
        batch_size, sequence_lengths = inputs.shape[:2]

        assert (
            self.pad_token_id is not None or batch_size == 1
        ), "Cannot handle batch sizes > 1 if no padding token is defined."
        if self.pad_token_id is None:
            sequence_lengths = -1
            logger.warning(
                f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
                f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
            )
        else:
            sequence_lengths = torch.ne(inputs, self.pad_token_id).sum(-1) - 1
        pooled_output = h[torch.arange(batch_size), sequence_lengths]
        return pooled_output

forward(h, inputs=None, **kwargs)

Forward pass of the pooling layer

Parameters:

Name Type Description Default
h Tensor

hidden representation of the input sequence to pool over

required
inputs Optional[Tensor]

inputs tokens to the bart underlying model

None

Returns:

Name Type Description
pooled_output Tensor

pooled representation of the input sequence

Source code in molfeat/utils/pooler.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
def forward(
    self, h: torch.Tensor, inputs: Optional[torch.Tensor] = None, **kwargs
) -> torch.Tensor:
    """Forward pass of the pooling layer

    Args:
        h: hidden representation of the input sequence to pool over
        inputs: inputs tokens to the bart underlying model

    Returns:
        pooled_output: pooled representation of the input sequence
    """
    batch_size, sequence_lengths = inputs.shape[:2]

    assert (
        self.pad_token_id is not None or batch_size == 1
    ), "Cannot handle batch sizes > 1 if no padding token is defined."
    if self.pad_token_id is None:
        sequence_lengths = -1
        logger.warning(
            f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
            f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
        )
    else:
        sequence_lengths = torch.ne(inputs, self.pad_token_id).sum(-1) - 1
    pooled_output = h[torch.arange(batch_size), sequence_lengths]
    return pooled_output

HFPooler

Bases: Module

Default Pooler based on Molfeat Pooling layer

Source code in molfeat/utils/pooler.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
class HFPooler(nn.Module):
    """Default Pooler based on Molfeat Pooling layer"""

    def __init__(self, config, dim: int = 1, name: str = "mean", **kwargs):
        super().__init__()
        self.config = config
        self.pooling = Pooling(dim=dim, name=name)

    def forward(
        self,
        h: torch.Tensor,
        inputs: Optional[torch.Tensor] = None,
        mask: Optional[torch.Tensor] = None,
        ignore_padding: bool = True,
        **kwargs,
    ) -> torch.Tensor:
        """Forward pass of the pooling layer

        Args:
            h: hidden representation of the input sequence to pool over
            inputs: optional input that has been provided to the underlying bert model
            mask: optional mask to use in place of computing the padding specific mask
            ignore_padding: whether to ignore padding tokens when pooling

        Returns:
            pooled_output: pooled representation of the input sequence
        """
        if mask is None and ignore_padding:
            mask = inputs.ne(self.config.get("pad_token_id"))
        if mask.ndim == 2:
            mask = mask.unsqueeze(-1)  # B, S, 1
        return self.pooling(h, indices=None, mask=mask)

forward(h, inputs=None, mask=None, ignore_padding=True, **kwargs)

Forward pass of the pooling layer

Parameters:

Name Type Description Default
h Tensor

hidden representation of the input sequence to pool over

required
inputs Optional[Tensor]

optional input that has been provided to the underlying bert model

None
mask Optional[Tensor]

optional mask to use in place of computing the padding specific mask

None
ignore_padding bool

whether to ignore padding tokens when pooling

True

Returns:

Name Type Description
pooled_output Tensor

pooled representation of the input sequence

Source code in molfeat/utils/pooler.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def forward(
    self,
    h: torch.Tensor,
    inputs: Optional[torch.Tensor] = None,
    mask: Optional[torch.Tensor] = None,
    ignore_padding: bool = True,
    **kwargs,
) -> torch.Tensor:
    """Forward pass of the pooling layer

    Args:
        h: hidden representation of the input sequence to pool over
        inputs: optional input that has been provided to the underlying bert model
        mask: optional mask to use in place of computing the padding specific mask
        ignore_padding: whether to ignore padding tokens when pooling

    Returns:
        pooled_output: pooled representation of the input sequence
    """
    if mask is None and ignore_padding:
        mask = inputs.ne(self.config.get("pad_token_id"))
    if mask.ndim == 2:
        mask = mask.unsqueeze(-1)  # B, S, 1
    return self.pooling(h, indices=None, mask=mask)

Pooling

Bases: Module

Perform simple pooling on a tensor over one dimension

Source code in molfeat/utils/pooler.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class Pooling(nn.Module):
    """
    Perform simple pooling on a tensor over one dimension
    """

    SUPPORTED_POOLING = ["mean", "avg", "max", "sum", "clf", None]

    def __init__(self, dim: int = 1, name: str = "max"):
        """
        Pooling for embeddings

        Args:
            dim: dimension to pool over, default is 1
            name: pooling type. Default is 'mean'.
        """
        super().__init__()
        self.dim = dim
        self.name = name

    def forward(self, x, indices: List[int] = None, mask: torch.Tensor = None) -> torch.Tensor:
        """Perform a pooling operation on the input tensor

        Args:
            x: input tensor to pull over
            indices: Subset of indices to pool over. Defaults to None for all indices.
            mask: binary mask to apply when pooling. Defaults to None, which is a matrix of 1.
                If mask is provided it takes precedence over indices.
        """
        x = torch.as_tensor(x)
        if mask is None:
            mask = torch.ones_like(x)
        if indices is not None:
            mask[:, indices] = 0
        neg_inf = torch.finfo(x.dtype).min
        if mask.ndim == 2:
            mask = mask.unsqueeze(-1)  # B, S, 1
        if self.name == "clf":
            return x[:, 0, :]
        if self.name == "max":
            tmp = x.masked_fill(mask, neg_inf)
            return torch.max(tmp, dim=self.dim)[0]
        elif self.name in ["mean", "avg"]:
            return torch.sum(x * mask, dim=self.dim) / mask.sum(self.dim)
        elif self.name == "sum":
            return torch.sum(x * mask, dim=self.dim)
        return x

__init__(dim=1, name='max')

Pooling for embeddings

Parameters:

Name Type Description Default
dim int

dimension to pool over, default is 1

1
name str

pooling type. Default is 'mean'.

'max'
Source code in molfeat/utils/pooler.py
40
41
42
43
44
45
46
47
48
49
50
def __init__(self, dim: int = 1, name: str = "max"):
    """
    Pooling for embeddings

    Args:
        dim: dimension to pool over, default is 1
        name: pooling type. Default is 'mean'.
    """
    super().__init__()
    self.dim = dim
    self.name = name

forward(x, indices=None, mask=None)

Perform a pooling operation on the input tensor

Parameters:

Name Type Description Default
x

input tensor to pull over

required
indices List[int]

Subset of indices to pool over. Defaults to None for all indices.

None
mask Tensor

binary mask to apply when pooling. Defaults to None, which is a matrix of 1. If mask is provided it takes precedence over indices.

None
Source code in molfeat/utils/pooler.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def forward(self, x, indices: List[int] = None, mask: torch.Tensor = None) -> torch.Tensor:
    """Perform a pooling operation on the input tensor

    Args:
        x: input tensor to pull over
        indices: Subset of indices to pool over. Defaults to None for all indices.
        mask: binary mask to apply when pooling. Defaults to None, which is a matrix of 1.
            If mask is provided it takes precedence over indices.
    """
    x = torch.as_tensor(x)
    if mask is None:
        mask = torch.ones_like(x)
    if indices is not None:
        mask[:, indices] = 0
    neg_inf = torch.finfo(x.dtype).min
    if mask.ndim == 2:
        mask = mask.unsqueeze(-1)  # B, S, 1
    if self.name == "clf":
        return x[:, 0, :]
    if self.name == "max":
        tmp = x.masked_fill(mask, neg_inf)
        return torch.max(tmp, dim=self.dim)[0]
    elif self.name in ["mean", "avg"]:
        return torch.sum(x * mask, dim=self.dim) / mask.sum(self.dim)
    elif self.name == "sum":
        return torch.sum(x * mask, dim=self.dim)
    return x

get_default_hgf_pooler(name, config, **kwargs)

Get default HuggingFace pooler based on the model name Args: name: name of the model config: config of the model kwargs: additional arguments to pass to the pooler

Source code in molfeat/utils/pooler.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def get_default_hgf_pooler(name, config, **kwargs):
    """Get default HuggingFace pooler based on the model name
    Args:
        name: name of the model
        config: config of the model
        kwargs: additional arguments to pass to the pooler
    """
    model_type = config.get("model_type", None)
    if name not in ["bert", "roberta", "gpt", "bart"] and name in Pooling.SUPPORTED_POOLING[:-1]:
        return HFPooler(config, name=name, **kwargs)
    names = [name]
    if model_type is not None:
        names += [model_type]
    if any(x in ["bert", "roberta"] for x in names):
        return BertPooler(config, **kwargs)
    elif any(x.startswith("gpt") for x in names):
        return GPTPooler(config, **kwargs)
    elif any(x == "bart" for x in names):
        return BartPooler(config, **kwargs)
    return None

Mol Format Converters

SmilesConverter

Molecule line notation conversion from smiles to selfies or inchi

Source code in molfeat/utils/converters.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class SmilesConverter:
    """Molecule line notation conversion from smiles to selfies or inchi"""

    SUPPORTED_LINE_NOTATIONS = [
        "none",
        "smiles",
        "selfies",
        "inchi",
    ]

    def __init__(self, target: str = None):
        """
        Convert input smiles to a target line notation

        Args:
            target: target representation.
        """
        self.target = target

        if self.target is not None and self.target not in self.SUPPORTED_LINE_NOTATIONS:
            raise ValueError(
                f"{target} is not a supported line representation. Choose from {self.SUPPORTED_LINE_NOTATIONS}"
            )

        if self.target == "smiles" or (self.target is None or self.target == "none"):
            self.converter = None
        elif self.target == "inchi":
            self.converter = types.SimpleNamespace(decode=dm.from_inchi, encode=dm.to_inchi)
        elif self.target == "selfies":
            self.converter = types.SimpleNamespace(decode=dm.from_selfies, encode=dm.to_selfies)

    def decode(self, inp: str):
        """Decode inputs into smiles

        Args:
            inp: input representation to decode
        """
        if self.converter is None:
            return inp
        with dm.without_rdkit_log():
            try:
                decoded = self.converter.decode(inp)
                return decoded.strip()
            except Exception:  # (deepsmiles.DecodeError, ValueError, AttributeError, IndexError):
                return None

    def encode(self, smiles: str):
        """Encode a input smiles into target line notation

        Args:
            smiles: input smiles to encode
        """
        if self.converter is None:
            return smiles
        with dm.without_rdkit_log():
            try:
                encoded = self.converter.encode(smiles)
                return encoded.strip()
            except Exception:
                return None

__init__(target=None)

Convert input smiles to a target line notation

Parameters:

Name Type Description Default
target str

target representation.

None
Source code in molfeat/utils/converters.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def __init__(self, target: str = None):
    """
    Convert input smiles to a target line notation

    Args:
        target: target representation.
    """
    self.target = target

    if self.target is not None and self.target not in self.SUPPORTED_LINE_NOTATIONS:
        raise ValueError(
            f"{target} is not a supported line representation. Choose from {self.SUPPORTED_LINE_NOTATIONS}"
        )

    if self.target == "smiles" or (self.target is None or self.target == "none"):
        self.converter = None
    elif self.target == "inchi":
        self.converter = types.SimpleNamespace(decode=dm.from_inchi, encode=dm.to_inchi)
    elif self.target == "selfies":
        self.converter = types.SimpleNamespace(decode=dm.from_selfies, encode=dm.to_selfies)

decode(inp)

Decode inputs into smiles

Parameters:

Name Type Description Default
inp str

input representation to decode

required
Source code in molfeat/utils/converters.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def decode(self, inp: str):
    """Decode inputs into smiles

    Args:
        inp: input representation to decode
    """
    if self.converter is None:
        return inp
    with dm.without_rdkit_log():
        try:
            decoded = self.converter.decode(inp)
            return decoded.strip()
        except Exception:  # (deepsmiles.DecodeError, ValueError, AttributeError, IndexError):
            return None

encode(smiles)

Encode a input smiles into target line notation

Parameters:

Name Type Description Default
smiles str

input smiles to encode

required
Source code in molfeat/utils/converters.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def encode(self, smiles: str):
    """Encode a input smiles into target line notation

    Args:
        smiles: input smiles to encode
    """
    if self.converter is None:
        return smiles
    with dm.without_rdkit_log():
        try:
            encoded = self.converter.encode(smiles)
            return encoded.strip()
        except Exception:
            return None