Skip to content

Graphormer

Graphormer pretrained models

GraphormerTransformer

Bases: PretrainedMolTransformer

Graphormer transformer based on pretrained sequence embedder

Attributes:

Name Type Description
featurizer

Graphormer embedding object

dtype

Data type. Use call instead

pooling

Pooling method for Graphormer's embedding layer

Source code in molfeat/trans/pretrained/graphormer.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
class GraphormerTransformer(PretrainedMolTransformer):
    r"""
    Graphormer transformer based on pretrained sequence embedder

    Attributes:
        featurizer: Graphormer embedding object
        dtype: Data type. Use call instead
        pooling: Pooling method for Graphormer's embedding layer
    """

    def __init__(
        self,
        kind: str = "pcqm4mv2_graphormer_base",
        dtype: Callable = np.float32,
        pooling: str = "mean",
        max_length: Optional[int] = None,
        version=None,
        **params,
    ):
        super().__init__(dtype=dtype, pooling=pooling, **params)
        if not requires.check("graphormer"):
            raise ValueError("`graphormer` is required to use this featurizer.")

        self.preload = True
        self.name = kind
        self._require_mols = False
        self.max_length = max_length
        self.pooling = pooling
        if isinstance(pooling, str):
            pooling = Pooling(dim=1, name=pooling)
        self.pooling = pooling
        self.featurizer = GraphormerEmbeddingsExtractor(
            pretrained_name=self.name, max_nodes=self.max_length
        )
        self.featurizer.config.max_nodes = self.max_length
        self.version = version

    def __repr__(self):
        return "{}(name={}, pooling={}, dtype={})".format(
            self.__class__.__name__,
            _parse_to_evaluable_str(self.name),
            _parse_to_evaluable_str(self.pooling.name),
            _parse_to_evaluable_str(self.dtype),
        )

    @staticmethod
    def list_available_models():
        """List available graphormer model to use"""
        return [
            "pcqm4mv1_graphormer_base",  # PCQM4Mv1
            "pcqm4mv2_graphormer_base",  # PCQM4Mv2
            "pcqm4mv1_graphormer_base_for_molhiv",  # ogbg-molhiv
            "oc20is2re_graphormer3d_base",  # Open Catalyst Challenge
        ]

    def _embed(self, inputs: List[str], **kwargs):
        """Internal molecular embedding

        Args:
            smiles: input smiless
        """
        with torch.no_grad():
            x = self.featurizer.model(inputs)
            x = self.pooling(x)
        return x.numpy()

    def __getstate__(self):
        """Getting state to allow pickling"""
        d = copy.deepcopy(self.__dict__)
        d["precompute_cache"] = None
        d.pop("featurizer", None)
        return d

    def __setstate__(self, d):
        """Setting state during reloading pickling"""
        self.__dict__.update(d)
        self._update_params()

    def compute_max_length(self, inputs: list):
        """Compute maximum node number for the input list of molecules

        Args:
            inputs: input list of molecules
        """
        dataset = GraphormerInferenceDataset(
            inputs,
            multi_hop_max_dist=self.featurizer.config.multi_hop_max_dist,
            spatial_pos_max=self.featurizer.config.spatial_pos_max,
        )
        xs = [item.x.size(0) for item in dataset]
        return max(xs)

    def set_max_length(self, max_length: int):
        """Set the maximum length for this featurizer"""
        self.max_length = max_length
        self._update_params()
        self._preload()

    def _convert(self, inputs: list, **kwargs):
        """Convert molecules to the right format

        Args:
            inputs: inputs to preprocess

        Returns:
            processed: pre-processed input list
        """
        inputs = super()._convert(inputs, **kwargs)
        batch = self.featurizer._convert(inputs)
        return batch

    def _update_params(self):
        super()._update_params()
        self.featurizer = GraphormerEmbeddingsExtractor(
            pretrained_name=self.name, max_nodes=self.max_length
        )
        self.featurizer.config.max_nodes = self.max_length

featurizer = GraphormerEmbeddingsExtractor(pretrained_name=self.name, max_nodes=self.max_length) instance-attribute

max_length = max_length instance-attribute

name = kind instance-attribute

pooling = pooling instance-attribute

preload = True instance-attribute

version = version instance-attribute

__getstate__()

Getting state to allow pickling

Source code in molfeat/trans/pretrained/graphormer.py
84
85
86
87
88
89
def __getstate__(self):
    """Getting state to allow pickling"""
    d = copy.deepcopy(self.__dict__)
    d["precompute_cache"] = None
    d.pop("featurizer", None)
    return d

__init__(kind='pcqm4mv2_graphormer_base', dtype=np.float32, pooling='mean', max_length=None, version=None, **params)

Source code in molfeat/trans/pretrained/graphormer.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def __init__(
    self,
    kind: str = "pcqm4mv2_graphormer_base",
    dtype: Callable = np.float32,
    pooling: str = "mean",
    max_length: Optional[int] = None,
    version=None,
    **params,
):
    super().__init__(dtype=dtype, pooling=pooling, **params)
    if not requires.check("graphormer"):
        raise ValueError("`graphormer` is required to use this featurizer.")

    self.preload = True
    self.name = kind
    self._require_mols = False
    self.max_length = max_length
    self.pooling = pooling
    if isinstance(pooling, str):
        pooling = Pooling(dim=1, name=pooling)
    self.pooling = pooling
    self.featurizer = GraphormerEmbeddingsExtractor(
        pretrained_name=self.name, max_nodes=self.max_length
    )
    self.featurizer.config.max_nodes = self.max_length
    self.version = version

__repr__()

Source code in molfeat/trans/pretrained/graphormer.py
55
56
57
58
59
60
61
def __repr__(self):
    return "{}(name={}, pooling={}, dtype={})".format(
        self.__class__.__name__,
        _parse_to_evaluable_str(self.name),
        _parse_to_evaluable_str(self.pooling.name),
        _parse_to_evaluable_str(self.dtype),
    )

__setstate__(d)

Setting state during reloading pickling

Source code in molfeat/trans/pretrained/graphormer.py
91
92
93
94
def __setstate__(self, d):
    """Setting state during reloading pickling"""
    self.__dict__.update(d)
    self._update_params()

compute_max_length(inputs)

Compute maximum node number for the input list of molecules

Parameters:

Name Type Description Default
inputs list

input list of molecules

required
Source code in molfeat/trans/pretrained/graphormer.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def compute_max_length(self, inputs: list):
    """Compute maximum node number for the input list of molecules

    Args:
        inputs: input list of molecules
    """
    dataset = GraphormerInferenceDataset(
        inputs,
        multi_hop_max_dist=self.featurizer.config.multi_hop_max_dist,
        spatial_pos_max=self.featurizer.config.spatial_pos_max,
    )
    xs = [item.x.size(0) for item in dataset]
    return max(xs)

list_available_models() staticmethod

List available graphormer model to use

Source code in molfeat/trans/pretrained/graphormer.py
63
64
65
66
67
68
69
70
71
@staticmethod
def list_available_models():
    """List available graphormer model to use"""
    return [
        "pcqm4mv1_graphormer_base",  # PCQM4Mv1
        "pcqm4mv2_graphormer_base",  # PCQM4Mv2
        "pcqm4mv1_graphormer_base_for_molhiv",  # ogbg-molhiv
        "oc20is2re_graphormer3d_base",  # Open Catalyst Challenge
    ]

set_max_length(max_length)

Set the maximum length for this featurizer

Source code in molfeat/trans/pretrained/graphormer.py
110
111
112
113
114
def set_max_length(self, max_length: int):
    """Set the maximum length for this featurizer"""
    self.max_length = max_length
    self._update_params()
    self._preload()