# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import math
import queue
from dataclasses import dataclass
from typing import List, Optional, Union, Tuple
import torch
from .. import py_graphlearn_torch as pywrap
from ..channel import ChannelBase, SampleMessage
from ..sampler import (
NodeSamplerInput, EdgeSamplerInput,
NeighborOutput, SamplerOutput, HeteroSamplerOutput,
NeighborSampler
)
from ..typing import EdgeType, as_str, NumNeighbors
from ..utils import (
get_available_device, ensure_device, merge_dict, id2idx,
merge_hetero_sampler_output, format_hetero_sampler_output,
id2idx_v2
)
from .dist_dataset import DistDataset
from .dist_feature import DistFeature
from .dist_graph import DistGraph
from .event_loop import ConcurrentEventLoop, wrap_torch_future
from .rpc import (
RpcCalleeBase, rpc_register, rpc_request_async,
RpcDataPartitionRouter, rpc_sync_data_partitions
)
[docs]@dataclass
class PartialNeighborOutput:
r""" The sampled neighbor output of a subset of the original ids.
* index: the index of the subset vertex ids.
* output: the sampled neighbor output.
"""
index: torch.Tensor
output: NeighborOutput
[docs]class RpcSamplingCallee(RpcCalleeBase):
r""" A wrapper for rpc callee that will perform rpc sampling from
remote processes.
"""
def __init__(self, sampler: NeighborSampler, device: torch.device):
super().__init__()
self.sampler = sampler
self.device = device
[docs] def call(self, *args, **kwargs):
ensure_device(self.device)
output = self.sampler.sample_one_hop(*args, **kwargs)
return output.to(torch.device('cpu'))
[docs]class RpcSubGraphCallee(RpcCalleeBase):
r""" A wrapper for rpc callee that will perform rpc sampling from
remote processes.
"""
def __init__(self, sampler: NeighborSampler, device: torch.device):
super().__init__()
self.sampler = sampler
self.device = device
[docs] def call(self, *args, **kwargs):
ensure_device(self.device)
with_edge = kwargs['with_edge']
output = self.sampler.subgraph_op.node_subgraph(args[0].to(self.device),
with_edge)
eids = output.eids.to('cpu') if with_edge else None
return output.nodes.to('cpu'), output.rows.to('cpu'), output.cols.to('cpu'), eids
[docs]class DistNeighborSampler(ConcurrentEventLoop):
r""" Asynchronized and distributed neighbor sampler.
Args:
data (DistDataset): The graph and feature data with partition info.
num_neighbors (NumNeighbors): The number of sampling neighbors on each hop.
with_edge (bool): Whether to sample with edge ids. (default: ``None``).
collect_features (bool): Whether collect features for sampled results.
(default: ``None``).
channel (ChannelBase, optional): The message channel to send sampled
results. If set to `None`, the sampled results will be returned
directly with `sample_from_nodes`. (default: ``None``).
concurrency (int): The max number of concurrent seed batches processed by
the current sampler. (default: ``1``).
device: The device to use for sampling. If set to ``None``, the current
cuda device (got by ``torch.cuda.current_device``) will be used if
available, otherwise, the cpu device will be used. (default: ``None``).
"""
def __init__(self,
data: DistDataset,
num_neighbors: Optional[NumNeighbors] = None,
with_edge: bool = False,
with_neg: bool = False,
collect_features: bool = False,
channel: Optional[ChannelBase] = None,
concurrency: int = 1,
device: Optional[torch.device] = None):
self.data = data
self.num_neighbors = num_neighbors
self.max_input_size = 0
self.with_edge = with_edge
self.with_neg = with_neg
self.collect_features = collect_features
self.channel = channel
self.concurrency = concurrency
self.device = get_available_device(device)
if isinstance(data, DistDataset):
partition2workers = rpc_sync_data_partitions(
num_data_partitions=self.data.num_partitions,
current_partition_idx=self.data.partition_idx
)
self.rpc_router = RpcDataPartitionRouter(partition2workers)
self.dist_graph = DistGraph(
data.num_partitions, data.partition_idx,
data.graph, data.node_pb, data.edge_pb
)
self.dist_node_feature = None
self.dist_edge_feature = None
if self.collect_features:
if data.node_features is not None:
self.dist_node_feature = DistFeature(
data.num_partitions, data.partition_idx,
data.node_features, data.node_feat_pb,
local_only=False, rpc_router=self.rpc_router, device=self.device
)
if self.with_edge and data.edge_features is not None:
self.dist_edge_feature = DistFeature(
data.num_partitions, data.partition_idx,
data.edge_features, data.edge_feat_pb,
local_only=False, rpc_router=self.rpc_router, device=self.device
)
else:
raise ValueError(f"'{self.__class__.__name__}': found invalid input "
f"data type '{type(data)}'")
self.sampler = NeighborSampler(
self.dist_graph.local_graph, self.num_neighbors,
self.device, self.with_edge, self.with_neg
)
self.inducer_pool = queue.Queue(maxsize=self.concurrency)
# rpc register
rpc_sample_callee = RpcSamplingCallee(self.sampler, self.device)
self.rpc_sample_callee_id = rpc_register(rpc_sample_callee)
rpc_subgraph_callee = RpcSubGraphCallee(self.sampler, self.device)
self.rpc_subgraph_callee_id = rpc_register(rpc_subgraph_callee)
if self.dist_graph.data_cls == 'hetero':
self.num_neighbors = self.sampler.num_neighbors
self.num_hops = self.sampler.num_hops
self.edge_types = self.sampler.edge_types
super().__init__(self.concurrency)
self._loop.call_soon_threadsafe(ensure_device, self.device)
[docs] def sample_from_nodes(
self,
inputs: NodeSamplerInput,
**kwargs
) -> Optional[SampleMessage]:
r""" Sample multi-hop neighbors from nodes, collect the remote features
(optional), and send results to the output channel.
Note that if the output sample channel is specified, this func is
asynchronized and the sampled result will not be returned directly.
Otherwise, this func will be blocked to wait for the sampled result and
return it.
Args:
inputs (NodeSamplerInput): The input data with node indices to start
sampling from.
"""
inputs = NodeSamplerInput.cast(inputs)
if self.channel is None:
return self.run_task(coro=self._send_adapter(self._sample_from_nodes,
inputs))
cb = kwargs.get('callback', None)
self.add_task(coro=self._send_adapter(self._sample_from_nodes, inputs),
callback=cb)
return None
[docs] def sample_from_edges(
self,
inputs: EdgeSamplerInput,
**kwargs,
) -> Optional[SampleMessage]:
r""" Sample multi-hop neighbors from edges, collect the remote features
(optional), and send results to the output channel.
Note that if the output sample channel is specified, this func is
asynchronized and the sampled result will not be returned directly.
Otherwise, this func will be blocked to wait for the sampled result and
return it.
Args:
inputs (EdgeSamplerInput): The input data for sampling from edges
including the (1) source node indices, the (2) destination node
indices, the (3) optional edge labels and the (4) input edge type.
"""
if self.channel is None:
return self.run_task(coro=self._send_adapter(self._sample_from_edges,
inputs))
cb = kwargs.get('callback', None)
self.add_task(coro=self._send_adapter(self._sample_from_edges, inputs),
callback=cb)
return None
[docs] def subgraph(
self,
inputs: NodeSamplerInput,
**kwargs
) -> Optional[SampleMessage]:
r""" Induce an enclosing subgraph based on inputs and their neighbors(if
self.num_neighbors is not None).
"""
inputs = NodeSamplerInput.cast(inputs)
if self.channel is None:
return self.run_task(coro=self._send_adapter(self._subgraph, inputs))
cb = kwargs.get('callback', None)
self.add_task(coro=self._send_adapter(self._subgraph, inputs), callback=cb)
return None
async def _send_adapter(
self,
async_func,
*args, **kwargs
) -> Optional[SampleMessage]:
sampler_output = await async_func(*args, **kwargs)
res = await self._colloate_fn(sampler_output)
if self.channel is None:
return res
self.channel.send(res)
return None
async def _sample_from_nodes(
self,
inputs: NodeSamplerInput,
) -> Optional[SampleMessage]:
input_seeds = inputs.node.to(self.device)
input_type = inputs.input_type
self.max_input_size = max(self.max_input_size, input_seeds.numel())
inducer = self._acquire_inducer()
is_hetero = (self.dist_graph.data_cls == 'hetero')
if is_hetero:
assert input_type is not None
src_dict = inducer.init_node({input_type: input_seeds})
batch_size = src_dict[input_type].numel()
out_nodes, out_rows, out_cols, out_edges = {}, {}, {}, {}
merge_dict(src_dict, out_nodes)
for i in range(self.num_hops):
task_dict, nbr_dict, edge_dict = {}, {}, {}
for etype in self.edge_types:
srcs = src_dict.get(etype[0], None)
req_num = self.num_neighbors[etype][i]
if srcs is not None:
task_dict[etype] = self._loop.create_task(
self._sample_one_hop(srcs, req_num, etype))
for etype, task in task_dict.items():
output: NeighborOutput = await task
nbr_dict[etype] = [src_dict[etype[0]], output.nbr, output.nbr_num]
if output.edge is not None:
edge_dict[etype] = output.edge
nodes_dict, rows_dict, cols_dict = inducer.induce_next(nbr_dict)
merge_dict(nodes_dict, out_nodes)
merge_dict(rows_dict, out_rows)
merge_dict(cols_dict, out_cols)
merge_dict(edge_dict, out_edges)
src_dict = nodes_dict
sample_output = HeteroSamplerOutput(
node={ntype: torch.cat(nodes) for ntype, nodes in out_nodes.items()},
row={etype: torch.cat(rows) for etype, rows in out_rows.items()},
col={etype: torch.cat(cols) for etype, cols in out_cols.items()},
edge=(
{etype: torch.cat(eids) for etype, eids in out_edges.items()}
if self.with_edge else None
),
metadata={'input_type': input_type, 'bs': batch_size}
)
else:
srcs = inducer.init_node(input_seeds)
batch_size = srcs.numel()
out_nodes, out_edges = [], []
out_nodes.append(srcs)
# Sample subgraph.
for req_num in self.num_neighbors:
output: NeighborOutput = await self._sample_one_hop(srcs, req_num, None)
nodes, rows, cols = \
inducer.induce_next(srcs, output.nbr, output.nbr_num)
out_nodes.append(nodes)
out_edges.append((rows, cols, output.edge))
srcs = nodes
sample_output = SamplerOutput(
node=torch.cat(out_nodes),
row=torch.cat([e[0] for e in out_edges]),
col=torch.cat([e[1] for e in out_edges]),
edge=(torch.cat([e[2] for e in out_edges]) if self.with_edge else None),
metadata={'input_type': None, 'bs': batch_size}
)
# Reclaim inducer into pool.
self.inducer_pool.put(inducer)
return sample_output
async def _sample_from_edges(
self,
inputs: EdgeSamplerInput,
) -> Optional[SampleMessage]:
r"""Performs sampling from an edge sampler input, leveraging a sampling
function of the same signature as `node_sample`.
Currently, we support the out-edge sampling manner, so we reverse the
direction of src and dst for the output so that features of the sampled
nodes during training can be aggregated from k-hop to (k-1)-hop nodes.
Note: Negative sampling is performed locally and unable to fetch positive
edges from remote, so the negative sampling in the distributed case is
currently non-strict for both binary and triplet manner.
"""
src = inputs.row.to(self.device)
dst = inputs.col.to(self.device)
edge_label = None if inputs.label is None else inputs.label.to(self.device)
input_type = inputs.input_type
neg_sampling = inputs.neg_sampling
num_pos = src.numel()
num_neg = 0
# Negative Sampling
self.sampler.lazy_init_neg_sampler()
if neg_sampling is not None:
# When we are doing negative sampling, we append negative information
# of nodes/edges to `src`, `dst`.
# Later on, we can easily reconstruct what belongs to positive and
# negative examples by slicing via `num_pos`.
num_neg = math.ceil(num_pos * neg_sampling.amount)
if neg_sampling.is_binary():
# In the "binary" case, we randomly sample negative pairs of nodes.
if input_type is not None:
neg_pair = self.sampler._neg_sampler[input_type].sample(num_neg)
else:
neg_pair = self.sampler._neg_sampler.sample(num_neg)
src_neg, dst_neg = neg_pair[0], neg_pair[1]
src = torch.cat([src, src_neg], dim=0)
dst = torch.cat([dst, dst_neg], dim=0)
if edge_label is None:
edge_label = torch.ones(num_pos, device=self.device)
size = (num_neg, ) + edge_label.size()[1:]
edge_neg_label = edge_label.new_zeros(size)
edge_label = torch.cat([edge_label, edge_neg_label])
elif neg_sampling.is_triplet():
assert num_neg % num_pos == 0
if input_type is not None:
neg_pair = self.sampler._neg_sampler[input_type].sample(num_neg, padding=True)
else:
neg_pair = self.sampler._neg_sampler.sample(num_neg, padding=True)
dst_neg = neg_pair[1]
dst = torch.cat([dst, dst_neg], dim=0)
assert edge_label is None
# Neighbor Sampling
if input_type is not None: # hetero
if input_type[0] != input_type[-1]: # Two distinct node types:
src_seed, dst_seed = src, dst
src, inverse_src = src.unique(return_inverse=True)
dst, inverse_dst = dst.unique(return_inverse=True)
seed_dict = {input_type[0]: src, input_type[-1]: dst}
else: # Only a single node type: Merge both source and destination.
seed = torch.cat([src, dst], dim=0)
seed, inverse_seed = seed.unique(return_inverse=True)
seed_dict = {input_type[0]: seed}
temp_out = []
for it, node in seed_dict.items():
seeds = NodeSamplerInput(node=node, input_type=it)
temp_out.append(await self._sample_from_nodes(seeds))
if len(temp_out) == 2:
out = merge_hetero_sampler_output(temp_out[0],
temp_out[1],
device=self.device)
else:
out = format_hetero_sampler_output(temp_out[0])
# edge_label
if neg_sampling is None or neg_sampling.is_binary():
if input_type[0] != input_type[-1]:
inverse_src = id2idx_v2(src_seed, out.node[input_type[0]])
inverse_dst = id2idx_v2(dst_seed, out.node[input_type[-1]])
edge_label_index = torch.stack([
inverse_src,
inverse_dst,
], dim=0)
else:
edge_label_index = inverse_seed.view(2, -1)
out.metadata.update({'edge_label_index': edge_label_index,
'edge_label': edge_label})
out.input_type = input_type
elif neg_sampling.is_triplet():
if input_type[0] != input_type[-1]:
inverse_src = id2idx_v2(src_seed, out.node[input_type[0]])
inverse_dst = id2idx_v2(dst_seed, out.node[input_type[-1]])
src_index = inverse_src
dst_pos_index = inverse_dst[:num_pos]
dst_neg_index = inverse_dst[num_pos:]
else:
src_index = inverse_seed[:num_pos]
dst_pos_index = inverse_seed[num_pos:2 * num_pos]
dst_neg_index = inverse_seed[2 * num_pos:]
dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1)
out.metadata.update({'src_index': src_index,
'dst_pos_index': dst_pos_index,
'dst_neg_index': dst_neg_index})
out.input_type = input_type
else: #homo
seed = torch.cat([src, dst], dim=0)
seed, inverse_seed = seed.unique(return_inverse=True)
out = await self._sample_from_nodes(NodeSamplerInput.cast(seed))
# edge_label
if neg_sampling is None or neg_sampling.is_binary():
edge_label_index = inverse_seed.view(2, -1)
out.metadata.update({'edge_label_index': edge_label_index,
'edge_label': edge_label})
elif neg_sampling.is_triplet():
src_index = inverse_seed[:num_pos]
dst_pos_index = inverse_seed[num_pos:2 * num_pos]
dst_neg_index = inverse_seed[2 * num_pos:]
dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1)
out.metadata.update({'src_index': src_index,
'dst_pos_index': dst_pos_index,
'dst_neg_index': dst_neg_index})
return out
async def _subgraph(
self,
inputs: NodeSamplerInput,
) -> Optional[SampleMessage]:
inputs = NodeSamplerInput.cast(inputs)
input_seeds = inputs.node.to(self.device)
is_hetero = (self.dist_graph.data_cls == 'hetero')
if is_hetero:
raise NotImplementedError
else:
# neighbor sampling.
if self.num_neighbors is not None:
nodes = [input_seeds]
for num in self.num_neighbors:
nbr = await self._sample_one_hop(nodes[-1], num, None)
nodes.append(torch.unique(nbr.nbr))
nodes = torch.cat(nodes)
else:
nodes = input_seeds
nodes, mapping = torch.unique(nodes, return_inverse=True)
nid2idx = id2idx(nodes)
# subgraph inducing.
partition_ids = self.dist_graph.get_node_partitions(nodes)
partition_ids = partition_ids.to(self.device)
rows, cols, eids, futs = [], [], [], []
for i in range(self.data.num_partitions):
pidx = (self.data.partition_idx + i) % self.data.num_partitions
p_ids = torch.masked_select(nodes, (partition_ids == pidx))
if p_ids.shape[0] > 0:
if pidx == self.data.partition_idx:
subgraph = self.sampler.subgraph_op.node_subgraph(nodes, self.with_edge)
# relabel row and col indices.
rows.append(nid2idx[subgraph.nodes[subgraph.rows]])
cols.append(nid2idx[subgraph.nodes[subgraph.cols]])
if self.with_edge:
eids.append(subgraph.eids.to(self.device))
else:
to_worker = self.rpc_router.get_to_worker(pidx)
futs.append(rpc_request_async(to_worker,
self.rpc_subgraph_callee_id,
args=(nodes.cpu(),),
kwargs={'with_edge': self.with_edge}))
if not len(futs) == 0:
res_fut_list = await wrap_torch_future(torch.futures.collect_all(futs))
for res_fut in res_fut_list:
res_nodes, res_rows, res_cols, res_eids = res_fut.wait()
res_nodes = res_nodes.to(self.device)
rows.append(nid2idx[res_nodes[res_rows]])
cols.append(nid2idx[res_nodes[res_cols]])
if self.with_edge:
eids.append(res_eids.to(self.device))
sample_output = SamplerOutput(
node=nodes,
row=torch.cat(rows),
col=torch.cat(cols),
edge=torch.cat(eids) if self.with_edge else None,
device=self.device,
metadata={'mapping': mapping[:input_seeds.numel()]})
return sample_output
def _acquire_inducer(self):
if self.inducer_pool.empty():
return self.sampler.create_inducer(self.max_input_size)
return self.inducer_pool.get()
def _stitch_sample_results(
self,
input_seeds: torch.Tensor,
results: List[PartialNeighborOutput]
) -> NeighborOutput:
r""" Stitch partitioned neighbor outputs into a complete one.
"""
idx_list = [r.index for r in results]
nbrs_list = [r.output.nbr for r in results]
nbrs_num_list = [r.output.nbr_num for r in results]
eids_list = [r.output.edge for r in results] if self.with_edge else []
if self.device.type == 'cuda':
nbrs, nbrs_num, eids = pywrap.cuda_stitch_sample_results(
input_seeds, idx_list, nbrs_list, nbrs_num_list, eids_list)
else:
nbrs, nbrs_num, eids = pywrap.cpu_stitch_sample_results(
input_seeds, idx_list, nbrs_list, nbrs_num_list, eids_list)
return NeighborOutput(nbrs, nbrs_num, eids)
async def _sample_one_hop(
self,
srcs: torch.Tensor,
num_nbr: int,
etype: Optional[EdgeType]
) -> NeighborOutput:
r""" Sample one-hop neighbors and induce the coo format subgraph.
Args:
srcs: input ids, 1D tensor.
num_nbr: request(max) number of neighbors for one hop.
etype: edge type to sample from input ids.
Returns:
Tuple[torch.Tensor, torch.Tensor]: unique node ids and edge_index.
"""
device = self.device
srcs = srcs.to(device)
orders = torch.arange(srcs.size(0), dtype=torch.long, device=device)
src_ntype = etype[0] if etype is not None else None
partition_ids = self.dist_graph.get_node_partitions(srcs, src_ntype)
partition_ids = partition_ids.to(device)
partition_results: List[PartialNeighborOutput] = []
remote_orders_list: List[torch.Tensor] = []
futs: List[torch.futures.Future] = []
for i in range(self.data.num_partitions):
pidx = (
(self.data.partition_idx + i) % self.data.num_partitions
)
p_mask = (partition_ids == pidx)
p_ids = torch.masked_select(srcs, p_mask)
if p_ids.shape[0] > 0:
p_orders = torch.masked_select(orders, p_mask)
if pidx == self.data.partition_idx:
p_nbr_out = self.sampler.sample_one_hop(p_ids, num_nbr, etype)
partition_results.append(PartialNeighborOutput(p_orders, p_nbr_out))
else:
remote_orders_list.append(p_orders)
to_worker = self.rpc_router.get_to_worker(pidx)
futs.append(rpc_request_async(to_worker,
self.rpc_sample_callee_id,
args=(p_ids.cpu(), num_nbr, etype)))
# Without remote sampling results.
if len(remote_orders_list) == 0:
return partition_results[0].output
# With remote sampling results.
res_fut_list = await wrap_torch_future(torch.futures.collect_all(futs))
for i, res_fut in enumerate(res_fut_list):
partition_results.append(
PartialNeighborOutput(
index=remote_orders_list[i],
output=res_fut.wait().to(device)
)
)
return self._stitch_sample_results(srcs, partition_results)
async def _colloate_fn(
self,
output: Union[SamplerOutput, HeteroSamplerOutput]
) -> SampleMessage:
r""" Collect labels and features for the sampled subgrarph if necessary,
and put them into a sample message.
"""
is_hetero = (self.dist_graph.data_cls == 'hetero')
result_map = {}
if isinstance(output.metadata, dict):
#scan kv and add metadata
input_type = output.metadata.get('input_type', '')
batch_size = output.metadata.get('bs', 1)
result_map['meta'] = torch.LongTensor([int(is_hetero), batch_size])
output.metadata.pop('input_type', '')
output.metadata.pop('bs', 1)
for k, v in output.metadata.items():
result_map[k] = v
if is_hetero:
for ntype, nodes in output.node.items():
result_map[f'{as_str(ntype)}.ids'] = nodes
for etype, rows in output.row.items():
etype_str = as_str(etype)
result_map[f'{etype_str}.rows'] = rows
result_map[f'{etype_str}.cols'] = output.col[etype]
if self.with_edge:
result_map[f'{etype_str}.eids'] = output.edge[etype]
# Collect node labels of input node type.
if not isinstance(input_type, Tuple):
node_labels = self.data.get_node_label(input_type)
if node_labels is not None:
result_map[f'{as_str(input_type)}.nlabels'] = \
node_labels[output.node[input_type]]
# Collect node features.
if self.dist_node_feature is not None:
nfeat_fut_dict = {}
for ntype, nodes in output.node.items():
nodes = nodes.to(torch.long)
nfeat_fut_dict[ntype] = self.dist_node_feature.async_get(nodes, ntype)
for ntype, fut in nfeat_fut_dict.items():
nfeats = await wrap_torch_future(fut)
result_map[f'{as_str(ntype)}.nfeats'] = nfeats
# Collect edge features
if self.dist_edge_feature is not None and self.with_edge:
efeat_fut_dict = {}
for etype in self.edge_types:
eids = result_map.get(f'{as_str(etype)}.eids', None).to(torch.long)
if eids is not None:
efeat_fut_dict[etype] = self.dist_edge_feature.async_get(eids, etype)
for etype, fut in efeat_fut_dict.items():
efeats = await wrap_torch_future(fut)
result_map[f'{as_str(etype)}.efeats'] = efeats
else:
result_map['ids'] = output.node
result_map['rows'] = output.row
result_map['cols'] = output.col
if self.with_edge:
result_map['eids'] = output.edge
# Collect node labels.
node_labels = self.data.get_node_label()
if node_labels is not None:
result_map['nlabels'] = node_labels[output.node]
# Collect node features.
if self.dist_node_feature is not None:
fut = self.dist_node_feature.async_get(output.node)
nfeats = await wrap_torch_future(fut)
result_map['nfeats'] = nfeats
# Collect edge features.
if self.dist_edge_feature is not None:
eids = result_map['eids']
fut = self.dist_edge_feature.async_get(eids)
efeats = await wrap_torch_future(fut)
result_map['efeats'] = efeats
return result_map