-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathutils.py
More file actions
64 lines (53 loc) · 2.21 KB
/
utils.py
File metadata and controls
64 lines (53 loc) · 2.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import random
import numpy as np
import torch
from torch_geometric.datasets import Planetoid, WikipediaNetwork, WebKB, Actor
root = os.path.split(__file__)[0]
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def accuracy(output, labels):
preds = output.max(1)[1].type_as(labels)
correct = preds.eq(labels).double()
correct = correct.sum()
return correct / len(labels)
def load_dataset(name: str, device=None):
if device is None:
device = torch.device('cpu')
name = name.lower()
if name in ["cora", "pubmed", "citeseer"]:
dataset = Planetoid(root=root + "/dataset/Planetoid", name=name)
elif name in ["chameleon", "squirrel"]:
dataset = WikipediaNetwork(root=root + "/dataset/WikipediaNetwork", name=name)
elif name in ["cornell", "texas", "wisconsin"]:
dataset = WebKB(root=root + "/dataset/WebKB", name=name)
elif name in ["actor"]:
dataset = Actor(root=root + "/dataset/Actor")
else:
raise "Please implement support for this dataset in function load_dataset()."
data = dataset[0].to(device)
x, y = data.x, data.y
n = len(x)
edge_index = data.edge_index
nfeat = data.num_node_features
nclass = len(torch.unique(y))
return x, y, nfeat, nclass, eidx_to_sp(n, edge_index), data.train_mask, data.val_mask, data.test_mask
def eidx_to_sp(n: int, edge_index: torch.Tensor, device=None) -> torch.sparse.Tensor:
indices = edge_index
values = torch.FloatTensor([1.0] * len(edge_index[0])).to(edge_index.device)
coo = torch.sparse_coo_tensor(indices=indices, values=values, size=[n, n])
if device is None:
device = edge_index.device
return coo.to(device)
def select_mask(i: int, train: torch.Tensor, val: torch.Tensor, test: torch.Tensor) -> torch.Tensor:
if train.dim() == 1:
return train, val, test
else:
indices = torch.tensor([i]).to(train.device)
train_idx = torch.index_select(train, 1, indices).reshape(-1)
val_idx = torch.index_select(val, 1, indices).reshape(-1)
test_idx = torch.index_select(test, 1, indices).reshape(-1)
return train_idx, val_idx, test_idx