-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: CPU implementation of the graph convolution op. Reviewed By: nikhilaravi, gkioxari Differential Revision: D21384361 fbshipit-source-id: bc96730e9727bb9aa1b0a232dcb82f0c0d12fe6b
- Loading branch information
1 parent
4872a2c
commit 7944d24
Showing
4 changed files
with
68 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
|
||
#include <ATen/ATen.h> | ||
|
||
at::Tensor GatherScatterCpu( | ||
const at::Tensor& input, | ||
const at::Tensor& edges, | ||
bool directed, | ||
bool backward) { | ||
const auto num_vertices = input.size(0); | ||
const auto input_feature_dim = input.size(1); | ||
const auto num_edges = edges.size(0); | ||
|
||
auto output = at::zeros({num_vertices, input_feature_dim}, input.options()); | ||
|
||
auto input_a = input.accessor<float, 2>(); | ||
auto edges_a = edges.accessor<int64_t, 2>(); | ||
auto output_a = output.accessor<float, 2>(); | ||
const int v0_idx = backward ? 1 : 0; | ||
const int v1_idx = backward ? 0 : 1; | ||
|
||
for (int e = 0; e < num_edges; ++e) { | ||
// Get indices of vertices which form the edge. | ||
const int64_t v0 = edges_a[e][v0_idx]; | ||
const int64_t v1 = edges_a[e][v1_idx]; | ||
|
||
for (int d = 0; d < input_feature_dim; ++d) { | ||
output_a[v0][d] += input_a[v1][d]; | ||
if (!directed) { | ||
output_a[v1][d] += input_a[v0][d]; | ||
} | ||
} | ||
} | ||
return output; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters