# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from multiprocessing.reduction import ForkingPickler
from typing import Optional, Tuple, Union
import torch
from .. import py_graphlearn_torch as pywrap
from ..typing import TensorDataType
from ..utils import (
convert_to_tensor, share_memory, ptr2ind, coo_to_csr, coo_to_csc
)
[docs]class CSRTopo(object):
r""" Graph topology in CSR format.
Args:
edge_index (a 2D torch.Tensor or numpy.ndarray, or a tuple): The edge
index for graph topology.
edge_ids (torch.Tensor or numpy.ndarray, optional): The edge ids for
graph edges. If set to ``None``, it will be aranged by the edge size.
(default: ``None``)
layout (str): The edge layout representation for the input edge index,
should be 'COO' (rows and cols uncompressed), 'CSR' (rows compressed)
or 'CSC' (columns compressed). (default: 'COO')
"""
def __init__(self,
edge_index: Union[TensorDataType,
Tuple[TensorDataType, TensorDataType]],
edge_ids: Optional[TensorDataType] = None,
layout: str = 'COO'):
layout = str(layout).upper()
if layout not in ['COO', 'CSR', 'CSC']:
raise RuntimeError(f"'{self.__class__.__name__}': got "
f"invalid edge layout {layout}")
edge_index = convert_to_tensor(edge_index, dtype=torch.int64)
row, col = edge_index[0], edge_index[1]
num_edges = max(row.numel(), col.numel())
edge_ids = convert_to_tensor(edge_ids, dtype=torch.int64)
if edge_ids is None:
edge_ids = torch.arange(num_edges, dtype=torch.int64, device=row.device)
else:
assert edge_ids.numel() == num_edges
if layout == 'CSR':
self._indptr, self._indices = row, col
self._edge_ids = edge_ids
else:
if layout == 'CSC':
col = ptr2ind(col)
self._indptr, self._indices, self._edge_ids = \
coo_to_csr(row, col, edge_value=edge_ids)
[docs] def to_coo(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r""" Convert to COO format.
Returns:
row indice tensor, column indice tensor and edge id tensor
"""
return ptr2ind(self._indptr), self._indices, self._edge_ids
[docs] def to_csc(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r""" Convert to CSC format.
Returns:
row indice tensor, column ptr tensor and edge id tensor
"""
row, col, edge_ids = self.to_coo()
return coo_to_csc(row, col, edge_value=edge_ids)
@property
def indptr(self):
return self._indptr
@property
def indices(self):
return self._indices
@property
def edge_ids(self):
r""" local edge ids in CSR.
"""
return self._edge_ids
@property
def degrees(self):
return self._indptr[1:] - self._indptr[:-1]
@property
def row_count(self):
return self._indptr.shape[0] - 1
@property
def edge_count(self):
return self._indices.shape[0]
[docs] def share_memory_(self):
self._indptr = share_memory(self._indptr)
self._indices = share_memory(self._indices)
self._edge_ids = share_memory(self._edge_ids)
def __getitem__(self, key):
return getattr(self, key, None)
def __setitem__(self, key, value):
setattr(self, key, value)
[docs]class Graph(object):
r""" A graph object used for graph operations such as sampling.
There are three modes supported:
1.'CPU': graph data are stored in the CPU memory and graph
operations are also executed on CPU.
2.'ZERO_COPY': graph data are stored in the pinned CPU memory and graph
operations are executed on GPU.
3.'CUDA': graph data are stored in the GPU memory and graph operations
are executed on GPU.
Args:
csr_topo (CSRTopo): An instance of ``CSRTopo`` with graph topology data.
mode (str): The graph operation mode, must be 'CPU', 'ZERO_COPY' or 'CUDA'.
(Default: 'ZERO_COPY').
device (int, optional): The target cuda device rank to perform graph
operations. Note that this parameter will be ignored if the graph mode
set to 'CPU'. The value of ``torch.cuda.current_device()`` will be used
if set to ``None``. (Default: ``None``).
"""
def __init__(self, csr_topo: CSRTopo, mode = 'ZERO_COPY',
device: Optional[int] = None):
self.csr_topo = csr_topo
self.csr_topo.share_memory_()
self.mode = mode.upper()
self.device = device
if self.mode != 'CPU' and self.device is not None:
self.device = int(self.device)
assert (
self.device >= 0 and self.device < torch.cuda.device_count()
), f"'{self.__class__.__name__}': invalid device rank {self.device}"
self._graph = None
[docs] def lazy_init(self):
if self._graph is not None:
return
self._graph = pywrap.Graph()
indptr = self.csr_topo.indptr
indices = self.csr_topo.indices
if self.csr_topo.edge_ids is not None:
edge_ids = self.csr_topo.edge_ids
else:
edge_ids = torch.empty(0)
if self.mode == 'CPU':
self._graph.init_cpu_from_csr(indptr, indices, edge_ids)
else:
if self.device is None:
self.device = torch.cuda.current_device()
if self.mode == 'CUDA':
self._graph.init_cuda_from_csr(
indptr, indices, self.device, pywrap.GraphMode.DMA, edge_ids
)
elif self.mode == 'ZERO_COPY':
self._graph.init_cuda_from_csr(
indptr, indices, self.device, pywrap.GraphMode.ZERO_COPY, edge_ids
)
else:
raise ValueError(f"'{self.__class__.__name__}': "
f"invalid mode {self.mode}")
[docs] def share_ipc(self):
r""" Create ipc handle for multiprocessing.
Returns:
A tuple of csr_topo and graph mode.
"""
return self.csr_topo, self.mode
[docs] @classmethod
def from_ipc_handle(cls, ipc_handle):
r""" Create from ipc handle.
"""
csr_topo, mode = ipc_handle
return cls(csr_topo, mode, device=None)
@property
def row_count(self):
self.lazy_init()
return self._graph.get_row_count()
@property
def col_count(self):
self.lazy_init()
return self._graph.get_col_count()
@property
def edge_count(self):
self.lazy_init()
return self._graph.get_edge_count()
@property
def graph_handler(self):
r""" Get a pointer to the underlying graph object for graph operations
such as sampling.
"""
self.lazy_init()
return self._graph
## Pickling Registration
[docs]def rebuild_graph(ipc_handle):
graph = Graph.from_ipc_handle(ipc_handle)
return graph
[docs]def reduce_graph(graph: Graph):
ipc_handle = graph.share_ipc()
return (rebuild_graph, (ipc_handle, ))
ForkingPickler.register(Graph, reduce_graph)