# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import math
from typing import Dict, Optional, Union
import torch
from .. import py_graphlearn_torch as pywrap
from ..data import Graph
from ..typing import NodeType, EdgeType, NumNeighbors, reverse_edge_type
from ..utils import (
merge_dict, merge_hetero_sampler_output, format_hetero_sampler_output,
id2idx_v2
)
from .base import (
BaseSampler, EdgeIndex,
NodeSamplerInput, EdgeSamplerInput,
SamplerOutput, HeteroSamplerOutput, NeighborOutput,
)
from .negative_sampler import RandomNegativeSampler
[docs]class NeighborSampler(BaseSampler):
r""" Neighbor Sampler.
"""
def __init__(self,
graph: Union[Graph, Dict[EdgeType, Graph]],
num_neighbors: Optional[NumNeighbors] = None,
device: torch.device=torch.device('cuda', 0),
with_edge: bool=False,
with_neg: bool=False,
strategy: str = 'random'):
self.graph = graph
self.num_neighbors = num_neighbors
self.device = device
self.with_edge = with_edge
self.with_neg = with_neg
self.strategy = strategy
self._subgraph_op = None
self._sampler = None
self._neg_sampler = None
self._inducer = None
if isinstance(self.graph, Graph): #homo
self._g_cls = 'homo'
if self.graph.mode == 'CPU':
self.device = torch.device('cpu')
else: # hetero
self._g_cls = 'hetero'
self.edge_types = []
self.node_types = set()
for etype, graph in self.graph.items():
self.edge_types.append(etype)
self.node_types.add(etype[0])
self.node_types.add(etype[2])
if self.graph[self.edge_types[0]].mode == 'CPU':
self.device = torch.device('cpu')
self._set_num_neighbors_and_num_hops(self.num_neighbors)
@property
def subgraph_op(self):
self.lazy_init_subgraph_op()
return self._subgraph_op
[docs] def lazy_init_sampler(self):
if self._sampler is None:
if self._g_cls == 'homo':
if self.device.type == 'cuda':
self._sampler = pywrap.CUDARandomSampler(self.graph.graph_handler)
else:
self._sampler = pywrap.CPURandomSampler(self.graph.graph_handler)
else: # hetero
self._sampler = {}
for etype, g in self.graph.items():
if self.device != torch.device('cpu'):
self._sampler[etype] = pywrap.CUDARandomSampler(g.graph_handler)
else:
self._sampler[etype] = pywrap.CPURandomSampler(g.graph_handler)
[docs] def lazy_init_neg_sampler(self):
if self._neg_sampler is None and self.with_neg:
if self._g_cls == 'homo':
self._neg_sampler = RandomNegativeSampler(
graph=self.graph,
mode=self.device.type.upper()
)
else: # hetero
self._neg_sampler = {}
for etype, g in self.graph.items():
self._neg_sampler[etype] = RandomNegativeSampler(
graph=g,
mode=self.device.type.upper()
)
[docs] def lazy_init_subgraph_op(self):
if self._subgraph_op is None:
if self.device.type == 'cuda':
self._subgraph_op = pywrap.CUDASubGraphOp(self.graph.graph_handler)
else:
self._subgraph_op = pywrap.CPUSubGraphOp(self.graph.graph_handler)
[docs] def sample_one_hop(
self,
input_seeds: torch.Tensor,
req_num: int,
etype: EdgeType = None
) -> NeighborOutput:
self.lazy_init_sampler()
sampler = self._sampler[etype] if etype is not None else self._sampler
input_seeds = input_seeds.to(self.device)
edge_ids = None
if not self.with_edge:
nbrs, nbrs_num = sampler.sample(input_seeds, req_num)
else:
nbrs, nbrs_num, edge_ids = sampler.sample_with_edge(input_seeds, req_num)
if nbrs.numel() == 0:
nbrs, nbrs_num = input_seeds, torch.ones_like(input_seeds)
if self.with_edge:
edge_ids = -1 * nbrs_num
return NeighborOutput(nbrs, nbrs_num, edge_ids)
[docs] def sample_from_nodes(
self,
inputs: NodeSamplerInput,
**kwargs
) -> Union[HeteroSamplerOutput, SamplerOutput]:
inputs = NodeSamplerInput.cast(inputs)
input_seeds = inputs.node.to(self.device)
input_type = inputs.input_type
if self._g_cls == 'hetero':
assert input_type is not None
output = self._hetero_sample_from_nodes({input_type: input_seeds})
else:
output = self._sample_from_nodes(input_seeds)
return output
def _sample_from_nodes(
self,
input_seeds: torch.Tensor
) -> SamplerOutput:
r""" Sample on homogenous graphs and induce COO format subgraph.
Note that messages in PyG are passed from src to dst. But we sample src's
out neighbors and induce [src_index, dst_index] subgraphs. The direction
of sampling is opposite to the direction of message passing. To be
consistent with the semantics of PyG, the final edge index is transpose to
[dst_index, src_index].
"""
out_nodes, out_rows, out_cols, out_edges = [], [], [], []
inducer = self.get_inducer(input_seeds.numel())
srcs = inducer.init_node(input_seeds)
batch = srcs
out_nodes.append(srcs)
for req_num in self.num_neighbors:
out_nbrs = self.sample_one_hop(srcs, req_num)
nodes, rows, cols = inducer.induce_next(
srcs, out_nbrs.nbr, out_nbrs.nbr_num)
out_nodes.append(nodes)
out_rows.append(rows)
out_cols.append(cols)
if out_nbrs.edge is not None:
out_edges.append(out_nbrs.edge)
srcs = nodes
return SamplerOutput(
node=torch.cat(out_nodes),
row=torch.cat(out_cols),
col=torch.cat(out_rows),
edge=(torch.cat(out_edges) if out_edges else None),
batch=batch,
device=self.device
)
def _hetero_sample_from_nodes(
self,
input_seeds_dict: Dict[NodeType, torch.Tensor],
) -> HeteroSamplerOutput:
r""" Sample on heterogenous graphs and induce COO format subgraph dict.
Note that messages in PyG are passed from src to dst. But we sample src's
out neighbors and induce [src_index, dst_index] subgraphs. The direction
of sampling is opposite to the direction of message passing. To be
consistent with the semantics of PyG, the final edge index is transpose to
[dst_index, src_index] and edge_type is reversed as well. For example,
given the edge_type (u, u2i, i), we sample by meta-path u->i, but return
edge_index_dict {(i, rev_u2i, u) : [i, u]}.
"""
# sample neighbors hop by hop.
max_input_batch_size = max([t.numel() for t in input_seeds_dict.values()])
inducer = self.get_inducer(max_input_batch_size)
src_dict = inducer.init_node(input_seeds_dict)
batch = src_dict
out_nodes, out_rows, out_cols, out_edges = {}, {}, {}, {}
merge_dict(src_dict, out_nodes)
for i in range(self.num_hops):
nbr_dict, edge_dict = {}, {}
for etype in self.edge_types:
src = src_dict.get(etype[0], None)
req_num = self.num_neighbors[etype][i]
if src is not None:
output = self.sample_one_hop(src, req_num, etype)
nbr_dict[etype] = [src, output.nbr, output.nbr_num]
if output.edge is not None:
edge_dict[etype] = output.edge
nodes_dict, rows_dict, cols_dict = inducer.induce_next(nbr_dict)
merge_dict(nodes_dict, out_nodes)
merge_dict(rows_dict, out_rows)
merge_dict(cols_dict, out_cols)
merge_dict(edge_dict, out_edges)
src_dict = nodes_dict
for etype, rows in out_rows.items():
out_rows[etype] = torch.cat(rows)
out_cols[etype] = torch.cat(out_cols[etype])
if self.with_edge:
out_edges[etype] = torch.cat(out_edges[etype])
# TODO: support inbound sampling.
res_rows, res_cols, res_edges = {}, {}, {}
for etype, rows in out_rows.items():
rev_etype = reverse_edge_type(etype)
res_rows[rev_etype] = out_cols[etype]
res_cols[rev_etype] = rows
if self.with_edge:
res_edges[rev_etype] = out_edges[etype]
return HeteroSamplerOutput(
node={k : torch.cat(v) for k, v in out_nodes.items()},
row=res_rows,
col=res_cols,
edge=(res_edges if len(res_edges) else None),
batch=batch,
edge_types=self.edge_types,
device=self.device
)
[docs] def sample_from_edges(
self,
inputs: EdgeSamplerInput,
**kwargs,
) -> Union[HeteroSamplerOutput, SamplerOutput]:
r"""Performs sampling from an edge sampler input, leveraging a sampling
function of the same signature as `node_sample`.
Currently, we support the out-edge sampling manner, so we reverse the
direction of src and dst for the output so that features of the sampled
nodes during training can be aggregated from k-hop to (k-1)-hop nodes.
"""
src = inputs.row.to(self.device)
dst = inputs.col.to(self.device)
edge_label = None if inputs.label is None else inputs.label.to(self.device)
input_type = inputs.input_type
neg_sampling = inputs.neg_sampling
num_pos = src.numel()
num_neg = 0
# Negative Sampling
self.lazy_init_neg_sampler()
if neg_sampling is not None:
# When we are doing negative sampling, we append negative information
# of nodes/edges to `src`, `dst`.
# Later on, we can easily reconstruct what belongs to positive and
# negative examples by slicing via `num_pos`.
num_neg = math.ceil(num_pos * neg_sampling.amount)
if neg_sampling.is_binary():
# In the "binary" case, we randomly sample negative pairs of nodes.
if input_type is not None:
neg_pair = self._neg_sampler[input_type].sample(num_neg)
else:
neg_pair = self._neg_sampler.sample(num_neg)
src_neg, dst_neg = neg_pair[0], neg_pair[1]
src = torch.cat([src, src_neg], dim=0)
dst = torch.cat([dst, dst_neg], dim=0)
if edge_label is None:
edge_label = torch.ones(num_pos, device=self.device)
size = (num_neg, ) + edge_label.size()[1:]
edge_neg_label = edge_label.new_zeros(size)
edge_label = torch.cat([edge_label, edge_neg_label])
elif neg_sampling.is_triplet():
# TODO: make triplet negative sampling strict.
# In the "triplet" case, we randomly sample negative destinations
# in a "non-strict" manner.
assert num_neg % num_pos == 0
if input_type is not None:
neg_pair = self._neg_sampler[input_type].sample(num_neg, padding=True)
else:
neg_pair = self._neg_sampler.sample(num_neg, padding=True)
dst_neg = neg_pair[1]
dst = torch.cat([dst, dst_neg], dim=0)
assert edge_label is None
# Neighbor Sampling
if input_type is not None: # hetero
if input_type[0] != input_type[-1]: # Two distinct node types:
src_seed, dst_seed = src, dst
src, inverse_src = src.unique(return_inverse=True)
dst, inverse_dst = dst.unique(return_inverse=True)
seed_dict = {input_type[0]: src, input_type[-1]: dst}
else: # Only a single node type: Merge both source and destination.
seed = torch.cat([src, dst], dim=0)
seed, inverse_seed = seed.unique(return_inverse=True)
seed_dict = {input_type[0]: seed}
temp_out = []
for it, node in seed_dict.items():
seeds = NodeSamplerInput(node=node, input_type=it)
temp_out.append(self.sample_from_nodes(seeds))
if len(temp_out) == 2:
out = merge_hetero_sampler_output(temp_out[0],
temp_out[1],
device=self.device)
else:
out = format_hetero_sampler_output(temp_out[0])
# edge_label
if neg_sampling is None or neg_sampling.is_binary():
if input_type[0] != input_type[-1]:
inverse_src = id2idx_v2(src_seed, out.node[input_type[0]])
inverse_dst = id2idx_v2(dst_seed, out.node[input_type[-1]])
edge_label_index = torch.stack([
inverse_src,
inverse_dst,
], dim=0)
else:
edge_label_index = inverse_seed.view(2, -1)
out.metadata = {'edge_label_index': edge_label_index,
'edge_label': edge_label}
out.input_type = input_type
elif neg_sampling.is_triplet():
if input_type[0] != input_type[-1]:
inverse_src = id2idx_v2(src_seed, out.node[input_type[0]])
inverse_dst = id2idx_v2(dst_seed, out.node[input_type[-1]])
src_index = inverse_src
dst_pos_index = inverse_dst[:num_pos]
dst_neg_index = inverse_dst[num_pos:]
else:
src_index = inverse_seed[:num_pos]
dst_pos_index = inverse_seed[num_pos:2 * num_pos]
dst_neg_index = inverse_seed[2 * num_pos:]
dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1)
out.metadata = {'src_index': src_index,
'dst_pos_index': dst_pos_index,
'dst_neg_index': dst_neg_index}
out.input_type = input_type
else: #homo
seed = torch.cat([src, dst], dim=0)
seed, inverse_seed = seed.unique(return_inverse=True)
out = self.sample_from_nodes(seed)
# edge_label
if neg_sampling is None or neg_sampling.is_binary():
edge_label_index = inverse_seed.view(2, -1)
out.metadata = {'edge_label_index': edge_label_index,
'edge_label': edge_label}
elif neg_sampling.is_triplet():
src_index = inverse_seed[:num_pos]
dst_pos_index = inverse_seed[num_pos:2 * num_pos]
dst_neg_index = inverse_seed[2 * num_pos:]
dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1)
out.metadata = {'src_index': src_index,
'dst_pos_index': dst_pos_index,
'dst_neg_index': dst_neg_index}
return out
[docs] def sample_pyg_v1(self, ids: torch.Tensor):
r""" Sample multi-hop neighbors and organize results to PyG's `EdgeIndex`.
Args:
ids: input ids, 1D tensor.
The sampled results that is the same as PyG's `NeighborSampler`(PyG v1)
"""
ids = ids.to(self.device)
adjs = []
srcs = ids
out_ids = ids
batch_size = 0
inducer = self.get_inducer(srcs.numel())
for i, req_num in enumerate(self.num_neighbors):
srcs = inducer.init_node(srcs)
batch_size = srcs.numel() if i == 0 else batch_size
out_nbrs = self.sample_one_hop(srcs, req_num)
nodes, rows, cols = \
inducer.induce_next(srcs, out_nbrs.nbr, out_nbrs.nbr_num)
edge_index = torch.stack([cols, rows]) # we use csr instead of csc in PyG.
out_ids = torch.cat([srcs, nodes])
adj_size = torch.LongTensor([out_ids.size(0), srcs.size(0)])
adjs.append(EdgeIndex(edge_index, out_nbrs.edge, adj_size))
srcs = out_ids
return batch_size, out_ids, adjs[::-1]
[docs] def subgraph(
self,
inputs: NodeSamplerInput,
) -> SamplerOutput:
self.lazy_init_subgraph_op()
inputs = NodeSamplerInput.cast(inputs)
input_seeds = inputs.node.to(self.device)
if self.num_neighbors is not None:
nodes = [input_seeds]
for num in self.num_neighbors:
nbr = self.sample_one_hop(nodes[-1], num).nbr
nodes.append(torch.unique(nbr))
nodes, mapping = torch.cat(nodes).unique(return_inverse=True)
else:
nodes, mapping = torch.unique(input_seeds, return_inverse=True)
subgraph = self._subgraph_op.node_subgraph(nodes, self.with_edge)
return SamplerOutput(
node=subgraph.nodes,
row=subgraph.rows,
col=subgraph.cols,
edge=subgraph.eids if self.with_edge else None,
device=self.device,
metadata=mapping[:input_seeds.numel()])
[docs] def sample_prob(
self,
inputs: NodeSamplerInput,
node_cnt: Union[int, Dict[NodeType, int]]
) -> Union[torch.Tensor, Dict[NodeType, torch.Tensor]]:
r""" Get the probability of each node being sampled.
"""
self.lazy_init_sampler()
inputs = NodeSamplerInput.cast(inputs)
input_seeds = inputs.node.to(self.device)
input_type = inputs.input_type
if self._g_cls == 'hetero':
assert input_type is not None
output = self._hetero_sample_prob({input_type : input_seeds}, node_cnt)
else:
output = self._sample_prob(input_seeds, node_cnt)
return output
def _sample_prob(
self,
input_seeds: torch.Tensor,
node_cnt: int
) -> torch.Tensor:
last_prob = \
torch.ones(node_cnt, device=self.device, dtype=torch.float32) * 0.01
last_prob[input_seeds] = 1
for req in self.num_neighbors:
cur_prob = torch.zeros(node_cnt, device=self.device, dtype=torch.float32)
self._sampler.cal_nbr_prob(
req, last_prob, last_prob, self.graph.graph_handler, cur_prob
)
last_prob = cur_prob
return last_prob
def _hetero_sample_prob(
self,
input_seeds_dict: Dict[NodeType, torch.Tensor],
node_dict: Dict[NodeType, int]
) -> Dict[NodeType, torch.Tensor]:
probs = {}
for ntype in node_dict.keys():
probs[ntype] = []
# calculate probs for each subgraph
for i in range(self.num_hops):
for etype in self.edge_types:
req = self.num_neighbors[etype][i]
# homogenous subgraph case
if etype[0] == etype[2]:
if len(probs[etype[0]]) == 0:
last_prob = torch.ones(node_dict[etype[0]].size(0),
device=self.device,
dtype=torch.float32) * 0.005
last_prob[input_seeds_dict[etype[0]]] = 1
else:
last_prob = self.aggregate_prob(probs[etype[0]],
node_dict[etype[0]].size(0),
device=self.device)
cur_prob = torch.zeros(node_dict[etype[0]].size(0),
device=self.device,
dtype=torch.float32)
self._sampler[etype].cal_nbr_prob(
req, last_prob, last_prob,
self._graph_dict[etype].graph_handler, cur_prob
)
last_prob = cur_prob
probs[etype[0]].append(last_prob)
# hetero bipartite graph case
else:
if len(probs[etype[0]]) == 0:
last_prob = torch.ones(node_dict[etype[0]].size(0),
device=self.device,
dtype=torch.float32) * 0.005
last_prob[input_seeds_dict[etype[0]]] = 1
else:
last_prob = self.aggregate_prob(probs[etype[0]],
node_dict[etype[0]].size(0),
device=self.device)
etypes = [nbr_etype
for nbr_etype in self.edge_types
if nbr_etype[0] == etype[2]]
temp_probs = []
# prepare nbr_prob
if len(probs[etype[2]]) == 0:
nbr_prob = torch.ones(node_dict[etype[2]].size(0),
device=self.device,
dtype=torch.float32) * 0.005
if etype[2] in input_seeds_dict:
nbr_prob[input_seeds_dict[etype[2]]] = 1
else:
nbr_prob = self.aggregate_prob(probs[etype[2]],
node_dict[etype[2]].size(0),
device=self.device)
for nbr_etype in etypes:
cur_prob = torch.zeros(node_dict[etype[0]].size(0),
device=self.device,
dtype=torch.float32)
self._sampler[etype].cal_nbr_prob(
req, last_prob, nbr_prob,
self._graph_dict[nbr_etype].graph_handler, cur_prob
)
last_prob = cur_prob
temp_probs.append(last_prob)
# aggregate prob for the bipartite graph
# with #{subgraphs where the neighbours are}
sub_temp_prob = self.aggregate_prob(temp_probs,
node_dict[etype[0]].size(0),
device=self.device)
probs[etype[0]].append(sub_temp_prob)
# aggregate probs from each subgraph
# with #{subgraphs}
for ntype, prob in probs.items():
res = self.aggregate_prob(
prob, node_dict[ntype].size(0), device=self.device)
if i == self.num_hops - 1:
probs[ntype] = res
else:
probs[ntype] = [res]
return probs
[docs] def get_inducer(self, input_batch_size: int):
if self._inducer is None:
self._inducer = self.create_inducer(input_batch_size)
return self._inducer
[docs] def create_inducer(self, input_batch_size: int):
max_num_nodes = self._max_sampled_nodes(input_batch_size)
if self.device.type == 'cuda':
if self._g_cls == 'homo':
inducer = pywrap.CUDAInducer(max_num_nodes)
else:
inducer = pywrap.CUDAHeteroInducer(max_num_nodes)
else:
if self._g_cls == 'homo':
inducer = pywrap.CPUInducer(max_num_nodes)
else:
inducer = pywrap.CPUHeteroInducer(max_num_nodes)
return inducer
def _set_num_neighbors_and_num_hops(self, num_neighbors):
if isinstance(num_neighbors, (list, tuple)):
num_neighbors = {key: num_neighbors for key in self.edge_types}
assert isinstance(num_neighbors, dict)
self.num_neighbors = num_neighbors
# Add at least one element to the list to ensure `max` is well-defined
self.num_hops = max([0] + [len(v) for v in num_neighbors.values()])
for key, value in self.num_neighbors.items():
if len(value) != self.num_hops:
raise ValueError(f"Expected the edge type {key} to have "
f"{self.num_hops} entries (got {len(value)})")
def _max_sampled_nodes(
self,
input_batch_size: int,
) -> Union[int, Dict[str, int]]:
if self._g_cls == 'homo':
res = [input_batch_size]
for num in self.num_neighbors:
res.append(res[-1] * num)
return sum(res)
res = {k : [] for k in self.node_types}
for etype, num_list in self.num_neighbors.items():
tmp_res = [input_batch_size]
for num in num_list:
tmp_res.append(tmp_res[-1] * num)
res[etype[0]].extend(tmp_res)
res[etype[2]].extend(tmp_res)
return {k : sum(v) for k, v in res.items()}
def _aggregate_prob(self, probs, node_num, device):
"""
Aggregate probs from each subgraph
p = 1 - ((1-p_0)(1-p_1)...(1-p_k))**(1/k)
where k := #{subgraphs}
"""
res = torch.ones(node_num, device=device, dtype=torch.float32)
for temp_prob in probs:
# to avoid the case that p_i=1 causes p=1 s.t the whole importance won't
# be decided by one term.
res *= (1 + .002 - temp_prob)
res = 1 - res ** (1/len(probs))
return res.clamp(min=0.0)