Source code for graphlearn_torch.loader.transform

# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, eithPer express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Dict, Optional

import torch
from torch_geometric.data import Data, HeteroData

from ..sampler import SamplerOutput, HeteroSamplerOutput
from ..typing import NodeType, EdgeType, reverse_edge_type


[docs]def to_data( sampler_out: SamplerOutput, batch_labels: Optional[torch.Tensor] = None, node_feats: Optional[torch.Tensor] = None, edge_feats: Optional[torch.Tensor] = None, **kwargs ) -> Data: edge_index = torch.stack([sampler_out.row, sampler_out.col]) data = Data(x=node_feats, edge_index=edge_index, edge_attr=edge_feats, y=batch_labels, **kwargs) data.edge = sampler_out.edge data.node = sampler_out.node if sampler_out.metadata is None: data.batch = sampler_out.batch data.batch_size = sampler_out.batch.numel() if data.batch is not None else 0 elif isinstance(sampler_out.metadata, dict): if 'edge_label_index' in sampler_out.metadata: # binary negative sampling # In this case, we reverse the edge_label_index and put it into the # reversed edgetype subgraph edge_label_index = torch.stack( (sampler_out.metadata['edge_label_index'][1], sampler_out.metadata['edge_label_index'][0]), dim=0) data.edge_label_index = edge_label_index data.edge_label = sampler_out.metadata['edge_label'] elif 'src_index' in sampler_out.metadata: # triplet negative sampling # In this case, src_index and dst_pos/neg_index fields follow the nodetype data.src_index = sampler_out.metadata['src_index'] data.dst_pos_index = sampler_out.metadata['dst_pos_index'] data.dst_neg_index = sampler_out.metadata['dst_neg_index'] else: pass return data
[docs]def to_hetero_data( hetero_sampler_out: HeteroSamplerOutput, batch_label_dict: Optional[Dict[NodeType, torch.Tensor]] = None, node_feat_dict: Optional[Dict[NodeType, torch.Tensor]] = None, edge_feat_dict: Optional[Dict[EdgeType, torch.Tensor]] = None, **kwargs ) -> HeteroData: data = HeteroData(**kwargs) edge_index_dict = hetero_sampler_out.get_edge_index() # edges for k, v in edge_index_dict.items(): data[k].edge_index = v if hetero_sampler_out.edge is not None: data[k].edge = hetero_sampler_out.edge.get(k, None) if edge_feat_dict is not None: data[k].edge_attr = edge_feat_dict.get(k, None) # nodes for k, v in hetero_sampler_out.node.items(): data[k].node = v if node_feat_dict is not None: data[k].x = node_feat_dict.get(k, None) input_type = hetero_sampler_out.input_type # seed nodes if input_type is None or isinstance(input_type, NodeType): for k, v in hetero_sampler_out.batch.items(): data[k].batch = v data[k].batch_size = v.numel() if batch_label_dict is not None: data[k].y = batch_label_dict.get(k, None) # seed edges else: rev_input_type = reverse_edge_type(input_type) if 'edge_label_index' in hetero_sampler_out.metadata: # binary negative sampling # In this case, we reverse the edge_label_index and put it into the # reversed edgetype subgraph edge_label_index = torch.stack( (hetero_sampler_out.metadata['edge_label_index'][1], hetero_sampler_out.metadata['edge_label_index'][0]), dim=0) data[rev_input_type].edge_label_index = edge_label_index data[rev_input_type].edge_label = hetero_sampler_out.metadata['edge_label'] elif 'src_index' in hetero_sampler_out.metadata: # triplet negative sampling # In this case, src_index and dst_pos/neg_index fields follow the nodetype data[rev_input_type[-1]].src_index = hetero_sampler_out.metadata['src_index'] data[rev_input_type[0]].dst_pos_index = hetero_sampler_out.metadata['dst_pos_index'] data[rev_input_type[0]].dst_neg_index = hetero_sampler_out.metadata['dst_neg_index'] return data