diff --git a/baseline.ipynb b/baseline.ipynb new file mode 100644 index 0000000..9c66bf1 --- /dev/null +++ b/baseline.ipynb @@ -0,0 +1,918 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 29086, + "status": "ok", + "timestamp": 1614993114537, + "user": { + "displayName": "Arvind Srivastav", + "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GhKsFI_4660fvIUHDgSAVEBHXHmkjgDRf91ukLdhw=s64", + "userId": "13824725762542667997" + }, + "user_tz": 480 + }, + "id": "nQQedrLOIM5n", + "outputId": "c977889e-1d58-4e73-e2de-2db18a9dbcd8" + }, + "outputs": [], + "source": [ + "#!pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html\n", + "#!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html\n", + "#!pip install -q torch-geometric\n", + "#!pip install -q git+https://github.com/snap-stanford/deepsnap.git\n", + "#!pip install -q ogb" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 35 + }, + "executionInfo": { + "elapsed": 8138, + "status": "ok", + "timestamp": 1614993206556, + "user": { + "displayName": "Arvind Srivastav", + "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GhKsFI_4660fvIUHDgSAVEBHXHmkjgDRf91ukLdhw=s64", + "userId": "13824725762542667997" + }, + "user_tz": 480 + }, + "id": "g4BNgPthWJXh", + "outputId": "8d5c40d4-725c-4069-f5fb-e89a34434a20" + }, + "outputs": [], + "source": [ + "import torch_geometric\n", + "import torch_geometric as pyg\n", + "torch_geometric.__version__" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "v5mEnLJ9zrBH" + }, + "source": [ + "## Mount drive and load data folder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 21285, + "status": "ok", + "timestamp": 1614993927175, + "user": { + "displayName": "Arvind Srivastav", + "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GhKsFI_4660fvIUHDgSAVEBHXHmkjgDRf91ukLdhw=s64", + "userId": "13824725762542667997" + }, + "user_tz": 480 + }, + "id": "_xROmo7tzcAL", + "outputId": "170ad2f9-ae79-4b74-c61f-fd9730510402" + }, + "outputs": [], + "source": [ + "use_localenv = 1\n", + "if use_localenv == 0:\n", + " from google.colab import drive\n", + " drive.mount('/content/drive')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 3872, + "status": "ok", + "timestamp": 1614993932239, + "user": { + "displayName": "Arvind Srivastav", + "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GhKsFI_4660fvIUHDgSAVEBHXHmkjgDRf91ukLdhw=s64", + "userId": "13824725762542667997" + }, + "user_tz": 480 + }, + "id": "hf8e9GbxzvJt", + "outputId": "2edf8ad5-d491-4c0f-d5cd-1e491595d1bb" + }, + "outputs": [], + "source": [ + "if use_localenv == 0:\n", + " data_dir = \"/content/drive/MyDrive/CS224w-project/dataset/dataset/ogbn_products/\"\n", + " %cd drive/MyDrive/CS224w-project/dataset/dataset/ogbn_products/\n", + " !ls\n", + "else:\n", + " import os\n", + " data_dir = \"C:\\\\Users\\\\shaon\\\\Desktop\\\\CS224W\\\\CS224W_2021\\\\CS224W_PROJECT\\\\dataset\\\\ogbn_products\"\n", + " os.chdir(data_dir)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2fouCD0AU5Vt" + }, + "source": [ + "## Generate small graph for baseline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "21RZA4U_3DXE" + }, + "outputs": [], + "source": [ + "import random\n", + "\n", + "def sample_edges(edge_file_name, output_file_name):\n", + " d1 = {}\n", + " l = []\n", + " for i in range(1000):\n", + " d1[random.randint(0, 61859140)] = 1\n", + "\n", + " f = open(edge_file_name)\n", + " c = 0\n", + " out = []\n", + " while True:\n", + " l = f.readline()\n", + " if len(l) == 0:\n", + " break\n", + " if c in d1:\n", + " out.append(l)\n", + " c += 1\n", + "\n", + " w = open(output_file_name, 'w')\n", + " for line in out:\n", + " w.write(line)\n", + "\n", + " w.close()\n", + "\n", + "\n", + "def get_node_features(node_file, node_label_file, edge_file,\n", + " node_feature_output, node_label_output):\n", + " f = open(edge_file)\n", + " nodes = []\n", + " for line in f.readlines():\n", + " values = line.strip().split(',')\n", + " for x in values:\n", + " if len(x.strip()) > 0:\n", + " nodes.append(int(x))\n", + "\n", + " print(\"Num nodes before dedup=\", len(nodes))\n", + " nodes = set(nodes)\n", + " print(\"Num nodes=\", len(nodes))\n", + "\n", + " c = 0\n", + " nf = open(node_file)\n", + " nl = open(node_label_file)\n", + " outf = []\n", + " outl = []\n", + " while True:\n", + " features = nf.readline()\n", + " label = nl.readline()\n", + "\n", + " if not features:\n", + " break\n", + " if c in nodes:\n", + " outf.append(str(c) + ',' + features)\n", + " outl.append(str(c) + ',' + label)\n", + " c = c + 1\n", + "\n", + " w = open(node_feature_output, 'w')\n", + " for line in outf:\n", + " w.write(line)\n", + " w.close()\n", + "\n", + " w = open(node_label_output, 'w')\n", + " for line in outl:\n", + " w.write(line)\n", + " w.close()\n", + " \n", + "#!pwd\n", + "import os\n", + "data_dir = \"raw/\"\n", + "if not os.path.exists('raw/out'):\n", + " os.makedirs('raw/out')\n", + "\n", + " edge_input_file = data_dir + \"edge.csv\"\n", + " edge_output_file = data_dir + \"out/output.csv\"\n", + "\n", + " sample_edges(edge_input_file, edge_output_file)\n", + "\n", + " node_file = data_dir + \"node-feat.csv\"\n", + " node_label_file = data_dir + \"node-label.csv\"\n", + " node_feature_output = data_dir + \"out/node-feat.csv\"\n", + " node_label_output = data_dir + \"out/node-label.csv\"\n", + " get_node_features(node_file, node_label_file, edge_output_file,\n", + " node_feature_output, node_label_output)\n", + "else:\n", + " edge_output_file = data_dir + \"out/output.csv\"\n", + " node_feature_output = data_dir + \"out/node-feat.csv\"\n", + " node_label_output = data_dir + \"out/node-label.csv\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RWlV3LMcWRBJ" + }, + "source": [ + "## GNN Stack Module\n", + "\n", + "Below is the implementation for a general GNN Module that could plugin any layers, including **GraphSage**, **GAT**, etc. This module is provided for you, and you own **GraphSage** and **GAT** layers will function as components in the GNNStack Module.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "F54FqJdtWMn-" + }, + "outputs": [], + "source": [ + "# GNN Stack Module \n", + "import torch\n", + "import torch_scatter\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import torch_geometric.nn as pyg_nn\n", + "import torch_geometric.utils as pyg_utils\n", + "\n", + "from torch import Tensor\n", + "from typing import Union, Tuple, Optional\n", + "from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType, OptTensor)\n", + "\n", + "from torch.nn import Parameter, Linear\n", + "from torch_sparse import SparseTensor, set_diag\n", + "from torch_geometric.nn.conv import MessagePassing\n", + "from torch_geometric.utils import remove_self_loops, add_self_loops, softmax\n", + "\n", + "class GNNStack(torch.nn.Module):\n", + " def __init__(self, input_dim, hidden_dim, output_dim, args, emb=False):\n", + " super(GNNStack, self).__init__()\n", + " conv_model = self.build_conv_model(args.model_type)\n", + " self.convs = nn.ModuleList()\n", + " self.convs.append(conv_model(input_dim, hidden_dim))\n", + " assert (args.num_layers >= 1), 'Number of layers is not >=1'\n", + " for l in range(args.num_layers-1):\n", + " self.convs.append(conv_model(args.heads * hidden_dim, hidden_dim))\n", + "\n", + " # post-message-passing\n", + " self.post_mp = nn.Sequential(\n", + " nn.Linear(args.heads * hidden_dim, hidden_dim), nn.Dropout(args.dropout), \n", + " nn.Linear(hidden_dim, output_dim))\n", + "\n", + " self.dropout = args.dropout\n", + " self.num_layers = args.num_layers\n", + "\n", + " self.emb = emb\n", + "\n", + " def build_conv_model(self, model_type):\n", + " if model_type == 'GraphSage':\n", + " return GraphSage\n", + " elif model_type == 'GAT':\n", + " return GAT\n", + "\n", + " def forward(self, data):\n", + " print(data)\n", + " x, edge_index = data.x, data.edge_index\n", + " \n", + " for i in range(self.num_layers):\n", + " x = self.convs[i](x, edge_index)\n", + " x = F.relu(x)\n", + " x = F.dropout(x, p=self.dropout)\n", + "\n", + " x = self.post_mp(x)\n", + "\n", + " if self.emb == True:\n", + " return x\n", + "\n", + " return F.log_softmax(x, dim=1)\n", + "\n", + " def loss(self, pred, label):\n", + " return F.nll_loss(pred, label)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OfTCVJmTWaCl" + }, + "source": [ + "## GraphSage Implementation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2ngZJjEuWdPU" + }, + "outputs": [], + "source": [ + "class GraphSage(MessagePassing):\n", + " \n", + " def __init__(self, in_channels, out_channels, normalize = True,\n", + " bias = None, **kwargs): \n", + " super(GraphSage, self).__init__(**kwargs)\n", + "\n", + " self.in_channels = in_channels\n", + " self.out_channels = out_channels\n", + " self.normalize = normalize\n", + " self.bias = bias\n", + " self.lin_l = None\n", + " self.lin_r = None\n", + "\n", + " self.lin_l = torch.nn.Linear(self.in_channels, self.out_channels)\n", + " self.lin_r = torch.nn.Linear(self.in_channels, self.out_channels)\n", + "\n", + " self.reset_parameters()\n", + "\n", + " def reset_parameters(self):\n", + " self.lin_l.reset_parameters()\n", + " self.lin_r.reset_parameters()\n", + "\n", + " def forward(self, x, edge_index, size = None):\n", + " out = None\n", + " h_l = self.lin_l(x)\n", + " print(\"ZZZ:\", edge_index.shape, x.shape, size)\n", + " h_r = self.propagate(edge_index, x=(x, x))\n", + " print(\"ZZZ:\", edge_index.shape, x.shape, size)\n", + " h_r = self.lin_r(x)\n", + " out = h_l + h_r\n", + " if self.normalize:\n", + " out = F.normalize(out)\n", + " return out\n", + "\n", + " def message(self, x_j):\n", + " out = x_j\n", + "\n", + " return out\n", + "\n", + " def aggregate(self, inputs, index, dim_size = None):\n", + "\n", + " out = None\n", + " node_dim = self.node_dim\n", + " out = torch_scatter.scatter(src=inputs, index=index, dim=node_dim, reduce='mean')\n", + "\n", + " return out\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rZaBcVJpWhpZ" + }, + "source": [ + "## GAT Implementation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Hkch0GbGWjMj" + }, + "outputs": [], + "source": [ + "class GAT(MessagePassing):\n", + "\n", + " def __init__(self, in_channels, out_channels, heads = 2,\n", + " negative_slope = 0.2, dropout = 0., **kwargs):\n", + " super(GAT, self).__init__(node_dim=0, **kwargs)\n", + "\n", + " self.in_channels = in_channels\n", + " self.out_channels = out_channels\n", + " self.heads = heads\n", + " self.negative_slope = negative_slope\n", + " self.dropout = dropout\n", + " self.lin_l = nn.Linear(in_channels, heads * out_channels)\n", + "\n", + " self.lin_r = self.lin_l\n", + "\n", + " self.att_r = nn.Parameter(torch.zeros([heads, out_channels, 1], dtype=torch.float))\n", + " self.att_l = nn.Parameter(torch.zeros([heads, out_channels, 1], dtype=torch.float))\n", + "\n", + " self.reset_parameters()\n", + "\n", + " def reset_parameters(self):\n", + " nn.init.xavier_uniform_(self.lin_l.weight)\n", + " nn.init.xavier_uniform_(self.lin_r.weight)\n", + " nn.init.xavier_uniform_(self.att_l)\n", + " nn.init.xavier_uniform_(self.att_r)\n", + "\n", + " def forward(self, x, edge_index, size = None):\n", + "\n", + " H, C = self.heads, self.out_channels\n", + " z1 = self.lin_l(x)\n", + " z2 = self.lin_r(x)\n", + " h1 = z1.reshape([z1.shape[0], H, C])\n", + " h2 = z2.reshape([z2.shape[0], H, C])\n", + " h1e = h1[edge_index[0]]\n", + " h2e = h2[edge_index[1]]\n", + "\n", + " alpha_l = torch.matmul(self.att_l.reshape([1, H, 1, C]), h1.reshape([h1.shape[0], H, C, 1]))\n", + " alpha_r = torch.matmul(self.att_r.reshape([1, H, 1, C]), h2.reshape([h2.shape[0], H, C, 1]))\n", + " alpha_l = alpha_l.reshape([h1.shape[0], H])\n", + " alpha_r = alpha_r.reshape([h2.shape[0], H])\n", + "\n", + " z = self.propagate(edge_index, x=(h1, h2), alpha=(alpha_l, alpha_r))\n", + " out = z.reshape([z.shape[0], z.shape[1] * z.shape[2]])\n", + "\n", + " return out\n", + "\n", + " def message(self, x_j, alpha_j, alpha_i, index, ptr, size_i):\n", + " ax = F.leaky_relu(alpha_i + alpha_j, negative_slope=self.negative_slope)\n", + " a = pyg_utils.softmax(\n", + " ax,\n", + " index=index, ptr=ptr, num_nodes=size_i)\n", + " a1 = F.dropout(a, p=self.dropout)\n", + " a1 = a1.reshape([a1.shape[0], a1.shape[1], 1])\n", + " out = torch.mul(a1, x_j)\n", + "\n", + " return out\n", + "\n", + " def aggregate(self, inputs, index, dim_size = None):\n", + "\n", + " out = torch_scatter.scatter(inputs, index, dim=0, dim_size=dim_size, reduce='sum')\n", + "\n", + " return out" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ajK1UsIkWm25" + }, + "source": [ + "## Building Optimizers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2cPxd4bKWo2F" + }, + "outputs": [], + "source": [ + "import torch.optim as optim\n", + "\n", + "def build_optimizer(args, params):\n", + " weight_decay = args.weight_decay\n", + " filter_fn = filter(lambda p : p.requires_grad, params)\n", + " if args.opt == 'adam':\n", + " optimizer = optim.Adam(filter_fn, lr=args.lr, weight_decay=weight_decay)\n", + " elif args.opt == 'sgd':\n", + " optimizer = optim.SGD(filter_fn, lr=args.lr, momentum=0.95, weight_decay=weight_decay)\n", + " elif args.opt == 'rmsprop':\n", + " optimizer = optim.RMSprop(filter_fn, lr=args.lr, weight_decay=weight_decay)\n", + " elif args.opt == 'adagrad':\n", + " optimizer = optim.Adagrad(filter_fn, lr=args.lr, weight_decay=weight_decay)\n", + " if args.opt_scheduler == 'none':\n", + " return None, optimizer\n", + " elif args.opt_scheduler == 'step':\n", + " scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.opt_decay_step, gamma=args.opt_decay_rate)\n", + " elif args.opt_scheduler == 'cos':\n", + " scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.opt_restart)\n", + " return scheduler, optimizer" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sbSrQI9GWrrN" + }, + "source": [ + "## Training and testing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "a4z4xvSVWtrF" + }, + "outputs": [], + "source": [ + "import time\n", + "\n", + "import networkx as nx\n", + "import numpy as np\n", + "import torch\n", + "import torch.optim as optim\n", + "\n", + "from torch_geometric.data import DataLoader\n", + "\n", + "import torch_geometric.nn as pyg_nn\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "def train(dataset, args):\n", + " \n", + " test_loader = loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)\n", + " # build model\n", + " # model = GNNStack(dataset.num_node_features, args.hidden_dim, dataset.num_classes, args)\n", + " model = GNNStack(100, args.hidden_dim, args.num_classes, args)\n", + " scheduler, opt = build_optimizer(args, model.parameters())\n", + "\n", + " # train\n", + " losses = []\n", + " test_accs = []\n", + " for epoch in range(args.epochs):\n", + " total_loss = 0\n", + " model.train()\n", + " for batch in loader:\n", + " opt.zero_grad()\n", + " pred = model(batch)\n", + " label = batch.y.reshape(batch.y.shape[0])\n", + " print(\"ZZZ20\", pred.shape, label.shape, pred.dtype, label.dtype)\n", + " #pred = pred[batch.train_mask] \n", + " #label = label[batch.train_mask]\n", + " loss = model.loss(pred, label)\n", + " loss.backward()\n", + " opt.step()\n", + " #total_loss += loss.item() * batch.num_graphs\n", + " total_loss += loss.item()\n", + " # total_loss /= len(loader.dataset)\n", + " losses.append(total_loss)\n", + "\n", + " # uncommented # Arvind: commented test function \n", + " if epoch % 10 == 0:\n", + " test_acc = test(test_loader, model)\n", + " test_accs.append(test_acc)\n", + " else:\n", + " test_accs.append(test_accs[-1])\n", + " return test_accs, losses\n", + "\n", + "def test(loader, model, is_validation=True):\n", + " model.eval()\n", + "\n", + " correct = 0\n", + " for data in loader:\n", + " with torch.no_grad():\n", + " # max(dim=1) returns values, indices tuple; only need indices\n", + " pred = model(data).max(dim=1)[1]\n", + " label = data.y\n", + "\n", + " mask = data.val_mask if is_validation else data.test_mask\n", + " # node classification: only evaluate on nodes in test set\n", + " pred = pred[mask]\n", + " label = data.y[mask]\n", + " \n", + " correct += pred.eq(label).sum().item()\n", + "\n", + " total = 0\n", + " for data in loader.dataset:\n", + " total += torch.sum(data.val_mask if is_validation else data.test_mask).item()\n", + " return correct / total\n", + "\n", + "class objectview(object):\n", + " def __init__(self, d):\n", + " self.__dict__ = d" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RcQWxCyU3_lo" + }, + "source": [ + "## Train on ogbn-products" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3wnHxNhtVgNy" + }, + "source": [ + "## Load dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "87EKi89i9BBG" + }, + "outputs": [], + "source": [ + "# from ogb.nodeproppred import NodePropPredDataset\n", + "# dataset = NodePropPredDataset(name = 'ogbn-products')\n", + "# split_idx = dataset.get_idx_split()\n", + "# train_idx, valid_idx, test_idx = split_idx[\"train\"], split_idx[\"valid\"], split_idx[\"test\"]\n", + "\n", + "# from ogb.nodeproppred import PygNodePropPredDataset, Evaluator\n", + "# dataset = PygNodePropPredDataset('ogbn-products')\n", + "# split_idx = dataset.get_idx_split()\n", + "# data = dataset[0]\n", + "# train_idx = split_idx['train']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 349, + "status": "ok", + "timestamp": 1614999517590, + "user": { + "displayName": "Arvind Srivastav", + "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GhKsFI_4660fvIUHDgSAVEBHXHmkjgDRf91ukLdhw=s64", + "userId": "13824725762542667997" + }, + "user_tz": 480 + }, + "id": "deaeTkxdRsua", + "outputId": "98b234ae-f676-47c5-acc6-e9b763faee09" + }, + "outputs": [], + "source": [ + "# Data subset\n", + "\n", + "def remap_edges(edge_list, node_map):\n", + " remapped_edges = torch.zeros(edge_list.shape, dtype=torch.long)\n", + " for e in edge_list:\n", + " remapped_edges[0] = node_map[e[0]]\n", + " remapped_edges[1] = node_map[e[1]]\n", + " return remapped_edges\n", + "\n", + "def load_sparse_node_features(node_feature_file):\n", + " node_map = {}\n", + " features_list = []\n", + " with open(node_feature_file) as f:\n", + " while 1:\n", + " line = f.readline()\n", + " if not line:\n", + " break\n", + " values = line.strip().split(',')\n", + " features = [float(x) for x in values[1:]]\n", + " features_list.append(features)\n", + " node_id = int(values[0])\n", + " if not (node_id in node_map):\n", + " node_map[node_id] = len(node_map)\n", + " features_tensor = torch.Tensor(features_list)\n", + " return (node_map, features_tensor)\n", + "\n", + "edges = np.genfromtxt(edge_output_file, dtype=int, delimiter=',')\n", + "print(\"zzz1\", edges.shape)\n", + "# labels = np.genfromtxt(node_label_output, dtype=float, delimiter=',')\n", + "node_map, features = load_sparse_node_features(node_feature_output)\n", + "print(\"zzz2\", len(node_map), features.shape)\n", + "nmap2, labels = load_sparse_node_features(node_label_output)\n", + "print(\"zzz2\", len(nmap2), labels.shape, labels.dtype)\n", + "\n", + "## Sanity Check: Each node has a label. Each edge endpoint is in node_map\n", + "error_count = 0\n", + "for n in node_map:\n", + " if not (n in nmap2):\n", + " error_count = error_count + 1\n", + " if error_count < 10:\n", + " print(\"Problem with node:\", n)\n", + "\n", + "error_count = 0\n", + "for e in edges:\n", + " if not ((e[0] in node_map) and (e[1] in node_map)):\n", + " error_count = error_count + 1\n", + " if error_count < 10:\n", + " print(\"Problem with edge: \", e[0], e[1])\n", + " \n", + "print(\"zzz3\", torch.max(labels))\n", + " \n", + " \n", + "###\n", + "\n", + "num_classes = int(torch.round(torch.max(labels)).item()) + 1\n", + "num_nodes = features.shape[0]\n", + "# Create the train, valid, test indexes\n", + "idx_mask = torch.rand(num_nodes)\n", + "train_mask = torch.lt(idx_mask, 0.8)\n", + "test_mask = torch.logical_and(torch.lt(idx_mask, 0.9), torch.ge(idx_mask, 0.8))\n", + "valid_mask = torch.ge(idx_mask, 0.9)\n", + "print(\"zzz3\", num_classes, train_mask.shape, test_mask.shape, valid_mask.shape, train_mask.dtype)\n", + "num_masked = torch.sum(torch.logical_or(torch.logical_or(train_mask, test_mask), valid_mask).to(torch.int)).item()\n", + "assert num_masked == num_nodes, \"{} {}\".format(num_classes, num_masked)\n", + "num_common = torch.sum(torch.logical_and(torch.logical_and(train_mask, test_mask), valid_mask).to(torch.int)).item()\n", + "assert num_common == 0\n", + "###\n", + "\n", + "edges = remap_edges(edges, node_map).t()\n", + "labels = torch.round(labels).to(torch.long)\n", + "features = torch.FloatTensor(features)\n", + "\n", + "data = pyg.data.Data(x=features, edge_index=edges, y=labels, test_mask=test_mask, val_mask=valid_mask, train_mask=train_mask)\n", + "print(data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "executionInfo": { + "elapsed": 373, + "status": "error", + "timestamp": 1614999485794, + "user": { + "displayName": "Arvind Srivastav", + "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GhKsFI_4660fvIUHDgSAVEBHXHmkjgDRf91ukLdhw=s64", + "userId": "13824725762542667997" + }, + "user_tz": 480 + }, + "id": "FxE93EcWWywt", + "outputId": "92a24fb2-ae7c-402e-dc44-eacb94f2a313" + }, + "outputs": [], + "source": [ + "def main():\n", + " for args in [\n", + " {'model_type': 'GraphSage',\n", + " 'num_classes': num_classes,\n", + " 'dataset': 'cora', 'num_layers': 2,\n", + " 'heads': 1, 'batch_size': 32, 'hidden_dim': 32, 'dropout': 0.5,\n", + " 'epochs': 50, 'opt': 'adam', 'opt_scheduler': 'none', 'opt_restart': 0,\n", + " 'weight_decay': 5e-3, 'lr': 0.01},\n", + " ]:\n", + " args = objectview(args)\n", + " for model in ['GraphSage', 'GAT']:\n", + " args.model_type = model\n", + "\n", + " # Match the dimension.\n", + " if model == 'GAT':\n", + " args.heads = 2\n", + " else:\n", + " args.heads = 1\n", + "\n", + " test_accs, losses = train([data], args) \n", + "\n", + " print(\"Maximum accuracy: {0}\".format(max(test_accs)))\n", + " print(\"Minimum loss: {0}\".format(min(losses)))\n", + "\n", + " plt.plot(losses, label=\"training loss\" + \" - \" + args.model_type)\n", + " plt.plot(test_accs, label=\"test accuracy\" + \" - \" + args.model_type)\n", + " plt.legend()\n", + " plt.show()\n", + "\n", + "if __name__ == '__main__':\n", + " main()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 301, + "status": "ok", + "timestamp": 1614999455639, + "user": { + "displayName": "Arvind Srivastav", + "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GhKsFI_4660fvIUHDgSAVEBHXHmkjgDRf91ukLdhw=s64", + "userId": "13824725762542667997" + }, + "user_tz": 480 + }, + "id": "WQTFwXDrdkkF", + "outputId": "1017f124-10b6-4ff2-fb10-41212c816afd" + }, + "outputs": [], + "source": [ + "# loader = DataLoader([data], batch_size=32)\n", + "# print(loader)\n", + "# for batch in loader:\n", + "# print(batch)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "n3iiIHui9Dv6" + }, + "outputs": [], + "source": [ + "# loader = DataLoader(data, batch_size=32, shuffle=True)\n", + "# for batch in loader:\n", + "# print(batch)\n", + "# print('sd')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pxTwT4afH4ie" + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "s2QXMsJ2Vxqi" + }, + "source": [ + "## Archived" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HefWGx8--RZf" + }, + "source": [ + "## Inductive split using DeepSNAP" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "vie4Wruv-U9h" + }, + "outputs": [], + "source": [ + "# !pip install -q deepsnap\n", + "# from deepsnap.graph import Graph\n", + "# from deepsnap.batch import Batch\n", + "# from deepsnap.dataset import GraphDataset\n", + "# graphs = GraphDataset.pyg_to_graphs(dataset)\n", + "\n", + "# task = 'node'\n", + "# dataset = GraphDataset(graphs, task=task)\n", + "# dataset_train, dataset_val, dataset_test = dataset.split(\n", + "# transductive=False, split_ratio=[0.2, 0.2, 0.6])" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [ + "ajK1UsIkWm25" + ], + "machine_shape": "hm", + "name": "baseline.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}