Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions graphsage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import torch
import torch_scatter
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric
torch_geometric.__version__

import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils

from torch import Tensor
from typing import Union, Tuple, Optional
from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType,
OptTensor)

from torch.nn import Parameter, Linear
from torch_sparse import SparseTensor, set_diag
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax

class GraphSage(MessagePassing):

def __init__(self, in_channels, out_channels, normalize = True,
bias = False, **kwargs):
super(GraphSage, self).__init__(**kwargs)

self.in_channels = in_channels
self.out_channels = out_channels
self.normalize = normalize

self.lin_l = None
self.lin_r = None

############################################################################
# TODO: Your code here!
# Define the layers needed for the message and update functions below.
# self.lin_l is the linear transformation that you apply to embedding
# for central node.
# self.lin_r is the linear transformation that you apply to aggregated
# message from neighbors.
# Our implementation is ~2 lines, but don't worry if you deviate from this.
self.lin_l = nn.Linear(in_channels, out_channels)
self.lin_r = nn.Linear(in_channels, out_channels)
############################################################################

self.reset_parameters()

def reset_parameters(self):
self.lin_l.reset_parameters()
self.lin_r.reset_parameters()

def forward(self, x, edge_index, size = None):
""""""

out = None

############################################################################
# TODO: Your code here!
# Implement message passing, as well as any post-processing (our update rule).
# 1. First call propagate function to conduct the message passing.
# 1.1 See there for more information:
# https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html
# 1.2 We use the same representations for central (x_central) and
# neighbor (x_neighbor) nodes, which means you'll pass x=(x, x)
# to propagate.
# 2. Update our node embedding with skip connection.
# 3. If normalize is set, do L-2 normalization (defined in
# torch.nn.functional)
# Our implementation is ~5 lines, but don't worry if you deviate from this.
z = self.propagate(edge_index, x=(x, x))
z1 = self.lin_l(x) + self.lin_r(z)
if self.normalize:
z1 = F.normalize(z1, p=2, dim=1)
out = z1
############################################################################

return out

def message(self, x_j):

out = None

############################################################################
# TODO: Your code here!
# Implement your message function here.
# Our implementation is ~1 lines, but don't worry if you deviate from this.
out = x_j

############################################################################

return out

def aggregate(self, inputs, index, dim_size = None):

out = None

# The axis along which to index number of nodes.
node_dim = self.node_dim

############################################################################
# TODO: Your code here!
# Implement your aggregate function here.
# See here as how to use torch_scatter.scatter:
# https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html#torch_scatter.scatter
# Our implementation is ~1 lines, but don't worry if you deviate from this.
out = torch_scatter.scatter(inputs, index=index, dim=node_dim, reduce='mean')
############################################################################

return out
18 changes: 18 additions & 0 deletions load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
from ogb.nodeproppred import PygNodePropPredDataset
import os

def get_idx_split():
dataset_name = "ogbn-products"
dataset = PygNodePropPredDataset(name=dataset_name)
split_idx = dataset.get_idx_split()
return split_idx


def get_product_clusters():
dataset_name = "ogbn-products"
Expand Down Expand Up @@ -35,6 +41,13 @@ def get_product_clusters():
return cluster_data, dataset, data, split_idx


def small_clusters():
dataset_name = "small_cluster"
data = torch.load("dataset/data_0.pt")
cluster_data = ClusterData(data, num_parts=10, save_dir="dataset")
return cluster_data


def get_cluster_batches(cluster_data, batch_size):
loader = ClusterLoader(cluster_data, batch_size=batch_size, shuffle=True, num_workers=1)
return loader
Expand All @@ -44,6 +57,11 @@ def load_small():
data = torch.load("dataset/data_0.pt")
print(data)
return data


def load_full():
data = torch.load("dataset/partition_15000.pt")
return data


def save_batch(loader, name):
Expand Down
2 changes: 1 addition & 1 deletion models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def reset_parameters(self):
for bn in self.bns:
bn.reset_parameters()

def forward(self, x, adj_t):
def forward(self, x, adj_t=None, **kwargs):
z = self.convs[0](x, adj_t)

for i, layer in enumerate(self.bns):
Expand Down
Loading