Source code for graphlearn_torch.distributed.dist_feature

# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Dict, List, Optional, Tuple, Union

import torch

from ..data import Feature
from ..typing import (
  EdgeType, NodeType,
  PartitionBook, HeteroNodePartitionDict, HeteroEdgePartitionDict
)
from ..utils import get_available_device, ensure_device

from .rpc import (
  RpcDataPartitionRouter, RpcCalleeBase, rpc_register, rpc_request_async
)


# Given a set of node ids, the `PartialFeature` stores the feature info
# of a subset of the original ids, the first tensor is the features of the
# subset node ids, and the second tensor records the index of the subset
# node ids.
PartialFeature = Tuple[torch.Tensor, torch.Tensor]


[docs]class RpcFeatureLookupCallee(RpcCalleeBase): r""" A wrapper for rpc callee that will perform feature lookup from remote processes. """ def __init__(self, dist_feature): super().__init__() self.dist_feature = dist_feature
[docs] def call(self, *args, **kwargs): return self.dist_feature.local_get(*args, **kwargs)
[docs]class DistFeature(object): r""" Distributed feature data manager for global feature lookups. Args: num_partitions: Number of data partitions. partition_id: Data partition idx of current process. local_feature: Local ``Feature`` instance. feature_pb: Partition book which records node/edge ids to worker node ids mapping on feature store. local_only: Use this instance only for local feature lookup or stitching. If set to ``True``, the related rpc callee will not be registered and users should ensure that lookups for remote features are not invoked through this instance. Default to ``False``. device: Device used for computing. Default to ``None``. Note that`local_feature` and `feature_pb` should be a dictionary for hetero data. """ def __init__(self, num_partitions: int, partition_idx: int, local_feature: Union[Feature, Dict[Union[NodeType, EdgeType], Feature]], feature_pb: Union[PartitionBook, HeteroNodePartitionDict, HeteroEdgePartitionDict], local_only: bool = False, rpc_router: Optional[RpcDataPartitionRouter] = None, device: Optional[torch.device] = None): self.num_partitions = num_partitions self.partition_idx = partition_idx self.device = get_available_device(device) ensure_device(self.device) self.local_feature = local_feature if isinstance(self.local_feature, dict): self.data_cls = 'hetero' for _, feat in self.local_feature.items(): feat.lazy_init_with_ipc_handle() elif isinstance(self.local_feature, Feature): self.data_cls = 'homo' self.local_feature.lazy_init_with_ipc_handle() else: raise ValueError(f"'{self.__class__.__name__}': found invalid input " f"feature type '{type(self.local_feature)}'") self.feature_pb = feature_pb if isinstance(self.feature_pb, dict): assert self.data_cls == 'hetero' elif isinstance(self.feature_pb, PartitionBook): assert self.data_cls == 'homo' else: raise ValueError(f"'{self.__class__.__name__}': found invalid input " f"patition book type '{type(self.feature_pb)}'") self.rpc_router = rpc_router if not local_only: if self.rpc_router is None: raise ValueError(f"'{self.__class__.__name__}': a rpc router must be " f"provided when `local_only` set to `False`") rpc_callee = RpcFeatureLookupCallee(self) self.rpc_callee_id = rpc_register(rpc_callee) else: self.rpc_callee_id = None def _get_local_store(self, input_type: Optional[Union[NodeType, EdgeType]]): if self.data_cls == 'hetero': assert input_type is not None return self.local_feature[input_type], self.feature_pb[input_type] return self.local_feature, self.feature_pb
[docs] def local_get( self, ids: torch.Tensor, input_type: Optional[Union[NodeType, EdgeType]] = None ) -> torch.Tensor: r""" Lookup features in the local feature store, the input node/edge ids should be guaranteed to be all local to the current feature store. """ feat, _ = self._get_local_store(input_type) # TODO: check performance with `return feat[ids].cpu()` return feat.cpu_get(ids)
[docs] def async_get( self, ids: torch.Tensor, input_type: Optional[Union[NodeType, EdgeType]] = None ) -> torch.futures.Future: r""" Lookup features asynchronously and return a future. """ ids = ids.to(self.device) remote_fut = self._remote_selecting_get(ids, input_type) local_feature = self._local_selecting_get(ids, input_type) res_fut = torch.futures.Future() def on_done(*_): try: remote_feature_list = remote_fut.wait() result = self._stitch(ids, local_feature, remote_feature_list) except Exception as e: res_fut.set_exception(e) else: res_fut.set_result(result) remote_fut.add_done_callback(on_done) return res_fut
def __getitem__( self, input: Union[torch.Tensor, Tuple[Union[NodeType, EdgeType], torch.Tensor]] ) -> torch.Tensor: r""" Lookup features synchronously in a '__getitem__' way. """ if isinstance(input, torch.Tensor): input_type, ids = None, input elif isinstance(input, tuple): input_type, ids = ids[0], ids[1] else: raise ValueError(f"'{self.__class__.__name__}': found invalid input " f"type for feature lookup: '{type(input)}'") fut = self.async_get(ids, input_type) return fut.wait() def _local_selecting_get( self, ids: torch.Tensor, input_type: Optional[Union[NodeType, EdgeType]] = None ) -> torch.Tensor: r""" Select node/edge ids only in the local feature store and lookup features of them. Args: ids: input node/edge ids. input_type: input node/edge type for heterogeneous feature lookup. Return: PartialFeature: features and index for local node/edge ids. """ feat, pb = self._get_local_store(input_type) ids = ids.to(self.device) input_order= torch.arange(ids.size(0), dtype=torch.long, device=self.device) partition_ids = pb[ids].to(self.device) local_mask = (partition_ids == self.partition_idx) local_ids = torch.masked_select(ids, local_mask) local_index = torch.masked_select(input_order, local_mask) return feat[local_ids], local_index def _remote_selecting_get( self, ids: torch.Tensor, input_type: Optional[Union[NodeType, EdgeType]] = None ) -> torch.futures.Future: r""" Select node/edge ids only in the remote feature stores and fetch their features. Args: ids: input node/edge ids. input_type: input node/edge type for heterogeneous feature lookup. Return: torch.futures.Future: a torch future with a list of `PartialFeature`, which corresponds to partial features on different remote workers. """ assert ( self.rpc_callee_id is not None ), "Remote feature lookup is disabled in 'local_only' mode." _, pb = self._get_local_store(input_type) ids = ids.to(self.device) input_order= torch.arange(ids.size(0), dtype=torch.long, device=self.device) partition_ids = pb[ids].to(self.device) futs, indexes = [], [] for pidx in range(0, self.num_partitions): if pidx == self.partition_idx: continue remote_mask = (partition_ids == pidx) remote_ids = torch.masked_select(ids, remote_mask) if remote_ids.shape[0] > 0: to_worker = self.rpc_router.get_to_worker(pidx) futs.append(rpc_request_async(to_worker, self.rpc_callee_id, args=(remote_ids.cpu(), input_type))) indexes.append(torch.masked_select(input_order, remote_mask)) collect_fut = torch.futures.collect_all(futs) res_fut = torch.futures.Future() def on_done(*_): try: fut_list = collect_fut.wait() result = [] for i, fut in enumerate(fut_list): result.append((fut.wait(), indexes[i])) except Exception as e: res_fut.set_exception(e) else: res_fut.set_result(result) collect_fut.add_done_callback(on_done) return res_fut def _stitch( self, ids: torch.Tensor, local: PartialFeature, remotes: List[PartialFeature] ) -> torch.Tensor: r""" Stitch local and remote partial features into a complete one. Args: ids: the complete input node/edge ids. local: partial feature of local node/edge ids. remotes: partial feature list of remote node/edge ids. """ feat = torch.zeros(ids.shape[0], local[0].shape[1], dtype=local[0].dtype, device=self.device) feat[local[1].to(self.device)] = local[0].to(self.device) for remote in remotes: feat[remote[1].to(self.device)] = remote[0].to(self.device) return feat