From 39203dd5d072089cbfad90889e41cf8797f9017b Mon Sep 17 00:00:00 2001 From: shaonc <55311820+shaonc@users.noreply.github.com> Date: Wed, 17 Mar 2021 23:29:14 -0700 Subject: [PATCH 1/2] Update Train Library To Run GAT and GraphSAGE Update Train Library to be able to run GAT and GraphSAGE Models from the Colab assignment. Create separate main functions to run the GraphSAGE and GAT Models on the small dataset --- models.py | 2 +- project_gat.py | 369 ++++++++++++++++++++++++++++++++++++++++++++++ run_gat.py | 40 +++++ run_graph_sage.py | 39 +++++ train.py | 66 ++++++++- 5 files changed, 512 insertions(+), 4 deletions(-) create mode 100644 project_gat.py create mode 100644 run_gat.py create mode 100644 run_graph_sage.py diff --git a/models.py b/models.py index acb2906..6ee16b8 100644 --- a/models.py +++ b/models.py @@ -32,7 +32,7 @@ def reset_parameters(self): for bn in self.bns: bn.reset_parameters() - def forward(self, x, adj_t): + def forward(self, x, adj_t=None, **kwargs): z = self.convs[0](x, adj_t) for i, layer in enumerate(self.bns): diff --git a/project_gat.py b/project_gat.py new file mode 100644 index 0000000..671570b --- /dev/null +++ b/project_gat.py @@ -0,0 +1,369 @@ +import torch_geometric + +import torch +import torch_scatter +import torch.nn as nn +import torch.nn.functional as F + +import torch_geometric.nn as pyg_nn +import torch_geometric.utils as pyg_utils + +from torch import Tensor +from typing import Union, Tuple, Optional +from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType, + OptTensor) + +from torch.nn import Parameter, Linear +from torch_sparse import SparseTensor, set_diag +from torch_geometric.nn.conv import MessagePassing +from torch_geometric.utils import remove_self_loops, add_self_loops, softmax + +class GNNStack(torch.nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim, args, emb=False): + super(GNNStack, self).__init__() + conv_model = self.build_conv_model(args.model_type) + self.convs = nn.ModuleList() + self.convs.append(conv_model(input_dim, hidden_dim, heads=args.heads)) + assert (args.num_layers >= 1), 'Number of layers is not >=1' + for l in range(args.num_layers-1): + self.convs.append(conv_model(hidden_dim, hidden_dim, heads=args.heads)) + + # post-message-passing + self.post_mp = nn.Sequential( + nn.Linear(args.heads * hidden_dim, hidden_dim), nn.Dropout(args.dropout), + nn.Linear(hidden_dim, output_dim)) + + self.dropout = args.dropout + self.num_layers = args.num_layers + + self.emb = emb + print("Dims:", input_dim, hidden_dim, output_dim, args.heads) + + def build_conv_model(self, model_type): + if model_type == 'GraphSage': + return GraphSage + elif model_type == 'GAT': + return GAT + + def forward(self, data, **kwargs): + x, edge_index = data, kwargs['edge_index'] + + for i in range(self.num_layers): + x = self.convs[i](x, edge_index) + x = F.relu(x) + x = F.dropout(x, p=self.dropout) + + x = self.post_mp(x) + + if self.emb == True: + return x + + return F.log_softmax(x, dim=1) + + def loss(self, pred, label): + return F.nll_loss(pred, label) + + def reset_parameters(self): + for c in self.convs: + c.reset_parameters() + self.post_mp[0].reset_parameters() + self.post_mp[2].reset_parameters() + + +class GAT(MessagePassing): + + def __init__(self, in_channels, out_channels, heads = 2, + negative_slope = 0.2, dropout = 0., **kwargs): + super(GAT, self).__init__(node_dim=0, **kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.heads = heads + self.negative_slope = negative_slope + self.dropout = dropout + + self.lin_l = None + self.lin_r = None + self.att_l = None + self.att_r = None + + self.lin_l = nn.Linear(in_channels, heads * out_channels) + + self.lin_r = self.lin_l + + self.att_r = nn.Parameter(torch.zeros([heads, out_channels, 1], dtype=torch.float)) + self.att_l = nn.Parameter(torch.zeros([heads, out_channels, 1], dtype=torch.float)) + + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_uniform_(self.lin_l.weight) + nn.init.xavier_uniform_(self.lin_r.weight) + nn.init.xavier_uniform_(self.att_l) + nn.init.xavier_uniform_(self.att_r) + + def forward(self, x, edge_index, size = None): + + H, C = self.heads, self.out_channels + + z1 = self.lin_l(x) + z2 = self.lin_r(x) + h1 = z1.reshape([z1.shape[0], H, C]) + h2 = z2.reshape([z2.shape[0], H, C]) + h1e = h1[edge_index[0]] + h2e = h2[edge_index[1]] + + alpha_l = torch.matmul(self.att_l.reshape([1, H, 1, C]), h1.reshape([h1.shape[0], H, C, 1])) + alpha_r = torch.matmul(self.att_r.reshape([1, H, 1, C]), h2.reshape([h2.shape[0], H, C, 1])) + alpha_l = alpha_l.reshape([h1.shape[0], H]) + alpha_r = alpha_r.reshape([h2.shape[0], H]) + + z = self.propagate(edge_index, x=(h1, h2), alpha=(alpha_l, alpha_r)) + out = z.reshape([z.shape[0], z.shape[1] * z.shape[2]]) + + return out + + + def message(self, x_j, alpha_j, alpha_i, index, ptr, size_i): + + ax = F.leaky_relu(alpha_i + alpha_j, negative_slope=self.negative_slope) + a = pyg_utils.softmax( + ax, + index=index, ptr=ptr, num_nodes=size_i) + a1 = F.dropout(a, p=self.dropout) + a1 = a1.reshape([a1.shape[0], a1.shape[1], 1]) + out = torch.mul(a1, x_j) + + return out + + + def aggregate(self, inputs, index, dim_size = None): + + out = torch_scatter.scatter(inputs, index, dim=0, dim_size=dim_size, reduce='sum') + + return out + + +class GraphSage(MessagePassing): + + def __init__(self, in_channels, out_channels, heads = 1, normalize = True, + bias = False, **kwargs): + super(GraphSage, self).__init__(**kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.normalize = normalize + + self.lin_l = None + self.lin_r = None + + ############################################################################ + # TODO: Your code here! + # Define the layers needed for the message and update functions below. + # self.lin_l is the linear transformation that you apply to embedding + # for central node. + # self.lin_r is the linear transformation that you apply to aggregated + # message from neighbors. + # Our implementation is ~2 lines, but don't worry if you deviate from this. + self.lin_l = nn.Linear(in_channels, out_channels) + self.lin_r = nn.Linear(in_channels, out_channels) + ############################################################################ + + self.reset_parameters() + + def reset_parameters(self): + self.lin_l.reset_parameters() + self.lin_r.reset_parameters() + + def forward(self, x, edge_index, size = None): + """""" + + out = None + + ############################################################################ + # TODO: Your code here! + # Implement message passing, as well as any post-processing (our update rule). + # 1. First call propagate function to conduct the message passing. + # 1.1 See there for more information: + # https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html + # 1.2 We use the same representations for central (x_central) and + # neighbor (x_neighbor) nodes, which means you'll pass x=(x, x) + # to propagate. + # 2. Update our node embedding with skip connection. + # 3. If normalize is set, do L-2 normalization (defined in + # torch.nn.functional) + # Our implementation is ~5 lines, but don't worry if you deviate from this. + z = self.propagate(edge_index, x=(x, x)) + z1 = self.lin_l(x) + self.lin_r(z) + if self.normalize: + z1 = F.normalize(z1, p=2, dim=1) + out = z1 + ############################################################################ + + return out + + def message(self, x_j): + + out = None + + ############################################################################ + # TODO: Your code here! + # Implement your message function here. + # Our implementation is ~1 lines, but don't worry if you deviate from this. + out = x_j + + ############################################################################ + + return out + + def aggregate(self, inputs, index, dim_size = None): + + out = None + + # The axis along which to index number of nodes. + node_dim = self.node_dim + + ############################################################################ + # TODO: Your code here! + # Implement your aggregate function here. + # See here as how to use torch_scatter.scatter: + # https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html#torch_scatter.scatter + # Our implementation is ~1 lines, but don't worry if you deviate from this. + out = torch_scatter.scatter(inputs, index=index, dim=node_dim, reduce='mean') + ############################################################################ + + return out + + +# import torch.optim as optim + +# def build_optimizer(args, params): + # weight_decay = args.weight_decay + # filter_fn = filter(lambda p : p.requires_grad, params) + # if args.opt == 'adam': + # optimizer = optim.Adam(filter_fn, lr=args.lr, weight_decay=weight_decay) + # elif args.opt == 'sgd': + # optimizer = optim.SGD(filter_fn, lr=args.lr, momentum=0.95, weight_decay=weight_decay) + # elif args.opt == 'rmsprop': + # optimizer = optim.RMSprop(filter_fn, lr=args.lr, weight_decay=weight_decay) + # elif args.opt == 'adagrad': + # optimizer = optim.Adagrad(filter_fn, lr=args.lr, weight_decay=weight_decay) + # if args.opt_scheduler == 'none': + # return None, optimizer + # elif args.opt_scheduler == 'step': + # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.opt_decay_step, gamma=args.opt_decay_rate) + # elif args.opt_scheduler == 'cos': + # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.opt_restart) + # return scheduler, optimizer + + +# import time + +# import networkx as nx +# import numpy as np +# import torch +# import torch.optim as optim + +# from torch_geometric.datasets import TUDataset +# from torch_geometric.datasets import Planetoid +# from torch_geometric.data import DataLoader + +# import torch_geometric.nn as pyg_nn + +# import matplotlib.pyplot as plt + + +# def train(dataset, args): + + # print("Node task. test set size:", np.sum(dataset[0]['train_mask'].numpy())) + # test_loader = loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) + + # build model + # model = GNNStack(dataset.num_node_features, args.hidden_dim, dataset.num_classes, + # args) + # scheduler, opt = build_optimizer(args, model.parameters()) + + #train + # losses = [] + # test_accs = [] + # for epoch in range(args.epochs): + # total_loss = 0 + # model.train() + # for batch in loader: + # opt.zero_grad() + # pred = model(batch) + # label = batch.y + # pred = pred[batch.train_mask] + # label = label[batch.train_mask] + # loss = model.loss(pred, label) + # loss.backward() + # opt.step() + # total_loss += loss.item() * batch.num_graphs + # total_loss /= len(loader.dataset) + # losses.append(total_loss) + + # if epoch % 10 == 0: + # test_acc = test(test_loader, model) + # test_accs.append(test_acc) + # else: + # test_accs.append(test_accs[-1]) + # return test_accs, losses + +# def test(loader, model, is_validation=True): + # model.eval() + + # correct = 0 + # for data in loader: + # with torch.no_grad(): + # max(dim=1) returns values, indices tuple; only need indices + # pred = model(data).max(dim=1)[1] + # label = data.y + + # mask = data.val_mask if is_validation else data.test_mask + #node classification: only evaluate on nodes in test set + # pred = pred[mask] + # label = data.y[mask] + + # correct += pred.eq(label).sum().item() + + # total = 0 + # for data in loader.dataset: + # total += torch.sum(data.val_mask if is_validation else data.test_mask).item() + # return correct / total + +class objectview(object): + def __init__(self, d): + self.__dict__ = d + +def main(): + for args in [ + {'model_type': 'GraphSage', 'dataset': 'cora', 'num_layers': 2, 'heads': 1, 'batch_size': 32, 'hidden_dim': 32, 'dropout': 0.5, 'epochs': 500, 'opt': 'adam', 'opt_scheduler': 'none', 'opt_restart': 0, 'weight_decay': 5e-3, 'lr': 0.01}, + ]: + args = objectview(args) + for model in ['GraphSage', 'GAT']: + args.model_type = model + + # Match the dimension. + if model == 'GAT': + args.heads = 2 + else: + args.heads = 1 + + if args.dataset == 'cora': + dataset = Planetoid(root='/tmp/cora', name='Cora') + else: + raise NotImplementedError("Unknown dataset") + test_accs, losses = train(dataset, args) + + print("Maximum accuracy: {0}".format(max(test_accs))) + print("Minimum loss: {0}".format(min(losses))) + + plt.title(dataset.name) + plt.plot(losses, label="training loss" + " - " + args.model_type) + plt.plot(test_accs, label="test accuracy" + " - " + args.model_type) + plt.legend() + plt.show() + +if __name__ == '__main__': + main() + diff --git a/run_gat.py b/run_gat.py new file mode 100644 index 0000000..1755b68 --- /dev/null +++ b/run_gat.py @@ -0,0 +1,40 @@ +import load_data +import train +import torch +import torch_geometric.data as tgd +import project_gat + +def main(): + data = load_data.load_small() + print(data.x.shape, data.x.dtype) + print(data.y.shape, data.y.dtype) + print(data.edge_index.shape) + data_loader = tgd.DataLoader([data]) + split_idx = { + 'train' : data.train_mask, + 'valid' : data.valid_mask, + 'test' : data.test_mask + } + + num_node_features = 100 + size_train = torch.sum(data.train_mask.to(torch.int)) + size_valid = torch.sum(data.valid_mask.to(torch.int)) + size_test = torch.sum(data.test_mask.to(torch.int)) + print("Splits: ", size_train, size_valid, size_test) + + num_labels = int(torch.max(data.y)) + print(num_labels) + args = {'model_type': 'GAT', 'num_layers': 2, 'heads': 1, 'batch_size': 32, 'hidden_dim': 32, 'dropout': 0.5, 'epochs': 10, + 'opt': 'adam', 'opt_scheduler': 'none', 'opt_restart': 0, 'weight_decay': 5e-3, 'lr': 0.01, + 'use_edge_index': 1, 'eval_small': 1} + + args_obj = project_gat.objectview(args) + + model = project_gat.GNNStack(num_node_features, args_obj.hidden_dim, num_labels, args=args_obj) + print('created model') + train.run(model, data_loader, split_idx, extra_args=args) + + +if __name__ == "__main__": + main() + \ No newline at end of file diff --git a/run_graph_sage.py b/run_graph_sage.py new file mode 100644 index 0000000..707aab4 --- /dev/null +++ b/run_graph_sage.py @@ -0,0 +1,39 @@ +import load_data +import train +import torch +import torch_geometric.data as tgd +import project_gat + +def main(): + data = load_data.load_small() + print(data.x.shape, data.x.dtype) + print(data.y.shape, data.y.dtype) + print(data.edge_index.shape) + data_loader = tgd.DataLoader([data]) + split_idx = { + 'train' : data.train_mask, + 'valid' : data.valid_mask, + 'test' : data.test_mask + } + + num_node_features = 100 + size_train = torch.sum(data.train_mask.to(torch.int)) + size_valid = torch.sum(data.valid_mask.to(torch.int)) + size_test = torch.sum(data.test_mask.to(torch.int)) + print("Splits: ", size_train, size_valid, size_test) + + num_labels = int(torch.max(data.y)) + print(num_labels) + + args = {'model_type': 'GraphSage', 'num_layers': 2, 'heads': 1, 'batch_size': 32, 'hidden_dim': 32, 'dropout': 0.5, 'epochs': 10, + 'opt': 'adam', 'opt_scheduler': 'none', 'opt_restart': 0, 'weight_decay': 5e-3, 'lr': 0.01, + 'use_edge_index': 1, 'eval_small': 1} + args_obj = project_gat.objectview(args) + model = project_gat.GNNStack(num_node_features, args_obj.hidden_dim, num_labels, args=args_obj) + print('created model') + train.run(model, data_loader, split_idx, extra_args=args) + + +if __name__ == "__main__": + main() + \ No newline at end of file diff --git a/train.py b/train.py index 757e287..f448821 100644 --- a/train.py +++ b/train.py @@ -21,7 +21,7 @@ def train(model, data_loader, optimizer, device): continue optimizer.zero_grad() - out = model(batch.x, batch.edge_index)[batch.train_mask] + out = model(batch.x, edge_index=batch.edge_index, adj_t=batch.edge_index)[batch.train_mask] y = torch.flatten(batch.y[batch.train_mask]) loss = F.nll_loss(out, y) loss.backward() @@ -37,12 +37,18 @@ def train(model, data_loader, optimizer, device): @torch.no_grad() -def test(model, data, split_idx, evaluator): +def test(model, data, split_idx, evaluator, use_edge_index=False): model.eval() print("starting") print(data) - out = model(data.x, data.adj_t) + if use_edge_index: + edge_index = data.edge_index + print(edge_index.shape) + + out = model(data.x, edge_index=edge_index) + else: + out = model(data.x, data.adj_t) print("Out done") y_pred = out.argmax(dim=-1, keepdim=True) @@ -62,7 +68,61 @@ def test(model, data, split_idx, evaluator): return train_acc, valid_acc +def run(model,data_loader,split_idx, extra_args=None): + device = "cpu" + args = { + 'device': device, + 'num_layers': 3, + 'hidden_dim': 256, + 'dropout': 0.5, + 'lr': 0.001, + 'epochs': 100, + } + if extra_args: + for k in extra_args: + args[k] = extra_args[k] + + dataset_name = "ogbn-products" + if extra_args.get('use_edge_index', 0) == 1: + dataset_eval = PygNodePropPredDataset(name=dataset_name) + else: + dataset_eval = PygNodePropPredDataset(name=dataset_name, transform=T.ToSparseTensor()) + + eval_data = dataset_eval[0] + + eval_split_idx = dataset_eval.get_idx_split() + + evaluator = Evaluator(name='ogbn-products') + + model.reset_parameters() + optimizer = torch.optim.Adam(model.parameters(), lr=args['lr']) + loss_fn = F.nll_loss + + best_model = None + best_valid_acc = 0 + + print("----------------------------------") + print("Params:", args) + print("======") + + for epoch in range(1, 1 + args["epochs"]): + model.to(device) + loss = train(model, data_loader, optimizer, device) + model.to("cpu") + if extra_args.get('eval_small', 0) == 1: + result = test(model, data_loader.dataset[0], split_idx, evaluator, use_edge_index=True) + else: + result = test(model, eval_data, eval_split_idx, evaluator, use_edge_index=extra_args['use_edge_index']) + train_acc, valid_acc = result + print(f'Epoch: {epoch:02d}, ' + f'Loss: {loss:.4f}, ' + f'Train: {100 * train_acc:.2f}%, ' + f'Valid: {100 * valid_acc:.2f}% ' + f'Test: {100 * 0:.2f}%') + + + if __name__ == "__main__": device = "cuda" args = { From d0e846a047532259cf8fe33331accd64e4f03bbb Mon Sep 17 00:00:00 2001 From: shaonc <55311820+shaonc@users.noreply.github.com> Date: Fri, 19 Mar 2021 23:21:59 -0700 Subject: [PATCH 2/2] add files to run the pytorch versions of GAT and GraphSAGE on the Clustered Data Add files to run the pytorch versions of GAT and GraphSAGE on the Clustered Data --- graphsage.py | 110 +++++++++++++++++++++++++++++++++++++++ load_data.py | 18 +++++++ project_gat.py | 14 +++-- run_torchg_gat.py | 36 +++++++++++++ run_torchg_graph_sage.py | 34 ++++++++++++ train.py | 10 ++-- 6 files changed, 213 insertions(+), 9 deletions(-) create mode 100644 graphsage.py create mode 100644 run_torchg_gat.py create mode 100644 run_torchg_graph_sage.py diff --git a/graphsage.py b/graphsage.py new file mode 100644 index 0000000..e5dfd30 --- /dev/null +++ b/graphsage.py @@ -0,0 +1,110 @@ +import torch +import torch_scatter +import torch.nn as nn +import torch.nn.functional as F + +import torch_geometric +torch_geometric.__version__ + +import torch_geometric.nn as pyg_nn +import torch_geometric.utils as pyg_utils + +from torch import Tensor +from typing import Union, Tuple, Optional +from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType, + OptTensor) + +from torch.nn import Parameter, Linear +from torch_sparse import SparseTensor, set_diag +from torch_geometric.nn.conv import MessagePassing +from torch_geometric.utils import remove_self_loops, add_self_loops, softmax + +class GraphSage(MessagePassing): + + def __init__(self, in_channels, out_channels, normalize = True, + bias = False, **kwargs): + super(GraphSage, self).__init__(**kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.normalize = normalize + + self.lin_l = None + self.lin_r = None + + ############################################################################ + # TODO: Your code here! + # Define the layers needed for the message and update functions below. + # self.lin_l is the linear transformation that you apply to embedding + # for central node. + # self.lin_r is the linear transformation that you apply to aggregated + # message from neighbors. + # Our implementation is ~2 lines, but don't worry if you deviate from this. + self.lin_l = nn.Linear(in_channels, out_channels) + self.lin_r = nn.Linear(in_channels, out_channels) + ############################################################################ + + self.reset_parameters() + + def reset_parameters(self): + self.lin_l.reset_parameters() + self.lin_r.reset_parameters() + + def forward(self, x, edge_index, size = None): + """""" + + out = None + + ############################################################################ + # TODO: Your code here! + # Implement message passing, as well as any post-processing (our update rule). + # 1. First call propagate function to conduct the message passing. + # 1.1 See there for more information: + # https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html + # 1.2 We use the same representations for central (x_central) and + # neighbor (x_neighbor) nodes, which means you'll pass x=(x, x) + # to propagate. + # 2. Update our node embedding with skip connection. + # 3. If normalize is set, do L-2 normalization (defined in + # torch.nn.functional) + # Our implementation is ~5 lines, but don't worry if you deviate from this. + z = self.propagate(edge_index, x=(x, x)) + z1 = self.lin_l(x) + self.lin_r(z) + if self.normalize: + z1 = F.normalize(z1, p=2, dim=1) + out = z1 + ############################################################################ + + return out + + def message(self, x_j): + + out = None + + ############################################################################ + # TODO: Your code here! + # Implement your message function here. + # Our implementation is ~1 lines, but don't worry if you deviate from this. + out = x_j + + ############################################################################ + + return out + + def aggregate(self, inputs, index, dim_size = None): + + out = None + + # The axis along which to index number of nodes. + node_dim = self.node_dim + + ############################################################################ + # TODO: Your code here! + # Implement your aggregate function here. + # See here as how to use torch_scatter.scatter: + # https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html#torch_scatter.scatter + # Our implementation is ~1 lines, but don't worry if you deviate from this. + out = torch_scatter.scatter(inputs, index=index, dim=node_dim, reduce='mean') + ############################################################################ + + return out diff --git a/load_data.py b/load_data.py index 3e525a2..3796896 100644 --- a/load_data.py +++ b/load_data.py @@ -5,6 +5,12 @@ from ogb.nodeproppred import PygNodePropPredDataset import os +def get_idx_split(): + dataset_name = "ogbn-products" + dataset = PygNodePropPredDataset(name=dataset_name) + split_idx = dataset.get_idx_split() + return split_idx + def get_product_clusters(): dataset_name = "ogbn-products" @@ -35,6 +41,13 @@ def get_product_clusters(): return cluster_data, dataset, data, split_idx +def small_clusters(): + dataset_name = "small_cluster" + data = torch.load("dataset/data_0.pt") + cluster_data = ClusterData(data, num_parts=10, save_dir="dataset") + return cluster_data + + def get_cluster_batches(cluster_data, batch_size): loader = ClusterLoader(cluster_data, batch_size=batch_size, shuffle=True, num_workers=1) return loader @@ -44,6 +57,11 @@ def load_small(): data = torch.load("dataset/data_0.pt") print(data) return data + + +def load_full(): + data = torch.load("dataset/partition_15000.pt") + return data def save_batch(loader, name): diff --git a/project_gat.py b/project_gat.py index 671570b..f1b8a93 100644 --- a/project_gat.py +++ b/project_gat.py @@ -44,9 +44,17 @@ def build_conv_model(self, model_type): return GraphSage elif model_type == 'GAT': return GAT - - def forward(self, data, **kwargs): - x, edge_index = data, kwargs['edge_index'] + elif model_type == 'torch_geometric_graph_sage': + def sgconv(in_channels, out_channels, **kwargs): + return pyg_nn.SAGEConv(in_channels, out_channels, normalize=True) + return sgconv + elif model_type == 'torch_geometric_gat': + def gatconv(in_channels, out_channels, heads=1): + return pyg_nn.GATConv(in_channels, out_channels, heads) + return gatconv + + def forward(self, data, edge_index, **kwargs): + x = data for i in range(self.num_layers): x = self.convs[i](x, edge_index) diff --git a/run_torchg_gat.py b/run_torchg_gat.py new file mode 100644 index 0000000..4e4ca33 --- /dev/null +++ b/run_torchg_gat.py @@ -0,0 +1,36 @@ +import load_data +import train +import torch +import torch_geometric.data as tgd +import project_gat + +def main(): + cluster_data, dataset, data, split_idx = load_data.get_product_clusters() + + cluster_loader = load_data.get_cluster_batches(cluster_data, 100) + + num_node_features = 100 + size_train = torch.sum(data.train_mask.to(torch.int)) + size_valid = torch.sum(data.valid_mask.to(torch.int)) + size_test = torch.sum(data.test_mask.to(torch.int)) + print("Splits: ", size_train, size_valid, size_test) + + num_labels = int(torch.max(data.y)) + print(num_labels) + + num_labels = int(torch.max(data.y)) + print(num_labels) + args = {'model_type': 'torch_geometric_gat', 'num_layers': 2, 'heads': 1, 'batch_size': 32, 'hidden_dim': 32, 'dropout': 0.5, 'epochs': 10, + 'opt': 'adam', 'opt_scheduler': 'none', 'opt_restart': 0, 'weight_decay': 5e-3, 'lr': 0.01, + } + + args_obj = project_gat.objectview(args) + + model = project_gat.GNNStack(num_node_features, args_obj.hidden_dim, num_labels, args=args_obj) + print('created model') + train.run(model, cluster_loader, split_idx, extra_args=args) + + +if __name__ == "__main__": + main() + \ No newline at end of file diff --git a/run_torchg_graph_sage.py b/run_torchg_graph_sage.py new file mode 100644 index 0000000..3b37479 --- /dev/null +++ b/run_torchg_graph_sage.py @@ -0,0 +1,34 @@ +import load_data +import train +import torch +import torch_geometric.data as tgd +import torch_geometric.transforms as T +import project_gat + + +def main(): + cluster_data, dataset, data, split_idx = load_data.get_product_clusters() + + cluster_loader = load_data.get_cluster_batches(cluster_data, 100) + + num_node_features = 100 + size_train = torch.sum(data.train_mask.to(torch.int)) + size_valid = torch.sum(data.valid_mask.to(torch.int)) + size_test = torch.sum(data.test_mask.to(torch.int)) + print("Splits: ", size_train, size_valid, size_test) + + num_labels = int(torch.max(data.y)) + print(num_labels) + + args = {'model_type': 'torch_geometric_graph_sage', 'num_layers': 2, 'heads': 1, 'batch_size': 32, 'hidden_dim': 32, 'dropout': 0.5, 'epochs': 10, + 'opt': 'adam', 'opt_scheduler': 'none', 'opt_restart': 0, 'weight_decay': 5e-3, 'lr': 0.01, + } + args_obj = project_gat.objectview(args) + model = project_gat.GNNStack(num_node_features, args_obj.hidden_dim, num_labels, args=args_obj) + print('created model') + train.run(model, cluster_loader, split_idx, extra_args=args) + + +if __name__ == "__main__": + main() + \ No newline at end of file diff --git a/train.py b/train.py index f448821..f0545e3 100644 --- a/train.py +++ b/train.py @@ -43,10 +43,7 @@ def test(model, data, split_idx, evaluator, use_edge_index=False): print(data) if use_edge_index: - edge_index = data.edge_index - print(edge_index.shape) - - out = model(data.x, edge_index=edge_index) + out = model(data.x, data.edge_index) else: out = model(data.x, data.adj_t) print("Out done") @@ -111,9 +108,10 @@ def run(model,data_loader,split_idx, extra_args=None): loss = train(model, data_loader, optimizer, device) model.to("cpu") if extra_args.get('eval_small', 0) == 1: - result = test(model, data_loader.dataset[0], split_idx, evaluator, use_edge_index=True) + eval_data = ld.load_small() + result = test(model, eval_data, split_idx, evaluator, use_edge_index=True) else: - result = test(model, eval_data, eval_split_idx, evaluator, use_edge_index=extra_args['use_edge_index']) + result = test(model, eval_data, eval_split_idx, evaluator, use_edge_index=False) train_acc, valid_acc = result print(f'Epoch: {epoch:02d}, ' f'Loss: {loss:.4f}, '