# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import logging
import time
import threading
from typing import Dict, Optional, Union
import torch
from ..channel import ShmChannel
from ..sampler import NodeSamplerInput, EdgeSamplerInput, SamplingConfig
from .dist_context import get_context, _set_server_context
from .dist_dataset import DistDataset
from .dist_options import RemoteDistSamplingWorkerOptions
from .dist_sampling_producer import DistMpSamplingProducer
from .rpc import barrier, init_rpc, shutdown_rpc
SERVER_EXIT_STATUS_CHECK_INTERVAL = 5.0
r""" Interval (in seconds) to check exit status of server.
"""
[docs]class DistServer(object):
r""" A server that supports launching remote sampling workers for
training clients.
Note that this server is enabled only when the distribution mode is a
server-client framework, and the graph and feature store will be partitioned
and managed by all server nodes.
Args:
dataset (DistDataset): The ``DistDataset`` object of a partition of graph
data and feature data, along with distributed patition books.
"""
def __init__(self, dataset: DistDataset):
self.dataset = dataset
self._lock = threading.RLock()
self._exit = False
self._producer_idx = 0
self._producer_pool: Dict[int, DistMpSamplingProducer] = {}
self._msg_buffer_pool: Dict[int, ShmChannel] = {}
[docs] def shutdown(self):
for producer_id in list(self._producer_pool.keys()):
self.destroy_sampling_producer(producer_id)
assert len(self._producer_pool) == 0
assert len(self._msg_buffer_pool) == 0
[docs] def wait_for_exit(self):
r""" Block until the exit flag been set to ``True``.
"""
while not self._exit:
time.sleep(SERVER_EXIT_STATUS_CHECK_INTERVAL)
[docs] def exit(self):
r""" Set the exit flag to ``True``.
"""
self._exit = True
return self._exit
[docs] def create_sampling_producer(
self,
sampler_input: Union[NodeSamplerInput, EdgeSamplerInput],
sampling_config: SamplingConfig,
worker_options: RemoteDistSamplingWorkerOptions,
) -> int:
r""" Create and initialize an instance of ``DistSamplingProducer`` with
a group of subprocesses for distributed sampling.
Args:
sampler_input (NodeSamplerInput or EdgeSamplerInput): The input data
for sampling.
sampling_config (SamplingConfig): Configuration of sampling meta info.
worker_options (RemoteDistSamplingWorkerOptions): Options for launching
remote sampling workers by this server.
Returns:
A unique id of created sampling producer on this server.
"""
buffer = ShmChannel(
worker_options.buffer_capacity, worker_options.buffer_size
)
producer = DistMpSamplingProducer(
self.dataset, sampler_input, sampling_config, worker_options, buffer
)
producer.init()
with self._lock:
producer_id = self._producer_idx
self._producer_pool[producer_id] = producer
self._msg_buffer_pool[producer_id] = buffer
self._producer_idx += 1
return producer_id
[docs] def destroy_sampling_producer(self, producer_id: int):
r""" Shutdown and destroy a sampling producer managed by this server with
its producer id.
"""
producer = self._producer_pool.get(producer_id, None)
if producer is not None:
producer.shutdown()
with self._lock:
self._producer_pool.pop(producer_id)
self._msg_buffer_pool.pop(producer_id)
[docs] def start_new_epoch_sampling(self, producer_id: int):
r""" Start a new epoch sampling tasks for a specific sampling producer
with its producer id.
"""
producer = self._producer_pool.get(producer_id, None)
if producer is not None:
producer.produce_all()
[docs] def fetch_one_sampled_message(self, producer_id: int):
r""" Fetch a sampled message from the buffer of a specific sampling
producer with its producer id.
"""
buffer = self._msg_buffer_pool.get(producer_id, None)
if buffer is None:
return None
return buffer.recv()
_dist_server: DistServer = None
r""" ``DistServer`` instance of the current process.
"""
[docs]def get_server() -> DistServer:
r""" Get the ``DistServer`` instance on the current process.
"""
return _dist_server
[docs]def init_server(num_servers: int, num_clients: int, server_rank: int,
dataset: DistDataset, master_addr: str, master_port: int,
num_rpc_threads: int = 16, request_timeout: int = 180,
server_group_name: Optional[str] = None,):
r""" Initialize the current process as a server and establish connections
with all other servers and clients. Note that this method should be called
only in the server-client distribution mode.
Args:
num_servers (int): Number of processes participating in the server group.
num_clients (int): Number of processes participating in the client group.
server_rank (int): Rank of the current process withing the server group (it
should be a number between 0 and ``num_servers``-1).
dataset (DistDataset): The ``DistDataset`` object of a partition of graph
data and feature data, along with distributed patition book info.
master_addr (str): The master TCP address for RPC connection between all
servers and clients, the value of this parameter should be same for all
servers and clients.
master_port (int): The master TCP port for RPC connection between all
servers and clients, the value of this parameter should be same for all
servers and clients.
num_rpc_threads (int): The number of RPC worker threads used for the
current server to respond remote requests. (Default: ``16``).
request_timeout (int): The max timeout seconds for remote requests,
otherwise an exception will be raised. (Default: ``16``).
server_group_name (str): A unique name of the server group that current
process belongs to. If set to ``None``, a default name will be used.
(Default: ``None``).
"""
_set_server_context(num_servers, num_clients, server_rank, server_group_name)
global _dist_server
_dist_server = DistServer(dataset=dataset)
init_rpc(master_addr, master_port, num_rpc_threads, request_timeout)
[docs]def wait_and_shutdown_server():
r""" Block until all client have been shutdowned, and further shutdown the
server on the current process and destroy all RPC connections.
"""
current_context = get_context()
if current_context is None:
logging.warning("'wait_and_shutdown_server': try to shutdown server when "
"the current process has not been initialized as a server.")
return
if not current_context.is_server():
raise RuntimeError(f"'wait_and_shutdown_server': role type of "
f"the current process context is not a server, "
f"got {current_context.role}.")
global _dist_server
_dist_server.wait_for_exit()
_dist_server.shutdown()
_dist_server = None
barrier()
shutdown_rpc()
def _call_func_on_server(func, *args, **kwargs):
r""" A callee entry for remote requests on the server side.
"""
if not callable(func):
logging.warning(f"'_call_func_on_server': receive a non-callable "
f"function target {func}")
return None
server = get_server()
if hasattr(server, func.__name__):
return func(server, *args, **kwargs)
return func(*args, **kwargs)