Skip to content

molfeat.trans.graph

Graphs

AdjGraphTransformer

Bases: GraphTransformer

Transforms a molecule into a molecular graph representation formed by an adjacency matrix of atoms and a set of features for each atom (and potentially bond).

Source code in molfeat/trans/graph/adj.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
class AdjGraphTransformer(GraphTransformer):
    r"""
    Transforms a molecule into a molecular graph representation formed by an
    adjacency matrix of atoms and a set of features for each atom (and potentially bond).
    """

    def __init__(
        self,
        atom_featurizer: Optional[Callable] = None,
        bond_featurizer: Optional[Callable] = None,
        self_loop: bool = False,
        explicit_hydrogens: bool = False,
        canonical_atom_order: bool = True,
        max_n_atoms: Optional[int] = None,
        n_jobs: int = 1,
        verbose: bool = False,
        dtype: Optional[Callable] = None,
        **params,
    ):
        """
        Adjacency graph transformer

        Args:
            atom_featurizer: atom featurizer to use
            bond_featurizer: bond featurizer to use
            self_loop: whether to add self loops to the adjacency matrix. Your bond featurizer needs to supports this.
            explicit_hydrogens: Whether to use explicit hydrogen in preprocessing of the input molecule
            canonical_atom_order: Whether to use a canonical ordering of the atoms
            max_n_atoms: Maximum number of atom to set the size of the graph
            n_jobs: Number of job to run in parallel. Defaults to 1.
            verbose: Verbosity level. Defaults to True.
            dtype: Output data type. Defaults to None, where numpy arrays are returned.
        """
        super().__init__(
            atom_featurizer=atom_featurizer,
            bond_featurizer=bond_featurizer,
            max_n_atoms=max_n_atoms,
            self_loop=self_loop,
            n_jobs=n_jobs,
            verbose=verbose,
            dtype=dtype,
            canonical_atom_order=canonical_atom_order,
            explicit_hydrogens=explicit_hydrogens,
            **params,
        )

    def _graph_featurizer(self, mol: dm.Mol):
        """Internal adjacency graph featurizer

        Returns:
            mat : N,N matrix representing the graph
        """
        adj_mat = GetAdjacencyMatrix(mol)
        if self.self_loop:
            np.fill_diagonal(adj_mat, 1)
        return adj_mat

    @staticmethod
    def _collate_batch(batch, max_n_atoms=None, pack=False):
        """
        Collate a batch of samples. Expected format is either single graphs, e.g. a list of tuples of the form (adj, feats),
        or graphs together with their labels, where each sample is of the form ((adj, feats), label).

        Args:
             batch: list
                Batch of samples.
             max_n_atoms: Max num atoms in graphs.
             pack: Whether the graph should be packed or not into a supergraph.

        Returns:
            Collated samples.

        """
        if isinstance(batch[0], (list, tuple)) and len(batch[0]) > 2:
            graphs, feats, labels = map(list, zip(*batch))
            batched_graph = AdjGraphTransformer._collate_graphs(
                zip(graphs, feats), max_n_atoms=max_n_atoms, pack=pack
            )

            if torch.is_tensor(labels[0]):
                return batched_graph, torch.stack(labels)
            else:
                return batched_graph, labels

        # Otherwise we assume the batch is composed of single graphs.
        return AdjGraphTransformer._collate_graphs(batch, max_n_atoms=max_n_atoms, pack=pack)

    @staticmethod
    def _collate_graphs(batch, max_n_atoms, pack):
        if not all([len(b) == 2 for b in batch]):
            raise ValueError("Default collate function only supports pair of (Graph, AtomFeats) ")

        graphs, feats = zip(*batch)
        # in case someone does not convert to tensor and wants to use collate
        # who would do that ?
        graphs = [datatype.to_tensor(g) for g in graphs]
        feats = [datatype.to_tensor(f) for f in feats]
        if pack:
            return pack_graph(graphs, feats)
        else:
            if max_n_atoms is None:
                cur_max_atoms = max([x.shape[0] for x in feats])
            else:
                cur_max_atoms = max_n_atoms

            graphs = torch.stack(
                [
                    F.pad(
                        g,
                        (0, cur_max_atoms - g.shape[0], 0, cur_max_atoms - g.shape[1]),
                    )
                    for g in graphs
                ]
            )
            feats = torch.stack([F.pad(f, (0, 0, 0, cur_max_atoms - f.shape[0])) for f in feats])
        return graphs, feats

    def get_collate_fn(self, pack: bool = False, max_n_atoms: Optional[int] = None):
        """Get collate function. Adj Graph are collated either through batching
        or diagonally packing the graph into a super graph. Either a format of (batch, labels) or graph is supported.

        !!! note
            Edge features are not supported yet in the default collate because
            there is no straightforward and universal way to collate them

        Args:
            pack : Whether to pack or batch the graphs.
            max_n_atoms: Maximum number of node per graph when packing is False.
                If the graph needs to be packed and it is not set, instance attributes will be used
        """
        if self.bond_featurizer is not None:
            raise ValueError(
                "Default collate function is not supported for transformer with bond featurizer"
            )
        max_n_atoms = max_n_atoms or self.max_n_atoms

        return partial(self._collate_batch, pack=pack, max_n_atoms=max_n_atoms)

    def transform(self, mols: List[Union[dm.Mol, str]], keep_dict: bool = False, **kwargs):
        r"""
        Compute the graph featurization for a set of molecules.

        Args:
            mols: a list containing smiles or mol objects
            keep_dict: whether to keep atom and bond featurizer as dict or get the underlying data
            kwargs: arguments to pass to the `super().transform`

         Returns:
             features: a list of features for each molecule in the input set
        """
        features = super().transform(mols, **kwargs)
        if not keep_dict:
            out = []
            for i, feat in enumerate(features):
                if feat is not None:
                    graph, nodes, *bonds = feat
                    if isinstance(nodes, dict):
                        nodes = nodes[self.atom_featurizer.name]
                    if len(bonds) > 0 and isinstance(bonds[0], dict):
                        try:
                            bonds = bonds[0][self.bond_featurizer.name]
                            feat = (graph, nodes, bonds)
                        except KeyError as e:
                            # more information on failure
                            logger.error("Encountered Molecule without bonds")
                            raise e
                    else:
                        feat = (graph, nodes)
                out.append(feat)
            features = out
        return features

    def _transform(self, mol: dm.Mol):
        r"""
        Transforms a molecule into an Adjacency graph with a set of atom and optional bond features

        Args:
            mol: molecule to transform into features

        Returns
            feat: featurized input molecule (adj_mat, node_feat) or (adj_mat, node_feat, edge_feat)

        """
        if mol is None:
            return None

        try:
            adj_matrix = datatype.cast(self._graph_featurizer(mol), dtype=self.dtype)
            atom_data = self.atom_featurizer(mol, dtype=self.dtype)
            feats = (adj_matrix, atom_data)
            bond_data = None
            if self.bond_featurizer is not None:
                bond_data = self.bond_featurizer(mol, flat=False, dtype=self.dtype)
                feats = (
                    adj_matrix,
                    atom_data,
                    bond_data,
                )
        except Exception as e:
            if self.verbose:
                logger.error(e)
            feats = None
        return feats

__init__(atom_featurizer=None, bond_featurizer=None, self_loop=False, explicit_hydrogens=False, canonical_atom_order=True, max_n_atoms=None, n_jobs=1, verbose=False, dtype=None, **params)

Adjacency graph transformer

Parameters:

Name Type Description Default
atom_featurizer Optional[Callable]

atom featurizer to use

None
bond_featurizer Optional[Callable]

bond featurizer to use

None
self_loop bool

whether to add self loops to the adjacency matrix. Your bond featurizer needs to supports this.

False
explicit_hydrogens bool

Whether to use explicit hydrogen in preprocessing of the input molecule

False
canonical_atom_order bool

Whether to use a canonical ordering of the atoms

True
max_n_atoms Optional[int]

Maximum number of atom to set the size of the graph

None
n_jobs int

Number of job to run in parallel. Defaults to 1.

1
verbose bool

Verbosity level. Defaults to True.

False
dtype Optional[Callable]

Output data type. Defaults to None, where numpy arrays are returned.

None
Source code in molfeat/trans/graph/adj.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
def __init__(
    self,
    atom_featurizer: Optional[Callable] = None,
    bond_featurizer: Optional[Callable] = None,
    self_loop: bool = False,
    explicit_hydrogens: bool = False,
    canonical_atom_order: bool = True,
    max_n_atoms: Optional[int] = None,
    n_jobs: int = 1,
    verbose: bool = False,
    dtype: Optional[Callable] = None,
    **params,
):
    """
    Adjacency graph transformer

    Args:
        atom_featurizer: atom featurizer to use
        bond_featurizer: bond featurizer to use
        self_loop: whether to add self loops to the adjacency matrix. Your bond featurizer needs to supports this.
        explicit_hydrogens: Whether to use explicit hydrogen in preprocessing of the input molecule
        canonical_atom_order: Whether to use a canonical ordering of the atoms
        max_n_atoms: Maximum number of atom to set the size of the graph
        n_jobs: Number of job to run in parallel. Defaults to 1.
        verbose: Verbosity level. Defaults to True.
        dtype: Output data type. Defaults to None, where numpy arrays are returned.
    """
    super().__init__(
        atom_featurizer=atom_featurizer,
        bond_featurizer=bond_featurizer,
        max_n_atoms=max_n_atoms,
        self_loop=self_loop,
        n_jobs=n_jobs,
        verbose=verbose,
        dtype=dtype,
        canonical_atom_order=canonical_atom_order,
        explicit_hydrogens=explicit_hydrogens,
        **params,
    )

get_collate_fn(pack=False, max_n_atoms=None)

Get collate function. Adj Graph are collated either through batching or diagonally packing the graph into a super graph. Either a format of (batch, labels) or graph is supported.

Note

Edge features are not supported yet in the default collate because there is no straightforward and universal way to collate them

Parameters:

Name Type Description Default
pack

Whether to pack or batch the graphs.

False
max_n_atoms Optional[int]

Maximum number of node per graph when packing is False. If the graph needs to be packed and it is not set, instance attributes will be used

None
Source code in molfeat/trans/graph/adj.py
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
def get_collate_fn(self, pack: bool = False, max_n_atoms: Optional[int] = None):
    """Get collate function. Adj Graph are collated either through batching
    or diagonally packing the graph into a super graph. Either a format of (batch, labels) or graph is supported.

    !!! note
        Edge features are not supported yet in the default collate because
        there is no straightforward and universal way to collate them

    Args:
        pack : Whether to pack or batch the graphs.
        max_n_atoms: Maximum number of node per graph when packing is False.
            If the graph needs to be packed and it is not set, instance attributes will be used
    """
    if self.bond_featurizer is not None:
        raise ValueError(
            "Default collate function is not supported for transformer with bond featurizer"
        )
    max_n_atoms = max_n_atoms or self.max_n_atoms

    return partial(self._collate_batch, pack=pack, max_n_atoms=max_n_atoms)

transform(mols, keep_dict=False, **kwargs)

Compute the graph featurization for a set of molecules.

Parameters:

Name Type Description Default
mols List[Union[Mol, str]]

a list containing smiles or mol objects

required
keep_dict bool

whether to keep atom and bond featurizer as dict or get the underlying data

False
kwargs

arguments to pass to the super().transform

{}

Returns: features: a list of features for each molecule in the input set

Source code in molfeat/trans/graph/adj.py
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
def transform(self, mols: List[Union[dm.Mol, str]], keep_dict: bool = False, **kwargs):
    r"""
    Compute the graph featurization for a set of molecules.

    Args:
        mols: a list containing smiles or mol objects
        keep_dict: whether to keep atom and bond featurizer as dict or get the underlying data
        kwargs: arguments to pass to the `super().transform`

     Returns:
         features: a list of features for each molecule in the input set
    """
    features = super().transform(mols, **kwargs)
    if not keep_dict:
        out = []
        for i, feat in enumerate(features):
            if feat is not None:
                graph, nodes, *bonds = feat
                if isinstance(nodes, dict):
                    nodes = nodes[self.atom_featurizer.name]
                if len(bonds) > 0 and isinstance(bonds[0], dict):
                    try:
                        bonds = bonds[0][self.bond_featurizer.name]
                        feat = (graph, nodes, bonds)
                    except KeyError as e:
                        # more information on failure
                        logger.error("Encountered Molecule without bonds")
                        raise e
                else:
                    feat = (graph, nodes)
            out.append(feat)
        features = out
    return features

CompleteGraphTransformer

Bases: GraphTransformer

Transforms a molecule into a complete graph

Source code in molfeat/trans/graph/adj.py
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
class CompleteGraphTransformer(GraphTransformer):
    """Transforms a molecule into a complete graph"""

    def _graph_featurizer(self, mol: dm.Mol):
        """Complete grah featurizer

        Args:
            mol: molecule to transform into a graph

        Returns:
            mat : N,N matrix representing the graph
        """
        n_atoms = mol.GetNumAtoms()
        adj_mat = np.ones((n_atoms, n_atoms))
        if not self.self_loop:
            np.fill_diagonal(adj_mat, 0)
        return adj_mat

DGLGraphTransformer

Bases: GraphTransformer

Transforms a molecule into a molecular graph representation formed by an adjacency matrix of atoms and a set of features for each atom (and potentially bond).

Source code in molfeat/trans/graph/adj.py
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
class DGLGraphTransformer(GraphTransformer):
    r"""
    Transforms a molecule into a molecular graph representation formed by an
    adjacency matrix of atoms and a set of features for each atom (and potentially bond).
    """

    def __init__(
        self,
        atom_featurizer: Optional[Callable] = None,
        bond_featurizer: Optional[Callable] = None,
        self_loop: bool = False,
        explicit_hydrogens: bool = False,
        canonical_atom_order: bool = True,
        complete_graph: bool = False,
        num_virtual_nodes: int = 0,
        n_jobs: int = 1,
        verbose: bool = False,
        dtype: Optional[Callable] = None,
        **params,
    ):
        """
        Adjacency graph transformer

        Args:
           atom_featurizer: atom featurizer to use
           bond_featurizer: atom featurizer to use
           self_loop: whether to use self loop or not
           explicit_hydrogens: Whether to use explicit hydrogen in preprocessing of the input molecule
           canonical_atom_order: Whether to use a canonical ordering of the atoms
           complete_graph: Whether to use a complete graph constructor or not
           num_virtual_nodes: number of virtual nodes to add
           n_jobs: Number of job to run in parallel. Defaults to 1.
           verbose: Verbosity level. Defaults to True.
           dtype: Output data type. Defaults to None, where numpy arrays are returned.
        """

        super().__init__(
            atom_featurizer=atom_featurizer,
            bond_featurizer=bond_featurizer,
            n_jobs=n_jobs,
            self_loop=self_loop,
            num_virtual_nodes=num_virtual_nodes,
            complete_graph=complete_graph,
            verbose=verbose,
            dtype=dtype,
            canonical_atom_order=canonical_atom_order,
            explicit_hydrogens=explicit_hydrogens,
            **params,
        )

        if not requires.check("dgllife"):
            logger.error(
                "Cannot find dgllife. It's required for some features. Please install it first !"
            )
        if not requires.check("dgl"):
            raise ValueError("Cannot find dgl, please install it first !")
        if self.dtype is not None and not datatype.is_dtype_tensor(self.dtype):
            raise ValueError("DGL featurizer only supports torch tensors currently")

    def auto_self_loop(self):
        """Patch the featurizer to auto support self loop based on the bond featurizer characteristics"""
        super().auto_self_loop()
        if isinstance(self.bond_featurizer, EdgeMatCalculator):
            self.self_loop = True

    def get_collate_fn(self, *args, **kwargs):
        """Return DGL collate function for a batch of molecular graph"""
        return self._dgl_collate

    @staticmethod
    def _dgl_collate(batch):
        """
        Batch of samples to be used with the featurizer. A sample of the batch is expected to
        be of the form (graph, label) or simply a graph.

        Args:
         batch: list
            batch of samples.

        returns:
            Batched lists of graphs and labels
        """
        if isinstance(batch[0], (list, tuple)):
            graphs, labels = map(list, zip(*batch))
            batched_graph = dgl.batch(graphs)

            if torch.is_tensor(labels[0]):
                return batched_graph, torch.stack(labels)
            else:
                return batched_graph, labels

        # Otherwise we assume the batch is composed of single graphs.
        return dgl.batch(batch)

    def _graph_featurizer(self, mol: dm.Mol):
        """Convert a molecule to a DGL graph.

        This only supports the bigraph and not any virtual nodes or complete graph.

        Args:
            mol (dm.Mol): molecule to transform into features

        Returns:
            graph (dgl.DGLGraph): graph built with dgl
        """

        n_atoms = mol.GetNumAtoms()
        num_bonds = mol.GetNumBonds()
        graph = dgl.graph()
        graph.add_nodes(n_atoms)
        bond_src = []
        bond_dst = []
        for i in range(num_bonds):
            bond = mol.GetBondWithIdx(i)
            begin_idx = bond.GetBeginAtom().GetIdx()
            end_idx = bond.GetEndAtom().GetIdx()
            bond_src.append(begin_idx)
            bond_dst.append(end_idx)
            # set up the reverse direction
            bond_src.append(end_idx)
            bond_dst.append(begin_idx)

        if self.self_loop:
            nodes = graph.nodes().tolist()
            bond_src.extend(nodes)
            bond_dst.extend(nodes)

        graph.add_edges(bond_src, bond_dst)
        return graph

    @property
    def atom_dim(self):
        return super(DGLGraphTransformer, self).atom_dim + int(self.num_virtual_nodes > 0)

    @property
    def bond_dim(self):
        return super(DGLGraphTransformer, self).bond_dim + int(self.num_virtual_nodes > 0)

    def _transform(self, mol: dm.Mol):
        r"""
        Transforms a molecule into an Adjacency graph with a set of atom and bond features

        Args:
            mol (dm.Mol): molecule to transform into features

        Returns
            graph (dgl.DGLGraph): a dgl graph containing atoms and bond data

        """
        if mol is None:
            return None

        graph = None
        if requires.check("dgllife"):
            graph_featurizer = dgllife_utils.mol_to_bigraph

            if self.complete_graph:
                graph_featurizer = dgllife_utils.mol_to_complete_graph
            try:
                graph = graph_featurizer(
                    mol,
                    add_self_loop=self.self_loop,
                    node_featurizer=self.__recast(self.atom_featurizer),
                    edge_featurizer=self.__recast(self.bond_featurizer),
                    canonical_atom_order=self.canonical_atom_order,
                    explicit_hydrogens=self.explicit_hydrogens,
                    num_virtual_nodes=self.num_virtual_nodes,
                )
            except Exception as e:
                if self.verbose:
                    logger.error(e)
                graph = None

        elif requires.check("dgl") and not self.complete_graph:
            # we need to build the graph ourselves.
            graph = self._graph_featurizer(mol)
            if self.atom_featurizer is not None:
                graph.ndata.update(self.atom_featurizer(mol, dtype=self.dtype))

            if self.bond_featurizer is not None:
                graph.edata.update(self.bond_featurizer(mol, dtype=self.dtype))

        else:
            raise ValueError(
                "Incorrect setup, please install missing packages (dgl, dgllife) for more features"
            )
        return graph

    def __recast(self, featurizer: Callable):
        """Recast the output of a featurizer to the transformer underlying type

        Args:
            featurizer: featurizer to patch
        """
        if featurizer is None:
            return None
        dtype = self.dtype or torch.float

        def patch_feats(*args, **kwargs):
            out_dict = featurizer(*args, **kwargs)
            out_dict = {k: datatype.cast(val, dtype=dtype) for k, val in out_dict.items()}
            return out_dict

        return patch_feats

__init__(atom_featurizer=None, bond_featurizer=None, self_loop=False, explicit_hydrogens=False, canonical_atom_order=True, complete_graph=False, num_virtual_nodes=0, n_jobs=1, verbose=False, dtype=None, **params)

Adjacency graph transformer

Parameters:

Name Type Description Default
atom_featurizer Optional[Callable]

atom featurizer to use

None
bond_featurizer Optional[Callable]

atom featurizer to use

None
self_loop bool

whether to use self loop or not

False
explicit_hydrogens bool

Whether to use explicit hydrogen in preprocessing of the input molecule

False
canonical_atom_order bool

Whether to use a canonical ordering of the atoms

True
complete_graph bool

Whether to use a complete graph constructor or not

False
num_virtual_nodes int

number of virtual nodes to add

0
n_jobs int

Number of job to run in parallel. Defaults to 1.

1
verbose bool

Verbosity level. Defaults to True.

False
dtype Optional[Callable]

Output data type. Defaults to None, where numpy arrays are returned.

None
Source code in molfeat/trans/graph/adj.py
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
def __init__(
    self,
    atom_featurizer: Optional[Callable] = None,
    bond_featurizer: Optional[Callable] = None,
    self_loop: bool = False,
    explicit_hydrogens: bool = False,
    canonical_atom_order: bool = True,
    complete_graph: bool = False,
    num_virtual_nodes: int = 0,
    n_jobs: int = 1,
    verbose: bool = False,
    dtype: Optional[Callable] = None,
    **params,
):
    """
    Adjacency graph transformer

    Args:
       atom_featurizer: atom featurizer to use
       bond_featurizer: atom featurizer to use
       self_loop: whether to use self loop or not
       explicit_hydrogens: Whether to use explicit hydrogen in preprocessing of the input molecule
       canonical_atom_order: Whether to use a canonical ordering of the atoms
       complete_graph: Whether to use a complete graph constructor or not
       num_virtual_nodes: number of virtual nodes to add
       n_jobs: Number of job to run in parallel. Defaults to 1.
       verbose: Verbosity level. Defaults to True.
       dtype: Output data type. Defaults to None, where numpy arrays are returned.
    """

    super().__init__(
        atom_featurizer=atom_featurizer,
        bond_featurizer=bond_featurizer,
        n_jobs=n_jobs,
        self_loop=self_loop,
        num_virtual_nodes=num_virtual_nodes,
        complete_graph=complete_graph,
        verbose=verbose,
        dtype=dtype,
        canonical_atom_order=canonical_atom_order,
        explicit_hydrogens=explicit_hydrogens,
        **params,
    )

    if not requires.check("dgllife"):
        logger.error(
            "Cannot find dgllife. It's required for some features. Please install it first !"
        )
    if not requires.check("dgl"):
        raise ValueError("Cannot find dgl, please install it first !")
    if self.dtype is not None and not datatype.is_dtype_tensor(self.dtype):
        raise ValueError("DGL featurizer only supports torch tensors currently")

__recast(featurizer)

Recast the output of a featurizer to the transformer underlying type

Parameters:

Name Type Description Default
featurizer Callable

featurizer to patch

required
Source code in molfeat/trans/graph/adj.py
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
def __recast(self, featurizer: Callable):
    """Recast the output of a featurizer to the transformer underlying type

    Args:
        featurizer: featurizer to patch
    """
    if featurizer is None:
        return None
    dtype = self.dtype or torch.float

    def patch_feats(*args, **kwargs):
        out_dict = featurizer(*args, **kwargs)
        out_dict = {k: datatype.cast(val, dtype=dtype) for k, val in out_dict.items()}
        return out_dict

    return patch_feats

auto_self_loop()

Patch the featurizer to auto support self loop based on the bond featurizer characteristics

Source code in molfeat/trans/graph/adj.py
525
526
527
528
529
def auto_self_loop(self):
    """Patch the featurizer to auto support self loop based on the bond featurizer characteristics"""
    super().auto_self_loop()
    if isinstance(self.bond_featurizer, EdgeMatCalculator):
        self.self_loop = True

get_collate_fn(*args, **kwargs)

Return DGL collate function for a batch of molecular graph

Source code in molfeat/trans/graph/adj.py
531
532
533
def get_collate_fn(self, *args, **kwargs):
    """Return DGL collate function for a batch of molecular graph"""
    return self._dgl_collate

DistGraphTransformer3D

Bases: AdjGraphTransformer

Graph featurizer using the 3D distance between pair of atoms for the adjacency matrix The self_loop attribute is ignored here as the distance between an atom and itself is 0.

Source code in molfeat/trans/graph/adj.py
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
class DistGraphTransformer3D(AdjGraphTransformer):
    """
    Graph featurizer using the 3D distance between pair of atoms for the adjacency matrix
    The `self_loop` attribute is ignored here as the distance between an atom and itself is 0.

    """

    @requires_conformer
    def _graph_featurizer(self, mol: dm.Mol):
        """Graph topological distance featurizer

        Args:
            mol: molecule to transform into a graph

        Returns:
            mat : N,N matrix representing the graph
        """
        return Get3DDistanceMatrix(mol)

GraphTransformer

Bases: MoleculeTransformer

Base class for all graph transformers including DGL

Source code in molfeat/trans/graph/adj.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
class GraphTransformer(MoleculeTransformer):
    """
    Base class for all graph transformers including DGL
    """

    def __init__(
        self,
        atom_featurizer: Optional[Callable] = None,
        bond_featurizer: Optional[Callable] = None,
        explicit_hydrogens: bool = False,
        canonical_atom_order: bool = True,
        self_loop: bool = False,
        n_jobs: int = 1,
        verbose: bool = False,
        dtype: Optional[Callable] = None,
        **params,
    ):
        """Mol to Graph transformer base class

        Args:
            atom_featurizer: atom featurizer to use
            bond_featurizer: atom featurizer to use
            explicit_hydrogens: Whether to use explicit hydrogen in preprocessing of the input molecule
            canonical_atom_order: Whether to use a canonical ordering of the atoms
            self_loop: Whether to add self loops or not
            n_jobs: Number of job to run in parallel. Defaults to 1.
            verbose: Verbosity level. Defaults to True.
            dtype: Output data type. Defaults to None
        """

        self._save_input_args()

        super().__init__(
            n_jobs=n_jobs,
            verbose=verbose,
            dtype=dtype,
            featurizer="none",
            self_loop=self_loop,
            canonical_atom_order=canonical_atom_order,
            explicit_hydrogens=explicit_hydrogens,
            **params,
        )
        if atom_featurizer is None:
            atom_featurizer = AtomCalculator()
        self.atom_featurizer = atom_featurizer
        self.bond_featurizer = bond_featurizer
        self._atom_dim = None
        self._bond_dim = None

    def auto_self_loop(self):
        """Patch the featurizer to auto support self loop based on the bond featurizer characteristics"""
        bf_self_loop = None
        if self.bond_featurizer is not None:
            bf_self_loop = getattr(self.bond_featurizer, "self_loop", None)
            bf_self_loop = bf_self_loop or getattr(self.bond_featurizer, "_self_loop", None)
        if bf_self_loop is not None:
            self.self_loop = bf_self_loop

    def preprocess(self, inputs, labels=None):
        """Preprocess list of input molecules
        Args:
            labels: For compatibility
        """
        inputs, labels = super().preprocess(inputs, labels)
        new_inputs = []
        for m in inputs:
            try:
                mol = dm.to_mol(
                    m, add_hs=self.explicit_hydrogens, ordered=self.canonical_atom_order
                )
            except Exception:
                mol = None
            new_inputs.append(mol)

        return new_inputs, labels

    def fit(self, **fit_params):
        """Fit the current transformer on given dataset."""
        if self.verbose:
            logger.error("GraphTransformer featurizers cannot be fitted !")
        return self

    @property
    def atom_dim(self):
        r"""
        Get the number of features per atom

        Returns:
            atom_dim (int): Number of atom features
        """
        if self._atom_dim is None:
            try:
                self._atom_dim = len(self.atom_featurizer)
            except Exception:
                _toy_mol = dm.to_mol("C")
                out = self.atom_featurizer(_toy_mol)
                self._atom_dim = sum([x.shape[-1] for x in out.values()])
        return self._atom_dim

    @property
    def bond_dim(self):
        r"""
        Get the number of features for a bond

        Returns:
            bond_dim (int): Number of bond features
        """
        if self.bond_featurizer is None:
            self._bond_dim = 0
        if self._bond_dim is None:
            try:
                self._bond_dim = len(self.bond_featurizer)
            except Exception:
                _toy_mol = dm.to_mol("CO")
                out = self.bond_featurizer(_toy_mol)
                self._bond_dim = sum([x.shape[-1] for x in out.values()])
        return self._bond_dim

    def _transform(self, mol: dm.Mol):
        r"""
        Compute features for a single molecule.
        This method would potentially need to be reimplemented by child classes

        Args:
            mol: molecule to transform into features

        Returns
            feat: featurized input molecule

        """
        raise NotImplementedError

    def __call__(self, mols: List[Union[dm.Mol, str]], ignore_errors: bool = False, **kwargs):
        r"""
        Calculate features for molecules. Using __call__, instead of transform.
        Note that most Transfomers allow you to specify
        a return datatype.

        Args:
            mols:  Mol or SMILES of the molecules to be transformed
            ignore_errors: Whether to ignore errors during featurization or raise an error.
            kwargs: Named parameters for the transform method

        Returns:
            feats: list of valid features
            ids: all valid molecule positions that did not failed during featurization
                Only returned when ignore_errors is True.

        """
        features = self.transform(mols, ignore_errors=ignore_errors, **kwargs)
        if not ignore_errors:
            return features
        features, ids = self._filter_none(features)
        return features, ids

atom_dim property

Get the number of features per atom

Returns:

Name Type Description
atom_dim int

Number of atom features

bond_dim property

Get the number of features for a bond

Returns:

Name Type Description
bond_dim int

Number of bond features

__call__(mols, ignore_errors=False, **kwargs)

Calculate features for molecules. Using call, instead of transform. Note that most Transfomers allow you to specify a return datatype.

Parameters:

Name Type Description Default
mols List[Union[Mol, str]]

Mol or SMILES of the molecules to be transformed

required
ignore_errors bool

Whether to ignore errors during featurization or raise an error.

False
kwargs

Named parameters for the transform method

{}

Returns:

Name Type Description
feats

list of valid features

ids

all valid molecule positions that did not failed during featurization Only returned when ignore_errors is True.

Source code in molfeat/trans/graph/adj.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def __call__(self, mols: List[Union[dm.Mol, str]], ignore_errors: bool = False, **kwargs):
    r"""
    Calculate features for molecules. Using __call__, instead of transform.
    Note that most Transfomers allow you to specify
    a return datatype.

    Args:
        mols:  Mol or SMILES of the molecules to be transformed
        ignore_errors: Whether to ignore errors during featurization or raise an error.
        kwargs: Named parameters for the transform method

    Returns:
        feats: list of valid features
        ids: all valid molecule positions that did not failed during featurization
            Only returned when ignore_errors is True.

    """
    features = self.transform(mols, ignore_errors=ignore_errors, **kwargs)
    if not ignore_errors:
        return features
    features, ids = self._filter_none(features)
    return features, ids

__init__(atom_featurizer=None, bond_featurizer=None, explicit_hydrogens=False, canonical_atom_order=True, self_loop=False, n_jobs=1, verbose=False, dtype=None, **params)

Mol to Graph transformer base class

Parameters:

Name Type Description Default
atom_featurizer Optional[Callable]

atom featurizer to use

None
bond_featurizer Optional[Callable]

atom featurizer to use

None
explicit_hydrogens bool

Whether to use explicit hydrogen in preprocessing of the input molecule

False
canonical_atom_order bool

Whether to use a canonical ordering of the atoms

True
self_loop bool

Whether to add self loops or not

False
n_jobs int

Number of job to run in parallel. Defaults to 1.

1
verbose bool

Verbosity level. Defaults to True.

False
dtype Optional[Callable]

Output data type. Defaults to None

None
Source code in molfeat/trans/graph/adj.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def __init__(
    self,
    atom_featurizer: Optional[Callable] = None,
    bond_featurizer: Optional[Callable] = None,
    explicit_hydrogens: bool = False,
    canonical_atom_order: bool = True,
    self_loop: bool = False,
    n_jobs: int = 1,
    verbose: bool = False,
    dtype: Optional[Callable] = None,
    **params,
):
    """Mol to Graph transformer base class

    Args:
        atom_featurizer: atom featurizer to use
        bond_featurizer: atom featurizer to use
        explicit_hydrogens: Whether to use explicit hydrogen in preprocessing of the input molecule
        canonical_atom_order: Whether to use a canonical ordering of the atoms
        self_loop: Whether to add self loops or not
        n_jobs: Number of job to run in parallel. Defaults to 1.
        verbose: Verbosity level. Defaults to True.
        dtype: Output data type. Defaults to None
    """

    self._save_input_args()

    super().__init__(
        n_jobs=n_jobs,
        verbose=verbose,
        dtype=dtype,
        featurizer="none",
        self_loop=self_loop,
        canonical_atom_order=canonical_atom_order,
        explicit_hydrogens=explicit_hydrogens,
        **params,
    )
    if atom_featurizer is None:
        atom_featurizer = AtomCalculator()
    self.atom_featurizer = atom_featurizer
    self.bond_featurizer = bond_featurizer
    self._atom_dim = None
    self._bond_dim = None

auto_self_loop()

Patch the featurizer to auto support self loop based on the bond featurizer characteristics

Source code in molfeat/trans/graph/adj.py
 95
 96
 97
 98
 99
100
101
102
def auto_self_loop(self):
    """Patch the featurizer to auto support self loop based on the bond featurizer characteristics"""
    bf_self_loop = None
    if self.bond_featurizer is not None:
        bf_self_loop = getattr(self.bond_featurizer, "self_loop", None)
        bf_self_loop = bf_self_loop or getattr(self.bond_featurizer, "_self_loop", None)
    if bf_self_loop is not None:
        self.self_loop = bf_self_loop

fit(**fit_params)

Fit the current transformer on given dataset.

Source code in molfeat/trans/graph/adj.py
122
123
124
125
126
def fit(self, **fit_params):
    """Fit the current transformer on given dataset."""
    if self.verbose:
        logger.error("GraphTransformer featurizers cannot be fitted !")
    return self

preprocess(inputs, labels=None)

Preprocess list of input molecules Args: labels: For compatibility

Source code in molfeat/trans/graph/adj.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def preprocess(self, inputs, labels=None):
    """Preprocess list of input molecules
    Args:
        labels: For compatibility
    """
    inputs, labels = super().preprocess(inputs, labels)
    new_inputs = []
    for m in inputs:
        try:
            mol = dm.to_mol(
                m, add_hs=self.explicit_hydrogens, ordered=self.canonical_atom_order
            )
        except Exception:
            mol = None
        new_inputs.append(mol)

    return new_inputs, labels

PYGGraphTransformer

Bases: AdjGraphTransformer

Graph transformer for the PYG models

Source code in molfeat/trans/graph/adj.py
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
class PYGGraphTransformer(AdjGraphTransformer):
    """
    Graph transformer for the PYG models
    """

    def _graph_featurizer(self, mol: dm.Mol):
        # we have used bond_calculator, therefore we need to
        # go over the molecules and fetch the proper bond info from the atom idx
        if self.bond_featurizer is None or (
            isinstance(self.bond_featurizer, EdgeMatCalculator)
            or hasattr(self.bond_featurizer, "pairwise_atom_funcs")
        ):
            graph = super()._graph_featurizer(mol)
            (rows, cols) = np.nonzero(graph)
            return np.vstack((rows, cols))

        # we have a regular bond calculator here instead of all pairwise atoms
        graph = []
        for i in range(mol.GetNumBonds()):
            bond = mol.GetBondWithIdx(i)
            a_idx_1 = bond.GetBeginAtomIdx()
            a_idx_2 = bond.GetEndAtomIdx()
            graph += [[a_idx_1, a_idx_2], [a_idx_2, a_idx_1]]
        if getattr(self.bond_featurizer, "_self_loop", False):
            graph.extend([[atom_ind, atom_ind] for atom_ind in range(mol.GetNumAtoms())])
        graph = np.asarray(graph).T
        return graph

    def _convert_feat_to_data_point(
        self,
        graph: np.ndarray,
        node_feat: np.ndarray,
        bond_feat: Optional[np.ndarray] = None,
    ):
        """Convert extracted graph features to a pyg Data object
        Args:
            graph: graph adjacency matrix
            node_feat: node features
            bond_feat: bond features

        Returns:
            datapoint: a pyg Data object
        """
        node_feat = torch.tensor(node_feat, dtype=torch.float32)
        # construct edge index array E of shape (2, n_edges)
        graph = torch.LongTensor(graph).view(2, -1)

        if bond_feat is not None:
            bond_feat = torch.tensor(bond_feat, dtype=torch.float32)
            if bond_feat.ndim == 3:
                bond_feat = bond_feat[graph[0, :], graph[1, :]]

        d = Data(x=node_feat, edge_index=graph, edge_attr=bond_feat)
        return d

    def transform(self, mols: List[Union[dm.Mol, str]], **kwargs):
        r"""
        Compute the graph featurization for a set of molecules.

        Args:
            mols: a list containing smiles or mol objects
            kwargs: arguments to pass to the `super().transform`

         Returns:
             features: a list of Data point for each molecule in the input set
        """
        features = super().transform(mols, keep_dict=False, **kwargs)
        return [self._convert_feat_to_data_point(*feat) for feat in features]

    def get_collate_fn(
        self,
        dataset: Optional[Union[PygDataset, Sequence[BaseData], DatasetAdapter]] = None,
        follow_batch: Optional[List[str]] = None,
        exclude_keys: Optional[List[str]] = None,
        return_pair: Optional[bool] = True,
        **kwargs,
    ):
        """
        Get collate function for pyg graphs.
        Note: The `collate_fn` is not required when using `torch_geometric.loader.dataloader.DataLoader`.

        Args:
            dataset: The dataset from which to load the data and apply the collate function.
                This is required if the dataset is <torch_geometric.data.on_disk_dataset.OnDiskDataset>.
            follow_batch: Creates assignment batch vectors for each key in the list. (default: :obj:`None`)
            exclude_keys: Will exclude each key in the list. (default: :obj:`None`)
            return_pair: whether to return a pair of X,y or a databatch (default: :obj:`True`)

        Returns:
            Collated samples.

        See Also:
            <torch_geometric.loader.dataloader.Collator>
            <torch_geometric.loader.dataloader.DataLoader>
        """
        collator = Collater(dataset=dataset, follow_batch=follow_batch, exclude_keys=exclude_keys)
        return partial(self._collate_batch, collator=collator, return_pair=return_pair)

    @staticmethod
    def _collate_batch(batch, collator: Callable, return_pair: bool = False, **kwargs):
        """
        Collate a batch of samples.

        Args:
            batch: Batch of samples.
            collator: collator function
            return_pair: whether to return a pair of (X,y) a databatch
        Returns:
            Collated samples.
        """
        if isinstance(batch[0], (list, tuple)) and len(batch[0]) > 1:
            graphs, labels = map(list, zip(*batch))
            for graph, label in zip(graphs, labels):
                graph.y = label
            batch = graphs
        batch = collator(batch)
        if return_pair:
            return (batch, batch.y)
        return batch

get_collate_fn(dataset=None, follow_batch=None, exclude_keys=None, return_pair=True, **kwargs)

Get collate function for pyg graphs. Note: The collate_fn is not required when using torch_geometric.loader.dataloader.DataLoader.

Parameters:

Name Type Description Default
dataset Optional[Union[Dataset, Sequence[BaseData], DatasetAdapter]]

The dataset from which to load the data and apply the collate function. This is required if the dataset is .

None
follow_batch Optional[List[str]]

Creates assignment batch vectors for each key in the list. (default: :obj:None)

None
exclude_keys Optional[List[str]]

Will exclude each key in the list. (default: :obj:None)

None
return_pair Optional[bool]

whether to return a pair of X,y or a databatch (default: :obj:True)

True

Returns:

Type Description

Collated samples.

See Also

Source code in molfeat/trans/graph/adj.py
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
def get_collate_fn(
    self,
    dataset: Optional[Union[PygDataset, Sequence[BaseData], DatasetAdapter]] = None,
    follow_batch: Optional[List[str]] = None,
    exclude_keys: Optional[List[str]] = None,
    return_pair: Optional[bool] = True,
    **kwargs,
):
    """
    Get collate function for pyg graphs.
    Note: The `collate_fn` is not required when using `torch_geometric.loader.dataloader.DataLoader`.

    Args:
        dataset: The dataset from which to load the data and apply the collate function.
            This is required if the dataset is <torch_geometric.data.on_disk_dataset.OnDiskDataset>.
        follow_batch: Creates assignment batch vectors for each key in the list. (default: :obj:`None`)
        exclude_keys: Will exclude each key in the list. (default: :obj:`None`)
        return_pair: whether to return a pair of X,y or a databatch (default: :obj:`True`)

    Returns:
        Collated samples.

    See Also:
        <torch_geometric.loader.dataloader.Collator>
        <torch_geometric.loader.dataloader.DataLoader>
    """
    collator = Collater(dataset=dataset, follow_batch=follow_batch, exclude_keys=exclude_keys)
    return partial(self._collate_batch, collator=collator, return_pair=return_pair)

transform(mols, **kwargs)

Compute the graph featurization for a set of molecules.

Parameters:

Name Type Description Default
mols List[Union[Mol, str]]

a list containing smiles or mol objects

required
kwargs

arguments to pass to the super().transform

{}

Returns: features: a list of Data point for each molecule in the input set

Source code in molfeat/trans/graph/adj.py
727
728
729
730
731
732
733
734
735
736
737
738
739
def transform(self, mols: List[Union[dm.Mol, str]], **kwargs):
    r"""
    Compute the graph featurization for a set of molecules.

    Args:
        mols: a list containing smiles or mol objects
        kwargs: arguments to pass to the `super().transform`

     Returns:
         features: a list of Data point for each molecule in the input set
    """
    features = super().transform(mols, keep_dict=False, **kwargs)
    return [self._convert_feat_to_data_point(*feat) for feat in features]

TopoDistGraphTransformer

Bases: AdjGraphTransformer

Graph featurizer using the topological distance between each pair of nodes instead of the adjacency matrix.

The self_loop attribute is ignored here as the distance between an atom and itself is 0.

Source code in molfeat/trans/graph/adj.py
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
class TopoDistGraphTransformer(AdjGraphTransformer):
    """
    Graph featurizer using the topological distance between each pair
    of nodes instead of the adjacency matrix.

    The `self_loop` attribute is ignored here as the distance between an atom and itself is 0.
    """

    def _graph_featurizer(self, mol: dm.Mol):
        """Graph topological distance featurizer

        Args:
            mol: molecule to transform into a graph

        Returns:
            mat : N,N matrix representing the graph
        """
        return GetDistanceMatrix(mol)

Tree

MolTreeDecompositionTransformer

Bases: MoleculeTransformer

Transforms a molecule into a tree structure whose nodes correspond to different functional groups.

Source code in molfeat/trans/graph/tree.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
class MolTreeDecompositionTransformer(MoleculeTransformer):
    r"""
    Transforms a molecule into a tree structure whose nodes correspond to different functional groups.
    """

    def __init__(
        self,
        vocab: Optional[Iterable] = None,
        one_hot: bool = False,
        dtype: Optional[Callable] = None,
        cache: bool = True,
        **params,
    ):
        """MolTree featurizer

        Args:
            vocab: List of the smiles of the functional groups or clusters.
                If None, the transformer should be fiited before any usage.
            one_hot (bool, optional): Whether or not for a tree a 1d array or a 2d array is returned as features
                If 1d array, vocabulary elements are mapped into integers,
                otherwise, vocabulary elements  ar mapped into one-hot vectors
            cache: Whether to cache the tree decomposition to avoid recomputing for seen molecules
            dtype: Output data type. Defaults to None

        Attributes:
            vocab: Mapping from clusters to integers
            vocab_size: The number of clusters + 1
            one_hot: Whether or not for a sequence a 1d array or a 2d array is returned as features
        """

        self._save_input_args()

        super().__init__(
            dtype=dtype,
            one_hot=one_hot,
            cache=cache,
            featurizer=TreeDecomposer(cache=cache),
            **params,
        )
        if vocab is not None:
            self.vocab = vocab
            self._vocab_size = len(self.vocab) + 1
            self._fitted = True
        else:
            self.vocab = None
            self._vocab_size = None
            self._fitted = False

        if not requires.check("dgl"):
            raise ValueError("dgl is required for this featurizer, please install it first")

        if self.dtype is not None and not datatype.is_dtype_tensor(self.dtype):
            raise ValueError("DGL featurizer only supports torch tensors currently")

    @property
    def vocab_size(self):
        """Compute vocab size of this featurizer

        Returns:
            size: vocab size
        """
        return self._vocab_size

    def fit(
        self,
        X: List[Union[dm.Mol, str]],
        y: Optional[list] = None,
        output_file: Optional[os.PathLike] = None,
        **fit_params,
    ):
        """Fit the current transformer on given dataset.

        The goal of fitting is for example to identify nan columns values
        that needs to be removed from the dataset

        Args:
            X: input list of molecules
            y (list, optional): Optional list of molecular properties. Defaults to None.
            output_file: path to a file that will be used to store the generated set of fragments.
            fit_params: key val of additional fit parameters


        Returns:
            self: MolTransformer instance after fitting
        """
        if self.vocab is not None:
            logger.warning("The previous vocabulary of fragments will be erased.")
        self.vocab = self.featurizer.get_vocab(X, output_file=output_file, log=self.verbose)
        self._vocab_size = len(self.vocab) + 1
        self._fitted = True

        # save the vocab in the state
        self._input_args["vocab"] = self.vocab

        return self

    def _transform(self, mol: dm.Mol):
        r"""
        Compute features for a single molecule.
        This method would potentially need to be reimplemented by child classes

        Args:
            mol (dm.Mol): molecule to transform into features

        Returns
            feat: featurized input molecule

        """
        if not self._fitted:
            raise ValueError(
                "Need to call the fit function before any transformation. \
                Or provide the fragments vocabulary at the object construction"
            )

        try:
            _, edges, fragments = self.featurizer(mol)
            n_nodes = len(fragments)
            enc = [self.vocab.index(f) + 1 if f in self.vocab else 0 for f in fragments]
            enc = datatype.cast(enc, (self.dtype or torch.long))
            graph = dgl.graph(([], []))
            graph.add_nodes(n_nodes)
            for edge in edges:
                graph.add_edges(*edge)
                graph.add_edges(*edge[::-1])

            if self.one_hot:
                enc = [one_hot_encoding(f, self.vocab, encode_unknown=True) for f in fragments]
                enc = np.asarray(enc)
                enc = datatype.cast(enc, (self.dtype or torch.float))

            graph.ndata["hv"] = enc
        except Exception as e:
            raise e
            if self.verbose:
                logger.error(e)
            graph = None
        return graph

vocab_size property

Compute vocab size of this featurizer

Returns:

Name Type Description
size

vocab size

__init__(vocab=None, one_hot=False, dtype=None, cache=True, **params)

MolTree featurizer

Parameters:

Name Type Description Default
vocab Optional[Iterable]

List of the smiles of the functional groups or clusters. If None, the transformer should be fiited before any usage.

None
one_hot bool

Whether or not for a tree a 1d array or a 2d array is returned as features If 1d array, vocabulary elements are mapped into integers, otherwise, vocabulary elements ar mapped into one-hot vectors

False
cache bool

Whether to cache the tree decomposition to avoid recomputing for seen molecules

True
dtype Optional[Callable]

Output data type. Defaults to None

None

Attributes:

Name Type Description
vocab

Mapping from clusters to integers

vocab_size

The number of clusters + 1

one_hot

Whether or not for a sequence a 1d array or a 2d array is returned as features

Source code in molfeat/trans/graph/tree.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def __init__(
    self,
    vocab: Optional[Iterable] = None,
    one_hot: bool = False,
    dtype: Optional[Callable] = None,
    cache: bool = True,
    **params,
):
    """MolTree featurizer

    Args:
        vocab: List of the smiles of the functional groups or clusters.
            If None, the transformer should be fiited before any usage.
        one_hot (bool, optional): Whether or not for a tree a 1d array or a 2d array is returned as features
            If 1d array, vocabulary elements are mapped into integers,
            otherwise, vocabulary elements  ar mapped into one-hot vectors
        cache: Whether to cache the tree decomposition to avoid recomputing for seen molecules
        dtype: Output data type. Defaults to None

    Attributes:
        vocab: Mapping from clusters to integers
        vocab_size: The number of clusters + 1
        one_hot: Whether or not for a sequence a 1d array or a 2d array is returned as features
    """

    self._save_input_args()

    super().__init__(
        dtype=dtype,
        one_hot=one_hot,
        cache=cache,
        featurizer=TreeDecomposer(cache=cache),
        **params,
    )
    if vocab is not None:
        self.vocab = vocab
        self._vocab_size = len(self.vocab) + 1
        self._fitted = True
    else:
        self.vocab = None
        self._vocab_size = None
        self._fitted = False

    if not requires.check("dgl"):
        raise ValueError("dgl is required for this featurizer, please install it first")

    if self.dtype is not None and not datatype.is_dtype_tensor(self.dtype):
        raise ValueError("DGL featurizer only supports torch tensors currently")

fit(X, y=None, output_file=None, **fit_params)

Fit the current transformer on given dataset.

The goal of fitting is for example to identify nan columns values that needs to be removed from the dataset

Parameters:

Name Type Description Default
X List[Union[Mol, str]]

input list of molecules

required
y list

Optional list of molecular properties. Defaults to None.

None
output_file Optional[PathLike]

path to a file that will be used to store the generated set of fragments.

None
fit_params

key val of additional fit parameters

{}

Returns:

Name Type Description
self

MolTransformer instance after fitting

Source code in molfeat/trans/graph/tree.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def fit(
    self,
    X: List[Union[dm.Mol, str]],
    y: Optional[list] = None,
    output_file: Optional[os.PathLike] = None,
    **fit_params,
):
    """Fit the current transformer on given dataset.

    The goal of fitting is for example to identify nan columns values
    that needs to be removed from the dataset

    Args:
        X: input list of molecules
        y (list, optional): Optional list of molecular properties. Defaults to None.
        output_file: path to a file that will be used to store the generated set of fragments.
        fit_params: key val of additional fit parameters


    Returns:
        self: MolTransformer instance after fitting
    """
    if self.vocab is not None:
        logger.warning("The previous vocabulary of fragments will be erased.")
    self.vocab = self.featurizer.get_vocab(X, output_file=output_file, log=self.verbose)
    self._vocab_size = len(self.vocab) + 1
    self._fitted = True

    # save the vocab in the state
    self._input_args["vocab"] = self.vocab

    return self