Source code for graphlearn_torch.sampler.base

# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from abc import ABC, abstractmethod
import math
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union

import torch

from ..typing import NodeType, EdgeType, NumNeighbors
from ..utils import CastMixin


[docs]class EdgeIndex(NamedTuple): r""" PyG's :class:`~torch_geometric.loader.EdgeIndex` used in old data loader :class:`~torch_geometric.loader.NeighborSampler`: https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/loader/neighbor_sampler.py """ edge_index: torch.Tensor e_id: Optional[torch.Tensor] size: Tuple[int, int]
[docs] def to(self, *args, **kwargs): edge_index = self.edge_index.to(*args, **kwargs) e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None return EdgeIndex(edge_index, e_id, self.size)
[docs]@dataclass class NodeSamplerInput(CastMixin): r""" The sampling input of :meth:`~graphlearn_torch.sampler.BaseSampler.sample_from_nodes`. This class corresponds to :class:`~torch_geometric.sampler.NodeSamplerInput`: https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/sampler/base.py Args: node (torch.Tensor): The indices of seed nodes to start sampling from. input_type (str, optional): The input node type (in case of sampling in a heterogeneous graph). (default: :obj:`None`). """ node: torch.Tensor input_type: Optional[NodeType] = None def __getitem__(self, index: Union[torch.Tensor, Any]) -> 'NodeSamplerInput': if not isinstance(index, torch.Tensor): index = torch.tensor(index, dtype=torch.long) return NodeSamplerInput(self.node[index], self.input_type) def __len__(self): return self.node.numel()
[docs] def share_memory(self): self.node.share_memory_() return self
[docs] def to(self, device: torch.device): self.node.to(device) return self
[docs]class NegativeSamplingMode(Enum): # 'binary': Randomly sample negative edges in the graph. binary = 'binary' # 'triplet': Randomly sample negative destination nodes for each positive # source node. triplet = 'triplet'
[docs]@dataclass class NegativeSampling(CastMixin): r"""The negative sampling configuration of a :class:`~torch_geometric.sampler.BaseSampler` when calling :meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges`. Args: mode (str): The negative sampling mode (:obj:`"binary"` or :obj:`"triplet"`). If set to :obj:`"binary"`, will randomly sample negative links from the graph. If set to :obj:`"triplet"`, will randomly sample negative destination nodes for each positive source node. amount (int or float, optional): The ratio of sampled negative edges to the number of positive edges. (default: :obj:`1`) weight (torch.Tensor, optional): A node-level vector determining the sampling of nodes. Does not necessariyl need to sum up to one. If not given, negative nodes will be sampled uniformly. (default: :obj:`None`) """ mode: NegativeSamplingMode amount: Union[int, float] = 1 weight: Optional[torch.Tensor] = None def __init__( self, mode: Union[NegativeSamplingMode, str], amount: Union[int, float] = 1, weight: Optional[torch.Tensor] = None, ): self.mode = NegativeSamplingMode(mode) self.amount = amount self.weight = weight if self.amount <= 0: raise ValueError(f"The attribute 'amount' needs to be positive " f"for '{self.__class__.__name__}' " f"(got {self.amount})") if self.is_triplet(): if self.amount != math.ceil(self.amount): raise ValueError(f"The attribute 'amount' needs to be an " f"integer for '{self.__class__.__name__}' " f"with 'triplet' negative sampling " f"(got {self.amount}).") self.amount = math.ceil(self.amount)
[docs] def is_binary(self) -> bool: return self.mode == NegativeSamplingMode.binary
[docs] def is_triplet(self) -> bool: return self.mode == NegativeSamplingMode.triplet
[docs] def share_memory(self): if self.weight is not None: self.weight.share_memory_() return self
[docs] def to(self, device: torch.device): if self.weight is not None: self.weight.to(device) return self
[docs]@dataclass class EdgeSamplerInput(CastMixin): r""" The sampling input of :meth:`~graphlearn_torch.sampler.BaseSampler.sample_from_edges`. This class corresponds to :class:`~torch_geometric.sampler.EdgeSamplerInput`: https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/sampler/base.py Args: row (torch.Tensor): The source node indices of seed links to start sampling from. col (torch.Tensor): The destination node indices of seed links to start sampling from. label (torch.Tensor, optional): The label for the seed links. (default: :obj:`None`). input_type (Tuple[str, str, str], optional): The input edge type (in case of sampling in a heterogeneous graph). (default: :obj:`None`). """ row: torch.Tensor col: torch.Tensor label: Optional[torch.Tensor] = None input_type: Optional[EdgeType] = None neg_sampling: Optional[NegativeSampling] = None def __getitem__(self, index: Union[torch.Tensor, Any]) -> 'EdgeSamplerInput': if not isinstance(index, torch.Tensor): index = torch.tensor(index, dtype=torch.long) return EdgeSamplerInput( self.row[index], self.col[index], self.label[index] if self.label is not None else None, self.input_type, self.neg_sampling ) def __len__(self): return self.row.numel()
[docs] def share_memory(self): self.row.share_memory_() self.col.share_memory_() if self.label is not None: self.label.share_memory_() if self.label is not None: self.neg_sampling.share_memory() return self
[docs] def to(self, device: torch.device): self.row.to(device) self.col.to(device) if self.label is not None: self.label.to(device) if self.label is not None: self.neg_sampling.to(device) return self
[docs]@dataclass class SamplerOutput(CastMixin): r""" The sampling output of a :class:`~graphlearn_torch.sampler.BaseSampler` on homogeneous graphs. This class corresponds to :class:`~torch_geometric.sampler.SamplerOutput`: https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/sampler/base.py Args: node (torch.Tensor): The sampled nodes in the original graph. row (torch.Tensor): The source node indices of the sampled subgraph. Indices must be re-indexed to :obj:`{ 0, ..., num_nodes - 1 }` corresponding to the nodes in the :obj:`node` tensor. col (torch.Tensor): The destination node indices of the sampled subgraph. Indices must be re-indexed to :obj:`{ 0, ..., num_nodes - 1 }` corresponding to the nodes in the :obj:`node` tensor. edge (torch.Tensor, optional): The sampled edges in the original graph. This tensor is used to obtain edge features from the original graph. If no edge attributes are present, it may be omitted. batch (torch.Tensor, optional): The vector to identify the seed node for each sampled node. Can be present in case of disjoint subgraph sampling per seed node. (default: :obj:`None`). device (torch.device, optional): The device that all data of this output resides in. (default: :obj:`None`). metadata: (Any, optional): Additional metadata information. (default: :obj:`None`). """ node: torch.Tensor row: torch.Tensor col: torch.Tensor edge: Optional[torch.Tensor] = None batch: Optional[torch.Tensor] = None device: Optional[torch.device] = None metadata: Optional[Any] = None
[docs]@dataclass class HeteroSamplerOutput(CastMixin): r""" The sampling output of a :class:`~graphlearn_torch.sampler.BaseSampler` on heterogeneous graphs. This class corresponds to :class:`~torch_geometric.sampler.HeteroSamplerOutput`: https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/sampler/base.py Args: node (Dict[str, torch.Tensor]): The sampled nodes in the original graph for each node type. row (Dict[Tuple[str, str, str], torch.Tensor]): The source node indices of the sampled subgraph for each edge type. Indices must be re-indexed to :obj:`{ 0, ..., num_nodes - 1 }` corresponding to the nodes in the :obj:`node` tensor of the source node type. col (Dict[Tuple[str, str, str], torch.Tensor]): The destination node indices of the sampled subgraph for each edge type. Indices must be re-indexed to :obj:`{ 0, ..., num_nodes - 1 }` corresponding to the nodes in the :obj:`node` tensor of the destination node type. edge (Dict[Tuple[str, str, str], torch.Tensor], optional): The sampled edges in the original graph for each edge type. This tensor is used to obtain edge features from the original graph. If no edge attributes are present, it may be omitted. (default: :obj:`None`). batch (Dict[str, torch.Tensor], optional): The vector to identify the seed node for each sampled node for each node type. Can be present in case of disjoint subgraph sampling per seed node. (default: :obj:`None`). edge_types: (List[Tuple[str, str, str]], optional): The list of edge types of the sampled subgraph. (default: :obj:`None`). input_type: (Union[NodeType, EdgeType], optional): The input type of seed nodes or edge_label_index. (default: :obj:`None`). device (torch.device, optional): The device that all data of this output resides in. (default: :obj:`None`). metadata: (Any, optional): Additional metadata information. (default: :obj:`None`) """ node: Dict[NodeType, torch.Tensor] row: Dict[EdgeType, torch.Tensor] col: Dict[EdgeType, torch.Tensor] edge: Optional[Dict[EdgeType, torch.Tensor]] = None batch: Optional[Dict[NodeType, torch.Tensor]] = None edge_types: Optional[List[EdgeType]] = None input_type: Optional[Union[NodeType, EdgeType]] = None device: Optional[torch.device] = None metadata: Optional[Any] = None
[docs] def get_edge_index(self): edge_index = {k: torch.stack([v, self.col[k]]) for k, v in self.row.items()} if self.edge_types is not None: for etype in self.edge_types: if edge_index.get(etype, None) is None: edge_index[etype] = \ torch.empty((2, 0), dtype=torch.long).to(self.device) return edge_index
[docs]@dataclass class NeighborOutput(CastMixin): r""" The output of sampled neighbor results for a single hop sampling. Args: nbr (torch.Tensor): A 1D tensor of all sampled neighborhood node ids. nbr_num (torch.Tensor): A 1D tensor that identify the number of neighborhood nodes for each source nodes. Must be the same length as the source nodes of this sampling hop. nbr_num (torch.Tensor, optional): The edge ids corresponding to the sampled edges (from source node to the sampled neighborhood node). Should be the same length as :obj:`nbr` if provided. """ nbr: torch.Tensor nbr_num: torch.Tensor edge: Optional[torch.Tensor]
[docs] def to(self, device: torch.device): return NeighborOutput( nbr=self.nbr.to(device), nbr_num=self.nbr_num.to(device), edge=(self.edge.to(device) if self.edge is not None else None) )
[docs]class SamplingType(Enum): r""" Enum class for sampling types. """ NODE = 0 LINK = 1 SUBGRAPH = 2 RANDOM_WALK = 3
[docs]@dataclass class SamplingConfig: r""" Configuration info for sampling. """ sampling_type: SamplingType num_neighbors: Optional[NumNeighbors] batch_size: int shuffle: bool drop_last: bool with_edge: bool collect_features: bool with_neg: bool
[docs]class BaseSampler(ABC): r""" A base class that initializes a graph sampler and provides :meth:`sample_from_nodes` and :meth:`sample_from_edges` routines. This class corresponds to :class:`~torch_geometric.sampler.BaseSampler`: https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/sampler/base.py """
[docs] @abstractmethod def sample_from_nodes( self, inputs: NodeSamplerInput, **kwargs ) -> Union[HeteroSamplerOutput, SamplerOutput]: r""" Performs sampling from the nodes specified in :obj:`inputs`, returning a sampled subgraph(egograph) in the specified output format. Args: inputs (torch.Tensor): The input data with node indices to start sampling from. """
[docs] @abstractmethod def sample_from_edges( self, inputs: EdgeSamplerInput, **kwargs, ) -> Union[HeteroSamplerOutput, SamplerOutput]: r""" Performs sampling from the edges specified in :obj:`inputs`, returning a sampled subgraph(egograph) in the specified output format. Args: inputs (EdgeSamplerInput): The input data for sampling from edges including the (1) source node indices, the (2) destination node indices, the (3) optional edge labels and the (4) input edge type. """
[docs] @abstractmethod def subgraph( self, inputs: NodeSamplerInput, ) -> SamplerOutput: r""" Induce an enclosing subgraph based on inputs and their neighbors(if num_neighbors is not None). Args: inputs (torch.Tensor): The input data with node indices to induce subgraph from. Returns: The sampled unique nodes, relabeled rows and cols, original edge_ids, and a mapping from indices in `inputs` to new indices in output nodes, i.e. nodes[mapping] = inputs. """