Source code for graphlearn_torch.distributed.dist_context

# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from enum import Enum
from typing import Optional


[docs]class DistRole(Enum): r""" Role types for distributed context groups. """ WORKER = 1 # As a worker in a distributed worker group (non-server mode) SERVER = 2 # As a server in a distributed server group (server-client mode) CLIENT = 3 # As a client in a distributed client group (server-client mode)
_DEFAULT_WORKER_GROUP = '_default_worker' _DEFAULT_SERVER_GROUP = '_default_server' _DEFAULT_CLIENT_GROUP = '_default_client'
[docs]class DistContext(object): r""" Distributed context info of the current process. Args: role (DistRole): The role type of the current context group. group_name (str): A unique name of the current role group. world_size (int): The number of processes in the current role group. rank (int): The current process rank within the current role group. global_world_size (int): The total number of processes in all role groups. global_rank (int): The current process rank within all role groups. """ def __init__(self, role: DistRole, group_name: str, world_size: int, rank: int, global_world_size: int, global_rank: int): assert world_size > 0 and rank in range(world_size) assert global_world_size > 0 and global_rank in range(global_world_size) assert world_size <= global_world_size self.role = role self.group_name = group_name self.world_size = world_size self.rank = rank self.global_world_size = global_world_size self.global_rank = global_rank def __repr__(self) -> str: cls = self.__class__.__name__ info = [] for key, value in self.__dict__.items(): info.append(f"{key}: {value}") info = ", ".join(info) return f"{cls}({info})" def __eq__(self, obj): if not isinstance(obj, DistContext): return False for key, value in self.__dict__.items(): if value != obj.__dict__[key]: return False return True
[docs] def is_worker(self) -> bool: return self.role == DistRole.WORKER
[docs] def is_server(self) -> bool: return self.role == DistRole.SERVER
[docs] def is_client(self) -> bool: return self.role == DistRole.CLIENT
[docs] def num_servers(self) -> int: if self.role == DistRole.SERVER: return self.world_size if self.role == DistRole.CLIENT: return self.global_world_size - self.world_size return 0
[docs] def num_clients(self) -> int: if self.role == DistRole.CLIENT: return self.world_size if self.role == DistRole.SERVER: return self.global_world_size - self.world_size return 0
@property def worker_name(self) -> str: r""" Get worker name of the current process of this context. """ return f"{self.group_name}-{self.rank}"
_dist_context: DistContext = None r""" Distributed context on the current process. """
[docs]def get_context() -> DistContext: r""" Get distributed context info of the current process. """ return _dist_context
def _set_worker_context(world_size: int, rank: int, group_name: Optional[str] = None): r""" Set distributed context info as a non-server worker on the current process. """ global _dist_context _dist_context = DistContext( role=DistRole.WORKER, group_name=(group_name if group_name is not None else _DEFAULT_WORKER_GROUP), world_size=world_size, rank=rank, global_world_size=world_size, global_rank=rank ) def _set_server_context(num_servers: int, num_clients: int, server_rank: int, server_group_name: Optional[str] = None): r""" Set distributed context info as a server on the current process. """ assert num_servers > 0 and num_clients > 0 global _dist_context _dist_context = DistContext( role=DistRole.SERVER, group_name=(server_group_name if server_group_name is not None else _DEFAULT_SERVER_GROUP), world_size=num_servers, rank=server_rank, global_world_size=num_servers+num_clients, global_rank=server_rank ) def _set_client_context(num_servers: int, num_clients: int, client_rank: int, client_group_name: Optional[str] = None): r""" Set distributed context info as a client on the current process. """ assert num_servers > 0 and num_clients > 0 global _dist_context _dist_context = DistContext( role=DistRole.CLIENT, group_name=(client_group_name if client_group_name is not None else _DEFAULT_CLIENT_GROUP), world_size=num_clients, rank=client_rank, global_world_size=num_servers+num_clients, global_rank=num_servers+client_rank )
[docs]def init_worker_group(world_size: int, rank: int, group_name: Optional[str] = None): r""" Initialize a simple worker group on the current process, this method should be called only in a non-server distribution mode with a group of parallel workers. Args: world_size (int): Number of all processes participating in the distributed worker group. rank (int): Rank of the current process withing the distributed group (it should be a number between 0 and ``world_size``-1). group_name (str): A unique name of the distributed group that current process belongs to. If set to ``None``, a default name will be used. """ _set_worker_context(world_size, rank, group_name)