Source code for graphlearn_torch.sampler.neighbor_sampler

# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import math
from typing import Dict, Optional, Union

import torch

from .. import py_graphlearn_torch as pywrap
from ..data import Graph
from ..typing import NodeType, EdgeType, NumNeighbors, reverse_edge_type
from ..utils import (
    merge_dict, merge_hetero_sampler_output, format_hetero_sampler_output,
    id2idx_v2
)

from .base import (
  BaseSampler, EdgeIndex,
  NodeSamplerInput, EdgeSamplerInput,
  SamplerOutput, HeteroSamplerOutput, NeighborOutput,
)
from .negative_sampler import RandomNegativeSampler

[docs]class NeighborSampler(BaseSampler): r""" Neighbor Sampler. """ def __init__(self, graph: Union[Graph, Dict[EdgeType, Graph]], num_neighbors: Optional[NumNeighbors] = None, device: torch.device=torch.device('cuda', 0), with_edge: bool=False, with_neg: bool=False, strategy: str = 'random'): self.graph = graph self.num_neighbors = num_neighbors self.device = device self.with_edge = with_edge self.with_neg = with_neg self.strategy = strategy self._subgraph_op = None self._sampler = None self._neg_sampler = None self._inducer = None if isinstance(self.graph, Graph): #homo self._g_cls = 'homo' if self.graph.mode == 'CPU': self.device = torch.device('cpu') else: # hetero self._g_cls = 'hetero' self.edge_types = [] self.node_types = set() for etype, graph in self.graph.items(): self.edge_types.append(etype) self.node_types.add(etype[0]) self.node_types.add(etype[2]) if self.graph[self.edge_types[0]].mode == 'CPU': self.device = torch.device('cpu') self._set_num_neighbors_and_num_hops(self.num_neighbors) @property def subgraph_op(self): self.lazy_init_subgraph_op() return self._subgraph_op
[docs] def lazy_init_sampler(self): if self._sampler is None: if self._g_cls == 'homo': if self.device.type == 'cuda': self._sampler = pywrap.CUDARandomSampler(self.graph.graph_handler) else: self._sampler = pywrap.CPURandomSampler(self.graph.graph_handler) else: # hetero self._sampler = {} for etype, g in self.graph.items(): if self.device != torch.device('cpu'): self._sampler[etype] = pywrap.CUDARandomSampler(g.graph_handler) else: self._sampler[etype] = pywrap.CPURandomSampler(g.graph_handler)
[docs] def lazy_init_neg_sampler(self): if self._neg_sampler is None and self.with_neg: if self._g_cls == 'homo': self._neg_sampler = RandomNegativeSampler( graph=self.graph, mode=self.device.type.upper() ) else: # hetero self._neg_sampler = {} for etype, g in self.graph.items(): self._neg_sampler[etype] = RandomNegativeSampler( graph=g, mode=self.device.type.upper() )
[docs] def lazy_init_subgraph_op(self): if self._subgraph_op is None: if self.device.type == 'cuda': self._subgraph_op = pywrap.CUDASubGraphOp(self.graph.graph_handler) else: self._subgraph_op = pywrap.CPUSubGraphOp(self.graph.graph_handler)
[docs] def sample_one_hop( self, input_seeds: torch.Tensor, req_num: int, etype: EdgeType = None ) -> NeighborOutput: self.lazy_init_sampler() sampler = self._sampler[etype] if etype is not None else self._sampler input_seeds = input_seeds.to(self.device) edge_ids = None if not self.with_edge: nbrs, nbrs_num = sampler.sample(input_seeds, req_num) else: nbrs, nbrs_num, edge_ids = sampler.sample_with_edge(input_seeds, req_num) if nbrs.numel() == 0: nbrs, nbrs_num = input_seeds, torch.ones_like(input_seeds) if self.with_edge: edge_ids = -1 * nbrs_num return NeighborOutput(nbrs, nbrs_num, edge_ids)
[docs] def sample_from_nodes( self, inputs: NodeSamplerInput, **kwargs ) -> Union[HeteroSamplerOutput, SamplerOutput]: inputs = NodeSamplerInput.cast(inputs) input_seeds = inputs.node.to(self.device) input_type = inputs.input_type if self._g_cls == 'hetero': assert input_type is not None output = self._hetero_sample_from_nodes({input_type: input_seeds}) else: output = self._sample_from_nodes(input_seeds) return output
def _sample_from_nodes( self, input_seeds: torch.Tensor ) -> SamplerOutput: r""" Sample on homogenous graphs and induce COO format subgraph. Note that messages in PyG are passed from src to dst. But we sample src's out neighbors and induce [src_index, dst_index] subgraphs. The direction of sampling is opposite to the direction of message passing. To be consistent with the semantics of PyG, the final edge index is transpose to [dst_index, src_index]. """ out_nodes, out_rows, out_cols, out_edges = [], [], [], [] inducer = self.get_inducer(input_seeds.numel()) srcs = inducer.init_node(input_seeds) batch = srcs out_nodes.append(srcs) for req_num in self.num_neighbors: out_nbrs = self.sample_one_hop(srcs, req_num) nodes, rows, cols = inducer.induce_next( srcs, out_nbrs.nbr, out_nbrs.nbr_num) out_nodes.append(nodes) out_rows.append(rows) out_cols.append(cols) if out_nbrs.edge is not None: out_edges.append(out_nbrs.edge) srcs = nodes return SamplerOutput( node=torch.cat(out_nodes), row=torch.cat(out_cols), col=torch.cat(out_rows), edge=(torch.cat(out_edges) if out_edges else None), batch=batch, device=self.device ) def _hetero_sample_from_nodes( self, input_seeds_dict: Dict[NodeType, torch.Tensor], ) -> HeteroSamplerOutput: r""" Sample on heterogenous graphs and induce COO format subgraph dict. Note that messages in PyG are passed from src to dst. But we sample src's out neighbors and induce [src_index, dst_index] subgraphs. The direction of sampling is opposite to the direction of message passing. To be consistent with the semantics of PyG, the final edge index is transpose to [dst_index, src_index] and edge_type is reversed as well. For example, given the edge_type (u, u2i, i), we sample by meta-path u->i, but return edge_index_dict {(i, rev_u2i, u) : [i, u]}. """ # sample neighbors hop by hop. max_input_batch_size = max([t.numel() for t in input_seeds_dict.values()]) inducer = self.get_inducer(max_input_batch_size) src_dict = inducer.init_node(input_seeds_dict) batch = src_dict out_nodes, out_rows, out_cols, out_edges = {}, {}, {}, {} merge_dict(src_dict, out_nodes) for i in range(self.num_hops): nbr_dict, edge_dict = {}, {} for etype in self.edge_types: src = src_dict.get(etype[0], None) req_num = self.num_neighbors[etype][i] if src is not None: output = self.sample_one_hop(src, req_num, etype) nbr_dict[etype] = [src, output.nbr, output.nbr_num] if output.edge is not None: edge_dict[etype] = output.edge nodes_dict, rows_dict, cols_dict = inducer.induce_next(nbr_dict) merge_dict(nodes_dict, out_nodes) merge_dict(rows_dict, out_rows) merge_dict(cols_dict, out_cols) merge_dict(edge_dict, out_edges) src_dict = nodes_dict for etype, rows in out_rows.items(): out_rows[etype] = torch.cat(rows) out_cols[etype] = torch.cat(out_cols[etype]) if self.with_edge: out_edges[etype] = torch.cat(out_edges[etype]) # TODO: support inbound sampling. res_rows, res_cols, res_edges = {}, {}, {} for etype, rows in out_rows.items(): rev_etype = reverse_edge_type(etype) res_rows[rev_etype] = out_cols[etype] res_cols[rev_etype] = rows if self.with_edge: res_edges[rev_etype] = out_edges[etype] return HeteroSamplerOutput( node={k : torch.cat(v) for k, v in out_nodes.items()}, row=res_rows, col=res_cols, edge=(res_edges if len(res_edges) else None), batch=batch, edge_types=self.edge_types, device=self.device )
[docs] def sample_from_edges( self, inputs: EdgeSamplerInput, **kwargs, ) -> Union[HeteroSamplerOutput, SamplerOutput]: r"""Performs sampling from an edge sampler input, leveraging a sampling function of the same signature as `node_sample`. Currently, we support the out-edge sampling manner, so we reverse the direction of src and dst for the output so that features of the sampled nodes during training can be aggregated from k-hop to (k-1)-hop nodes. """ src = inputs.row.to(self.device) dst = inputs.col.to(self.device) edge_label = None if inputs.label is None else inputs.label.to(self.device) input_type = inputs.input_type neg_sampling = inputs.neg_sampling num_pos = src.numel() num_neg = 0 # Negative Sampling self.lazy_init_neg_sampler() if neg_sampling is not None: # When we are doing negative sampling, we append negative information # of nodes/edges to `src`, `dst`. # Later on, we can easily reconstruct what belongs to positive and # negative examples by slicing via `num_pos`. num_neg = math.ceil(num_pos * neg_sampling.amount) if neg_sampling.is_binary(): # In the "binary" case, we randomly sample negative pairs of nodes. if input_type is not None: neg_pair = self._neg_sampler[input_type].sample(num_neg) else: neg_pair = self._neg_sampler.sample(num_neg) src_neg, dst_neg = neg_pair[0], neg_pair[1] src = torch.cat([src, src_neg], dim=0) dst = torch.cat([dst, dst_neg], dim=0) if edge_label is None: edge_label = torch.ones(num_pos, device=self.device) size = (num_neg, ) + edge_label.size()[1:] edge_neg_label = edge_label.new_zeros(size) edge_label = torch.cat([edge_label, edge_neg_label]) elif neg_sampling.is_triplet(): # TODO: make triplet negative sampling strict. # In the "triplet" case, we randomly sample negative destinations # in a "non-strict" manner. assert num_neg % num_pos == 0 if input_type is not None: neg_pair = self._neg_sampler[input_type].sample(num_neg, padding=True) else: neg_pair = self._neg_sampler.sample(num_neg, padding=True) dst_neg = neg_pair[1] dst = torch.cat([dst, dst_neg], dim=0) assert edge_label is None # Neighbor Sampling if input_type is not None: # hetero if input_type[0] != input_type[-1]: # Two distinct node types: src_seed, dst_seed = src, dst src, inverse_src = src.unique(return_inverse=True) dst, inverse_dst = dst.unique(return_inverse=True) seed_dict = {input_type[0]: src, input_type[-1]: dst} else: # Only a single node type: Merge both source and destination. seed = torch.cat([src, dst], dim=0) seed, inverse_seed = seed.unique(return_inverse=True) seed_dict = {input_type[0]: seed} temp_out = [] for it, node in seed_dict.items(): seeds = NodeSamplerInput(node=node, input_type=it) temp_out.append(self.sample_from_nodes(seeds)) if len(temp_out) == 2: out = merge_hetero_sampler_output(temp_out[0], temp_out[1], device=self.device) else: out = format_hetero_sampler_output(temp_out[0]) # edge_label if neg_sampling is None or neg_sampling.is_binary(): if input_type[0] != input_type[-1]: inverse_src = id2idx_v2(src_seed, out.node[input_type[0]]) inverse_dst = id2idx_v2(dst_seed, out.node[input_type[-1]]) edge_label_index = torch.stack([ inverse_src, inverse_dst, ], dim=0) else: edge_label_index = inverse_seed.view(2, -1) out.metadata = {'edge_label_index': edge_label_index, 'edge_label': edge_label} out.input_type = input_type elif neg_sampling.is_triplet(): if input_type[0] != input_type[-1]: inverse_src = id2idx_v2(src_seed, out.node[input_type[0]]) inverse_dst = id2idx_v2(dst_seed, out.node[input_type[-1]]) src_index = inverse_src dst_pos_index = inverse_dst[:num_pos] dst_neg_index = inverse_dst[num_pos:] else: src_index = inverse_seed[:num_pos] dst_pos_index = inverse_seed[num_pos:2 * num_pos] dst_neg_index = inverse_seed[2 * num_pos:] dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1) out.metadata = {'src_index': src_index, 'dst_pos_index': dst_pos_index, 'dst_neg_index': dst_neg_index} out.input_type = input_type else: #homo seed = torch.cat([src, dst], dim=0) seed, inverse_seed = seed.unique(return_inverse=True) out = self.sample_from_nodes(seed) # edge_label if neg_sampling is None or neg_sampling.is_binary(): edge_label_index = inverse_seed.view(2, -1) out.metadata = {'edge_label_index': edge_label_index, 'edge_label': edge_label} elif neg_sampling.is_triplet(): src_index = inverse_seed[:num_pos] dst_pos_index = inverse_seed[num_pos:2 * num_pos] dst_neg_index = inverse_seed[2 * num_pos:] dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1) out.metadata = {'src_index': src_index, 'dst_pos_index': dst_pos_index, 'dst_neg_index': dst_neg_index} return out
[docs] def sample_pyg_v1(self, ids: torch.Tensor): r""" Sample multi-hop neighbors and organize results to PyG's `EdgeIndex`. Args: ids: input ids, 1D tensor. The sampled results that is the same as PyG's `NeighborSampler`(PyG v1) """ ids = ids.to(self.device) adjs = [] srcs = ids out_ids = ids batch_size = 0 inducer = self.get_inducer(srcs.numel()) for i, req_num in enumerate(self.num_neighbors): srcs = inducer.init_node(srcs) batch_size = srcs.numel() if i == 0 else batch_size out_nbrs = self.sample_one_hop(srcs, req_num) nodes, rows, cols = \ inducer.induce_next(srcs, out_nbrs.nbr, out_nbrs.nbr_num) edge_index = torch.stack([cols, rows]) # we use csr instead of csc in PyG. out_ids = torch.cat([srcs, nodes]) adj_size = torch.LongTensor([out_ids.size(0), srcs.size(0)]) adjs.append(EdgeIndex(edge_index, out_nbrs.edge, adj_size)) srcs = out_ids return batch_size, out_ids, adjs[::-1]
[docs] def subgraph( self, inputs: NodeSamplerInput, ) -> SamplerOutput: self.lazy_init_subgraph_op() inputs = NodeSamplerInput.cast(inputs) input_seeds = inputs.node.to(self.device) if self.num_neighbors is not None: nodes = [input_seeds] for num in self.num_neighbors: nbr = self.sample_one_hop(nodes[-1], num).nbr nodes.append(torch.unique(nbr)) nodes, mapping = torch.cat(nodes).unique(return_inverse=True) else: nodes, mapping = torch.unique(input_seeds, return_inverse=True) subgraph = self._subgraph_op.node_subgraph(nodes, self.with_edge) return SamplerOutput( node=subgraph.nodes, row=subgraph.rows, col=subgraph.cols, edge=subgraph.eids if self.with_edge else None, device=self.device, metadata=mapping[:input_seeds.numel()])
[docs] def sample_prob( self, inputs: NodeSamplerInput, node_cnt: Union[int, Dict[NodeType, int]] ) -> Union[torch.Tensor, Dict[NodeType, torch.Tensor]]: r""" Get the probability of each node being sampled. """ self.lazy_init_sampler() inputs = NodeSamplerInput.cast(inputs) input_seeds = inputs.node.to(self.device) input_type = inputs.input_type if self._g_cls == 'hetero': assert input_type is not None output = self._hetero_sample_prob({input_type : input_seeds}, node_cnt) else: output = self._sample_prob(input_seeds, node_cnt) return output
def _sample_prob( self, input_seeds: torch.Tensor, node_cnt: int ) -> torch.Tensor: last_prob = \ torch.ones(node_cnt, device=self.device, dtype=torch.float32) * 0.01 last_prob[input_seeds] = 1 for req in self.num_neighbors: cur_prob = torch.zeros(node_cnt, device=self.device, dtype=torch.float32) self._sampler.cal_nbr_prob( req, last_prob, last_prob, self.graph.graph_handler, cur_prob ) last_prob = cur_prob return last_prob def _hetero_sample_prob( self, input_seeds_dict: Dict[NodeType, torch.Tensor], node_dict: Dict[NodeType, int] ) -> Dict[NodeType, torch.Tensor]: probs = {} for ntype in node_dict.keys(): probs[ntype] = [] # calculate probs for each subgraph for i in range(self.num_hops): for etype in self.edge_types: req = self.num_neighbors[etype][i] # homogenous subgraph case if etype[0] == etype[2]: if len(probs[etype[0]]) == 0: last_prob = torch.ones(node_dict[etype[0]].size(0), device=self.device, dtype=torch.float32) * 0.005 last_prob[input_seeds_dict[etype[0]]] = 1 else: last_prob = self.aggregate_prob(probs[etype[0]], node_dict[etype[0]].size(0), device=self.device) cur_prob = torch.zeros(node_dict[etype[0]].size(0), device=self.device, dtype=torch.float32) self._sampler[etype].cal_nbr_prob( req, last_prob, last_prob, self._graph_dict[etype].graph_handler, cur_prob ) last_prob = cur_prob probs[etype[0]].append(last_prob) # hetero bipartite graph case else: if len(probs[etype[0]]) == 0: last_prob = torch.ones(node_dict[etype[0]].size(0), device=self.device, dtype=torch.float32) * 0.005 last_prob[input_seeds_dict[etype[0]]] = 1 else: last_prob = self.aggregate_prob(probs[etype[0]], node_dict[etype[0]].size(0), device=self.device) etypes = [nbr_etype for nbr_etype in self.edge_types if nbr_etype[0] == etype[2]] temp_probs = [] # prepare nbr_prob if len(probs[etype[2]]) == 0: nbr_prob = torch.ones(node_dict[etype[2]].size(0), device=self.device, dtype=torch.float32) * 0.005 if etype[2] in input_seeds_dict: nbr_prob[input_seeds_dict[etype[2]]] = 1 else: nbr_prob = self.aggregate_prob(probs[etype[2]], node_dict[etype[2]].size(0), device=self.device) for nbr_etype in etypes: cur_prob = torch.zeros(node_dict[etype[0]].size(0), device=self.device, dtype=torch.float32) self._sampler[etype].cal_nbr_prob( req, last_prob, nbr_prob, self._graph_dict[nbr_etype].graph_handler, cur_prob ) last_prob = cur_prob temp_probs.append(last_prob) # aggregate prob for the bipartite graph # with #{subgraphs where the neighbours are} sub_temp_prob = self.aggregate_prob(temp_probs, node_dict[etype[0]].size(0), device=self.device) probs[etype[0]].append(sub_temp_prob) # aggregate probs from each subgraph # with #{subgraphs} for ntype, prob in probs.items(): res = self.aggregate_prob( prob, node_dict[ntype].size(0), device=self.device) if i == self.num_hops - 1: probs[ntype] = res else: probs[ntype] = [res] return probs
[docs] def get_inducer(self, input_batch_size: int): if self._inducer is None: self._inducer = self.create_inducer(input_batch_size) return self._inducer
[docs] def create_inducer(self, input_batch_size: int): max_num_nodes = self._max_sampled_nodes(input_batch_size) if self.device.type == 'cuda': if self._g_cls == 'homo': inducer = pywrap.CUDAInducer(max_num_nodes) else: inducer = pywrap.CUDAHeteroInducer(max_num_nodes) else: if self._g_cls == 'homo': inducer = pywrap.CPUInducer(max_num_nodes) else: inducer = pywrap.CPUHeteroInducer(max_num_nodes) return inducer
def _set_num_neighbors_and_num_hops(self, num_neighbors): if isinstance(num_neighbors, (list, tuple)): num_neighbors = {key: num_neighbors for key in self.edge_types} assert isinstance(num_neighbors, dict) self.num_neighbors = num_neighbors # Add at least one element to the list to ensure `max` is well-defined self.num_hops = max([0] + [len(v) for v in num_neighbors.values()]) for key, value in self.num_neighbors.items(): if len(value) != self.num_hops: raise ValueError(f"Expected the edge type {key} to have " f"{self.num_hops} entries (got {len(value)})") def _max_sampled_nodes( self, input_batch_size: int, ) -> Union[int, Dict[str, int]]: if self._g_cls == 'homo': res = [input_batch_size] for num in self.num_neighbors: res.append(res[-1] * num) return sum(res) res = {k : [] for k in self.node_types} for etype, num_list in self.num_neighbors.items(): tmp_res = [input_batch_size] for num in num_list: tmp_res.append(tmp_res[-1] * num) res[etype[0]].extend(tmp_res) res[etype[2]].extend(tmp_res) return {k : sum(v) for k, v in res.items()} def _aggregate_prob(self, probs, node_num, device): """ Aggregate probs from each subgraph p = 1 - ((1-p_0)(1-p_1)...(1-p_k))**(1/k) where k := #{subgraphs} """ res = torch.ones(node_num, device=device, dtype=torch.float32) for temp_prob in probs: # to avoid the case that p_i=1 causes p=1 s.t the whole importance won't # be decided by one term. res *= (1 + .002 - temp_prob) res = 1 - res ** (1/len(probs)) return res.clamp(min=0.0)