forked from seongjunyun/Graph_Transformer_Networks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
messagepassing.py
134 lines (110 loc) · 5.46 KB
/
messagepassing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import inspect
import torch
from torch_geometric.utils import scatter_
special_args = [
'edge_index', 'edge_index_i', 'edge_index_j', 'size', 'size_i', 'size_j'
]
__size_error_msg__ = ('All tensors which should get mapped to the same source '
'or target nodes must be of same size in dimension 0.')
class MessagePassing(torch.nn.Module):
r"""Base class for creating message passing layers
.. math::
\mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i,
\square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}}
\left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{i,j}\right) \right),
where :math:`\square` denotes a differentiable, permutation invariant
function, *e.g.*, sum, mean or max, and :math:`\gamma_{\mathbf{\Theta}}`
and :math:`\phi_{\mathbf{\Theta}}` denote differentiable functions such as
MLPs.
See `here <https://rusty1s.github.io/pytorch_geometric/build/html/notes/
create_gnn.html>`__ for the accompanying tutorial.
Args:
aggr (string, optional): The aggregation scheme to use
(:obj:`"add"`, :obj:`"mean"` or :obj:`"max"`).
(default: :obj:`"add"`)
flow (string, optional): The flow direction of message passing
(:obj:`"source_to_target"` or :obj:`"target_to_source"`).
(default: :obj:`"source_to_target"`)
"""
def __init__(self, aggr='add', flow='source_to_target'):
super(MessagePassing, self).__init__()
self.aggr = aggr
assert self.aggr in ['add', 'mean', 'max']
self.flow = flow
assert self.flow in ['source_to_target', 'target_to_source']
self.__message_args__ = inspect.getfullargspec(self.message)[0][1:]
self.__special_args__ = [(i, arg)
for i, arg in enumerate(self.__message_args__)
if arg in special_args]
self.__message_args__ = [
arg for arg in self.__message_args__ if arg not in special_args
]
self.__update_args__ = inspect.getfullargspec(self.update)[0][2:]
def propagate(self, edge_index, size=None, **kwargs):
r"""The initial call to start propagating messages.
Args:
edge_index (Tensor): The indices of a general (sparse) assignment
matrix with shape :obj:`[N, M]` (can be directed or
undirected).
size (list or tuple, optional): The size :obj:`[N, M]` of the
assignment matrix. If set to :obj:`None`, the size is tried to
get automatically inferrred. (default: :obj:`None`)
**kwargs: Any additional data which is needed to construct messages
and to update node embeddings.
"""
size = [None, None] if size is None else list(size)
assert len(size) == 2
i, j = (0, 1) if self.flow == 'target_to_source' else (1, 0)
ij = {"_i": i, "_j": j}
message_args = []
for arg in self.__message_args__:
if arg[-2:] in ij.keys():
tmp = kwargs[arg[:-2]]
if tmp is None: # pragma: no cover
message_args.append(tmp)
else:
idx = ij[arg[-2:]]
if isinstance(tmp, tuple) or isinstance(tmp, list):
assert len(tmp) == 2
if size[1 - idx] is None:
size[1 - idx] = tmp[1 - idx].size(0)
if size[1 - idx] != tmp[1 - idx].size(0):
raise ValueError(__size_error_msg__)
tmp = tmp[idx]
if size[idx] is None:
size[idx] = tmp.size(0)
if size[idx] != tmp.size(0):
raise ValueError(__size_error_msg__)
tmp = torch.index_select(tmp, 0, edge_index[idx])
message_args.append(tmp)
else:
message_args.append(kwargs[arg])
size[0] = size[1] if size[0] is None else size[0]
size[1] = size[0] if size[1] is None else size[1]
kwargs['edge_index'] = edge_index
kwargs['size'] = size
for (idx, arg) in self.__special_args__:
if arg[-2:] in ij.keys():
message_args.insert(idx, kwargs[arg[:-2]][ij[arg[-2:]]])
else:
message_args.insert(idx, kwargs[arg])
update_args = [kwargs[arg] for arg in self.__update_args__]
out = self.message(*message_args)
out = scatter_(self.aggr, out, edge_index[i], dim_size=size[i])
out = self.update(out, *update_args)
return out
def message(self, x_j): # pragma: no cover
r"""Constructs messages in analogy to :math:`\phi_{\mathbf{\Theta}}`
for each edge in :math:`(i,j) \in \mathcal{E}`.
Can take any argument which was initially passed to :meth:`propagate`.
In addition, features can be lifted to the source node :math:`i` and
target node :math:`j` by appending :obj:`_i` or :obj:`_j` to the
variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`."""
return x_j
def update(self, aggr_out): # pragma: no cover
r"""Updates node embeddings in analogy to
:math:`\gamma_{\mathbf{\Theta}}` for each node
:math:`i \in \mathcal{V}`.
Takes in the output of aggregation as first argument and any argument
which was initially passed to :meth:`propagate`."""
return aggr_out