From 9494b995fd9c697f2e8c263846eb3f2fcc2ad3cd Mon Sep 17 00:00:00 2001 From: shaonc <55311820+shaonc@users.noreply.github.com> Date: Sat, 6 Mar 2021 15:49:57 -0800 Subject: [PATCH] fix sampling and loading functions fix sampling function to output node ids in the sampled node feature and node label files. Fix loading functions to load and remap node ids after loading from sampled files. Generate appropriate test validation and training masks before starting the run. --- baseline.ipynb | 918 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 918 insertions(+) create mode 100644 baseline.ipynb diff --git a/baseline.ipynb b/baseline.ipynb new file mode 100644 index 0000000..9c66bf1 --- /dev/null +++ b/baseline.ipynb @@ -0,0 +1,918 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 29086, + "status": "ok", + "timestamp": 1614993114537, + "user": { + "displayName": "Arvind Srivastav", + "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GhKsFI_4660fvIUHDgSAVEBHXHmkjgDRf91ukLdhw=s64", + "userId": "13824725762542667997" + }, + "user_tz": 480 + }, + "id": "nQQedrLOIM5n", + "outputId": "c977889e-1d58-4e73-e2de-2db18a9dbcd8" + }, + "outputs": [], + "source": [ + "#!pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html\n", + "#!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html\n", + "#!pip install -q torch-geometric\n", + "#!pip install -q git+https://github.com/snap-stanford/deepsnap.git\n", + "#!pip install -q ogb" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 35 + }, + "executionInfo": { + "elapsed": 8138, + "status": "ok", + "timestamp": 1614993206556, + "user": { + "displayName": "Arvind Srivastav", + "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GhKsFI_4660fvIUHDgSAVEBHXHmkjgDRf91ukLdhw=s64", + "userId": "13824725762542667997" + }, + "user_tz": 480 + }, + "id": "g4BNgPthWJXh", + "outputId": "8d5c40d4-725c-4069-f5fb-e89a34434a20" + }, + "outputs": [], + "source": [ + "import torch_geometric\n", + "import torch_geometric as pyg\n", + "torch_geometric.__version__" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "v5mEnLJ9zrBH" + }, + "source": [ + "## Mount drive and load data folder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 21285, + "status": "ok", + "timestamp": 1614993927175, + "user": { + "displayName": "Arvind Srivastav", + "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GhKsFI_4660fvIUHDgSAVEBHXHmkjgDRf91ukLdhw=s64", + "userId": "13824725762542667997" + }, + "user_tz": 480 + }, + "id": "_xROmo7tzcAL", + "outputId": "170ad2f9-ae79-4b74-c61f-fd9730510402" + }, + "outputs": [], + "source": [ + "use_localenv = 1\n", + "if use_localenv == 0:\n", + " from google.colab import drive\n", + " drive.mount('/content/drive')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 3872, + "status": "ok", + "timestamp": 1614993932239, + "user": { + "displayName": "Arvind Srivastav", + "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GhKsFI_4660fvIUHDgSAVEBHXHmkjgDRf91ukLdhw=s64", + "userId": "13824725762542667997" + }, + "user_tz": 480 + }, + "id": "hf8e9GbxzvJt", + "outputId": "2edf8ad5-d491-4c0f-d5cd-1e491595d1bb" + }, + "outputs": [], + "source": [ + "if use_localenv == 0:\n", + " data_dir = \"/content/drive/MyDrive/CS224w-project/dataset/dataset/ogbn_products/\"\n", + " %cd drive/MyDrive/CS224w-project/dataset/dataset/ogbn_products/\n", + " !ls\n", + "else:\n", + " import os\n", + " data_dir = \"C:\\\\Users\\\\shaon\\\\Desktop\\\\CS224W\\\\CS224W_2021\\\\CS224W_PROJECT\\\\dataset\\\\ogbn_products\"\n", + " os.chdir(data_dir)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2fouCD0AU5Vt" + }, + "source": [ + "## Generate small graph for baseline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "21RZA4U_3DXE" + }, + "outputs": [], + "source": [ + "import random\n", + "\n", + "def sample_edges(edge_file_name, output_file_name):\n", + " d1 = {}\n", + " l = []\n", + " for i in range(1000):\n", + " d1[random.randint(0, 61859140)] = 1\n", + "\n", + " f = open(edge_file_name)\n", + " c = 0\n", + " out = []\n", + " while True:\n", + " l = f.readline()\n", + " if len(l) == 0:\n", + " break\n", + " if c in d1:\n", + " out.append(l)\n", + " c += 1\n", + "\n", + " w = open(output_file_name, 'w')\n", + " for line in out:\n", + " w.write(line)\n", + "\n", + " w.close()\n", + "\n", + "\n", + "def get_node_features(node_file, node_label_file, edge_file,\n", + " node_feature_output, node_label_output):\n", + " f = open(edge_file)\n", + " nodes = []\n", + " for line in f.readlines():\n", + " values = line.strip().split(',')\n", + " for x in values:\n", + " if len(x.strip()) > 0:\n", + " nodes.append(int(x))\n", + "\n", + " print(\"Num nodes before dedup=\", len(nodes))\n", + " nodes = set(nodes)\n", + " print(\"Num nodes=\", len(nodes))\n", + "\n", + " c = 0\n", + " nf = open(node_file)\n", + " nl = open(node_label_file)\n", + " outf = []\n", + " outl = []\n", + " while True:\n", + " features = nf.readline()\n", + " label = nl.readline()\n", + "\n", + " if not features:\n", + " break\n", + " if c in nodes:\n", + " outf.append(str(c) + ',' + features)\n", + " outl.append(str(c) + ',' + label)\n", + " c = c + 1\n", + "\n", + " w = open(node_feature_output, 'w')\n", + " for line in outf:\n", + " w.write(line)\n", + " w.close()\n", + "\n", + " w = open(node_label_output, 'w')\n", + " for line in outl:\n", + " w.write(line)\n", + " w.close()\n", + " \n", + "#!pwd\n", + "import os\n", + "data_dir = \"raw/\"\n", + "if not os.path.exists('raw/out'):\n", + " os.makedirs('raw/out')\n", + "\n", + " edge_input_file = data_dir + \"edge.csv\"\n", + " edge_output_file = data_dir + \"out/output.csv\"\n", + "\n", + " sample_edges(edge_input_file, edge_output_file)\n", + "\n", + " node_file = data_dir + \"node-feat.csv\"\n", + " node_label_file = data_dir + \"node-label.csv\"\n", + " node_feature_output = data_dir + \"out/node-feat.csv\"\n", + " node_label_output = data_dir + \"out/node-label.csv\"\n", + " get_node_features(node_file, node_label_file, edge_output_file,\n", + " node_feature_output, node_label_output)\n", + "else:\n", + " edge_output_file = data_dir + \"out/output.csv\"\n", + " node_feature_output = data_dir + \"out/node-feat.csv\"\n", + " node_label_output = data_dir + \"out/node-label.csv\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RWlV3LMcWRBJ" + }, + "source": [ + "## GNN Stack Module\n", + "\n", + "Below is the implementation for a general GNN Module that could plugin any layers, including **GraphSage**, **GAT**, etc. This module is provided for you, and you own **GraphSage** and **GAT** layers will function as components in the GNNStack Module.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "F54FqJdtWMn-" + }, + "outputs": [], + "source": [ + "# GNN Stack Module \n", + "import torch\n", + "import torch_scatter\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import torch_geometric.nn as pyg_nn\n", + "import torch_geometric.utils as pyg_utils\n", + "\n", + "from torch import Tensor\n", + "from typing import Union, Tuple, Optional\n", + "from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType, OptTensor)\n", + "\n", + "from torch.nn import Parameter, Linear\n", + "from torch_sparse import SparseTensor, set_diag\n", + "from torch_geometric.nn.conv import MessagePassing\n", + "from torch_geometric.utils import remove_self_loops, add_self_loops, softmax\n", + "\n", + "class GNNStack(torch.nn.Module):\n", + " def __init__(self, input_dim, hidden_dim, output_dim, args, emb=False):\n", + " super(GNNStack, self).__init__()\n", + " conv_model = self.build_conv_model(args.model_type)\n", + " self.convs = nn.ModuleList()\n", + " self.convs.append(conv_model(input_dim, hidden_dim))\n", + " assert (args.num_layers >= 1), 'Number of layers is not >=1'\n", + " for l in range(args.num_layers-1):\n", + " self.convs.append(conv_model(args.heads * hidden_dim, hidden_dim))\n", + "\n", + " # post-message-passing\n", + " self.post_mp = nn.Sequential(\n", + " nn.Linear(args.heads * hidden_dim, hidden_dim), nn.Dropout(args.dropout), \n", + " nn.Linear(hidden_dim, output_dim))\n", + "\n", + " self.dropout = args.dropout\n", + " self.num_layers = args.num_layers\n", + "\n", + " self.emb = emb\n", + "\n", + " def build_conv_model(self, model_type):\n", + " if model_type == 'GraphSage':\n", + " return GraphSage\n", + " elif model_type == 'GAT':\n", + " return GAT\n", + "\n", + " def forward(self, data):\n", + " print(data)\n", + " x, edge_index = data.x, data.edge_index\n", + " \n", + " for i in range(self.num_layers):\n", + " x = self.convs[i](x, edge_index)\n", + " x = F.relu(x)\n", + " x = F.dropout(x, p=self.dropout)\n", + "\n", + " x = self.post_mp(x)\n", + "\n", + " if self.emb == True:\n", + " return x\n", + "\n", + " return F.log_softmax(x, dim=1)\n", + "\n", + " def loss(self, pred, label):\n", + " return F.nll_loss(pred, label)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OfTCVJmTWaCl" + }, + "source": [ + "## GraphSage Implementation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2ngZJjEuWdPU" + }, + "outputs": [], + "source": [ + "class GraphSage(MessagePassing):\n", + " \n", + " def __init__(self, in_channels, out_channels, normalize = True,\n", + " bias = None, **kwargs): \n", + " super(GraphSage, self).__init__(**kwargs)\n", + "\n", + " self.in_channels = in_channels\n", + " self.out_channels = out_channels\n", + " self.normalize = normalize\n", + " self.bias = bias\n", + " self.lin_l = None\n", + " self.lin_r = None\n", + "\n", + " self.lin_l = torch.nn.Linear(self.in_channels, self.out_channels)\n", + " self.lin_r = torch.nn.Linear(self.in_channels, self.out_channels)\n", + "\n", + " self.reset_parameters()\n", + "\n", + " def reset_parameters(self):\n", + " self.lin_l.reset_parameters()\n", + " self.lin_r.reset_parameters()\n", + "\n", + " def forward(self, x, edge_index, size = None):\n", + " out = None\n", + " h_l = self.lin_l(x)\n", + " print(\"ZZZ:\", edge_index.shape, x.shape, size)\n", + " h_r = self.propagate(edge_index, x=(x, x))\n", + " print(\"ZZZ:\", edge_index.shape, x.shape, size)\n", + " h_r = self.lin_r(x)\n", + " out = h_l + h_r\n", + " if self.normalize:\n", + " out = F.normalize(out)\n", + " return out\n", + "\n", + " def message(self, x_j):\n", + " out = x_j\n", + "\n", + " return out\n", + "\n", + " def aggregate(self, inputs, index, dim_size = None):\n", + "\n", + " out = None\n", + " node_dim = self.node_dim\n", + " out = torch_scatter.scatter(src=inputs, index=index, dim=node_dim, reduce='mean')\n", + "\n", + " return out\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rZaBcVJpWhpZ" + }, + "source": [ + "## GAT Implementation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Hkch0GbGWjMj" + }, + "outputs": [], + "source": [ + "class GAT(MessagePassing):\n", + "\n", + " def __init__(self, in_channels, out_channels, heads = 2,\n", + " negative_slope = 0.2, dropout = 0., **kwargs):\n", + " super(GAT, self).__init__(node_dim=0, **kwargs)\n", + "\n", + " self.in_channels = in_channels\n", + " self.out_channels = out_channels\n", + " self.heads = heads\n", + " self.negative_slope = negative_slope\n", + " self.dropout = dropout\n", + " self.lin_l = nn.Linear(in_channels, heads * out_channels)\n", + "\n", + " self.lin_r = self.lin_l\n", + "\n", + " self.att_r = nn.Parameter(torch.zeros([heads, out_channels, 1], dtype=torch.float))\n", + " self.att_l = nn.Parameter(torch.zeros([heads, out_channels, 1], dtype=torch.float))\n", + "\n", + " self.reset_parameters()\n", + "\n", + " def reset_parameters(self):\n", + " nn.init.xavier_uniform_(self.lin_l.weight)\n", + " nn.init.xavier_uniform_(self.lin_r.weight)\n", + " nn.init.xavier_uniform_(self.att_l)\n", + " nn.init.xavier_uniform_(self.att_r)\n", + "\n", + " def forward(self, x, edge_index, size = None):\n", + "\n", + " H, C = self.heads, self.out_channels\n", + " z1 = self.lin_l(x)\n", + " z2 = self.lin_r(x)\n", + " h1 = z1.reshape([z1.shape[0], H, C])\n", + " h2 = z2.reshape([z2.shape[0], H, C])\n", + " h1e = h1[edge_index[0]]\n", + " h2e = h2[edge_index[1]]\n", + "\n", + " alpha_l = torch.matmul(self.att_l.reshape([1, H, 1, C]), h1.reshape([h1.shape[0], H, C, 1]))\n", + " alpha_r = torch.matmul(self.att_r.reshape([1, H, 1, C]), h2.reshape([h2.shape[0], H, C, 1]))\n", + " alpha_l = alpha_l.reshape([h1.shape[0], H])\n", + " alpha_r = alpha_r.reshape([h2.shape[0], H])\n", + "\n", + " z = self.propagate(edge_index, x=(h1, h2), alpha=(alpha_l, alpha_r))\n", + " out = z.reshape([z.shape[0], z.shape[1] * z.shape[2]])\n", + "\n", + " return out\n", + "\n", + " def message(self, x_j, alpha_j, alpha_i, index, ptr, size_i):\n", + " ax = F.leaky_relu(alpha_i + alpha_j, negative_slope=self.negative_slope)\n", + " a = pyg_utils.softmax(\n", + " ax,\n", + " index=index, ptr=ptr, num_nodes=size_i)\n", + " a1 = F.dropout(a, p=self.dropout)\n", + " a1 = a1.reshape([a1.shape[0], a1.shape[1], 1])\n", + " out = torch.mul(a1, x_j)\n", + "\n", + " return out\n", + "\n", + " def aggregate(self, inputs, index, dim_size = None):\n", + "\n", + " out = torch_scatter.scatter(inputs, index, dim=0, dim_size=dim_size, reduce='sum')\n", + "\n", + " return out" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ajK1UsIkWm25" + }, + "source": [ + "## Building Optimizers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2cPxd4bKWo2F" + }, + "outputs": [], + "source": [ + "import torch.optim as optim\n", + "\n", + "def build_optimizer(args, params):\n", + " weight_decay = args.weight_decay\n", + " filter_fn = filter(lambda p : p.requires_grad, params)\n", + " if args.opt == 'adam':\n", + " optimizer = optim.Adam(filter_fn, lr=args.lr, weight_decay=weight_decay)\n", + " elif args.opt == 'sgd':\n", + " optimizer = optim.SGD(filter_fn, lr=args.lr, momentum=0.95, weight_decay=weight_decay)\n", + " elif args.opt == 'rmsprop':\n", + " optimizer = optim.RMSprop(filter_fn, lr=args.lr, weight_decay=weight_decay)\n", + " elif args.opt == 'adagrad':\n", + " optimizer = optim.Adagrad(filter_fn, lr=args.lr, weight_decay=weight_decay)\n", + " if args.opt_scheduler == 'none':\n", + " return None, optimizer\n", + " elif args.opt_scheduler == 'step':\n", + " scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.opt_decay_step, gamma=args.opt_decay_rate)\n", + " elif args.opt_scheduler == 'cos':\n", + " scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.opt_restart)\n", + " return scheduler, optimizer" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sbSrQI9GWrrN" + }, + "source": [ + "## Training and testing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "a4z4xvSVWtrF" + }, + "outputs": [], + "source": [ + "import time\n", + "\n", + "import networkx as nx\n", + "import numpy as np\n", + "import torch\n", + "import torch.optim as optim\n", + "\n", + "from torch_geometric.data import DataLoader\n", + "\n", + "import torch_geometric.nn as pyg_nn\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "def train(dataset, args):\n", + " \n", + " test_loader = loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)\n", + " # build model\n", + " # model = GNNStack(dataset.num_node_features, args.hidden_dim, dataset.num_classes, args)\n", + " model = GNNStack(100, args.hidden_dim, args.num_classes, args)\n", + " scheduler, opt = build_optimizer(args, model.parameters())\n", + "\n", + " # train\n", + " losses = []\n", + " test_accs = []\n", + " for epoch in range(args.epochs):\n", + " total_loss = 0\n", + " model.train()\n", + " for batch in loader:\n", + " opt.zero_grad()\n", + " pred = model(batch)\n", + " label = batch.y.reshape(batch.y.shape[0])\n", + " print(\"ZZZ20\", pred.shape, label.shape, pred.dtype, label.dtype)\n", + " #pred = pred[batch.train_mask] \n", + " #label = label[batch.train_mask]\n", + " loss = model.loss(pred, label)\n", + " loss.backward()\n", + " opt.step()\n", + " #total_loss += loss.item() * batch.num_graphs\n", + " total_loss += loss.item()\n", + " # total_loss /= len(loader.dataset)\n", + " losses.append(total_loss)\n", + "\n", + " # uncommented # Arvind: commented test function \n", + " if epoch % 10 == 0:\n", + " test_acc = test(test_loader, model)\n", + " test_accs.append(test_acc)\n", + " else:\n", + " test_accs.append(test_accs[-1])\n", + " return test_accs, losses\n", + "\n", + "def test(loader, model, is_validation=True):\n", + " model.eval()\n", + "\n", + " correct = 0\n", + " for data in loader:\n", + " with torch.no_grad():\n", + " # max(dim=1) returns values, indices tuple; only need indices\n", + " pred = model(data).max(dim=1)[1]\n", + " label = data.y\n", + "\n", + " mask = data.val_mask if is_validation else data.test_mask\n", + " # node classification: only evaluate on nodes in test set\n", + " pred = pred[mask]\n", + " label = data.y[mask]\n", + " \n", + " correct += pred.eq(label).sum().item()\n", + "\n", + " total = 0\n", + " for data in loader.dataset:\n", + " total += torch.sum(data.val_mask if is_validation else data.test_mask).item()\n", + " return correct / total\n", + "\n", + "class objectview(object):\n", + " def __init__(self, d):\n", + " self.__dict__ = d" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RcQWxCyU3_lo" + }, + "source": [ + "## Train on ogbn-products" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3wnHxNhtVgNy" + }, + "source": [ + "## Load dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "87EKi89i9BBG" + }, + "outputs": [], + "source": [ + "# from ogb.nodeproppred import NodePropPredDataset\n", + "# dataset = NodePropPredDataset(name = 'ogbn-products')\n", + "# split_idx = dataset.get_idx_split()\n", + "# train_idx, valid_idx, test_idx = split_idx[\"train\"], split_idx[\"valid\"], split_idx[\"test\"]\n", + "\n", + "# from ogb.nodeproppred import PygNodePropPredDataset, Evaluator\n", + "# dataset = PygNodePropPredDataset('ogbn-products')\n", + "# split_idx = dataset.get_idx_split()\n", + "# data = dataset[0]\n", + "# train_idx = split_idx['train']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 349, + "status": "ok", + "timestamp": 1614999517590, + "user": { + "displayName": "Arvind Srivastav", + "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GhKsFI_4660fvIUHDgSAVEBHXHmkjgDRf91ukLdhw=s64", + "userId": "13824725762542667997" + }, + "user_tz": 480 + }, + "id": "deaeTkxdRsua", + "outputId": "98b234ae-f676-47c5-acc6-e9b763faee09" + }, + "outputs": [], + "source": [ + "# Data subset\n", + "\n", + "def remap_edges(edge_list, node_map):\n", + " remapped_edges = torch.zeros(edge_list.shape, dtype=torch.long)\n", + " for e in edge_list:\n", + " remapped_edges[0] = node_map[e[0]]\n", + " remapped_edges[1] = node_map[e[1]]\n", + " return remapped_edges\n", + "\n", + "def load_sparse_node_features(node_feature_file):\n", + " node_map = {}\n", + " features_list = []\n", + " with open(node_feature_file) as f:\n", + " while 1:\n", + " line = f.readline()\n", + " if not line:\n", + " break\n", + " values = line.strip().split(',')\n", + " features = [float(x) for x in values[1:]]\n", + " features_list.append(features)\n", + " node_id = int(values[0])\n", + " if not (node_id in node_map):\n", + " node_map[node_id] = len(node_map)\n", + " features_tensor = torch.Tensor(features_list)\n", + " return (node_map, features_tensor)\n", + "\n", + "edges = np.genfromtxt(edge_output_file, dtype=int, delimiter=',')\n", + "print(\"zzz1\", edges.shape)\n", + "# labels = np.genfromtxt(node_label_output, dtype=float, delimiter=',')\n", + "node_map, features = load_sparse_node_features(node_feature_output)\n", + "print(\"zzz2\", len(node_map), features.shape)\n", + "nmap2, labels = load_sparse_node_features(node_label_output)\n", + "print(\"zzz2\", len(nmap2), labels.shape, labels.dtype)\n", + "\n", + "## Sanity Check: Each node has a label. Each edge endpoint is in node_map\n", + "error_count = 0\n", + "for n in node_map:\n", + " if not (n in nmap2):\n", + " error_count = error_count + 1\n", + " if error_count < 10:\n", + " print(\"Problem with node:\", n)\n", + "\n", + "error_count = 0\n", + "for e in edges:\n", + " if not ((e[0] in node_map) and (e[1] in node_map)):\n", + " error_count = error_count + 1\n", + " if error_count < 10:\n", + " print(\"Problem with edge: \", e[0], e[1])\n", + " \n", + "print(\"zzz3\", torch.max(labels))\n", + " \n", + " \n", + "###\n", + "\n", + "num_classes = int(torch.round(torch.max(labels)).item()) + 1\n", + "num_nodes = features.shape[0]\n", + "# Create the train, valid, test indexes\n", + "idx_mask = torch.rand(num_nodes)\n", + "train_mask = torch.lt(idx_mask, 0.8)\n", + "test_mask = torch.logical_and(torch.lt(idx_mask, 0.9), torch.ge(idx_mask, 0.8))\n", + "valid_mask = torch.ge(idx_mask, 0.9)\n", + "print(\"zzz3\", num_classes, train_mask.shape, test_mask.shape, valid_mask.shape, train_mask.dtype)\n", + "num_masked = torch.sum(torch.logical_or(torch.logical_or(train_mask, test_mask), valid_mask).to(torch.int)).item()\n", + "assert num_masked == num_nodes, \"{} {}\".format(num_classes, num_masked)\n", + "num_common = torch.sum(torch.logical_and(torch.logical_and(train_mask, test_mask), valid_mask).to(torch.int)).item()\n", + "assert num_common == 0\n", + "###\n", + "\n", + "edges = remap_edges(edges, node_map).t()\n", + "labels = torch.round(labels).to(torch.long)\n", + "features = torch.FloatTensor(features)\n", + "\n", + "data = pyg.data.Data(x=features, edge_index=edges, y=labels, test_mask=test_mask, val_mask=valid_mask, train_mask=train_mask)\n", + "print(data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "executionInfo": { + "elapsed": 373, + "status": "error", + "timestamp": 1614999485794, + "user": { + "displayName": "Arvind Srivastav", + "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GhKsFI_4660fvIUHDgSAVEBHXHmkjgDRf91ukLdhw=s64", + "userId": "13824725762542667997" + }, + "user_tz": 480 + }, + "id": "FxE93EcWWywt", + "outputId": "92a24fb2-ae7c-402e-dc44-eacb94f2a313" + }, + "outputs": [], + "source": [ + "def main():\n", + " for args in [\n", + " {'model_type': 'GraphSage',\n", + " 'num_classes': num_classes,\n", + " 'dataset': 'cora', 'num_layers': 2,\n", + " 'heads': 1, 'batch_size': 32, 'hidden_dim': 32, 'dropout': 0.5,\n", + " 'epochs': 50, 'opt': 'adam', 'opt_scheduler': 'none', 'opt_restart': 0,\n", + " 'weight_decay': 5e-3, 'lr': 0.01},\n", + " ]:\n", + " args = objectview(args)\n", + " for model in ['GraphSage', 'GAT']:\n", + " args.model_type = model\n", + "\n", + " # Match the dimension.\n", + " if model == 'GAT':\n", + " args.heads = 2\n", + " else:\n", + " args.heads = 1\n", + "\n", + " test_accs, losses = train([data], args) \n", + "\n", + " print(\"Maximum accuracy: {0}\".format(max(test_accs)))\n", + " print(\"Minimum loss: {0}\".format(min(losses)))\n", + "\n", + " plt.plot(losses, label=\"training loss\" + \" - \" + args.model_type)\n", + " plt.plot(test_accs, label=\"test accuracy\" + \" - \" + args.model_type)\n", + " plt.legend()\n", + " plt.show()\n", + "\n", + "if __name__ == '__main__':\n", + " main()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 301, + "status": "ok", + "timestamp": 1614999455639, + "user": { + "displayName": "Arvind Srivastav", + "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GhKsFI_4660fvIUHDgSAVEBHXHmkjgDRf91ukLdhw=s64", + "userId": "13824725762542667997" + }, + "user_tz": 480 + }, + "id": "WQTFwXDrdkkF", + "outputId": "1017f124-10b6-4ff2-fb10-41212c816afd" + }, + "outputs": [], + "source": [ + "# loader = DataLoader([data], batch_size=32)\n", + "# print(loader)\n", + "# for batch in loader:\n", + "# print(batch)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "n3iiIHui9Dv6" + }, + "outputs": [], + "source": [ + "# loader = DataLoader(data, batch_size=32, shuffle=True)\n", + "# for batch in loader:\n", + "# print(batch)\n", + "# print('sd')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pxTwT4afH4ie" + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "s2QXMsJ2Vxqi" + }, + "source": [ + "## Archived" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HefWGx8--RZf" + }, + "source": [ + "## Inductive split using DeepSNAP" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "vie4Wruv-U9h" + }, + "outputs": [], + "source": [ + "# !pip install -q deepsnap\n", + "# from deepsnap.graph import Graph\n", + "# from deepsnap.batch import Batch\n", + "# from deepsnap.dataset import GraphDataset\n", + "# graphs = GraphDataset.pyg_to_graphs(dataset)\n", + "\n", + "# task = 'node'\n", + "# dataset = GraphDataset(graphs, task=task)\n", + "# dataset_train, dataset_val, dataset_test = dataset.split(\n", + "# transductive=False, split_ratio=[0.2, 0.2, 0.6])" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [ + "ajK1UsIkWm25" + ], + "machine_shape": "hm", + "name": "baseline.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}