graphlearn_torch.utils

common

ensure_dir(dir_path: str)[source]
merge_dict(in_dict: Dict[Any, Any], out_dict: Dict[Any, Any])[source]
get_free_port(host: str = '127.0.0.1') int[source]
id2idx_v2(gid, book)[source]
index_select(data, index)[source]
merge_hetero_sampler_output(in_sample: Any, out_sample: Any, device)[source]
format_hetero_sampler_output(in_sample: Any)[source]

device

get_available_device(device: device | None = None) device[source]

Get an available device. If the input device is not None, it will be returened directly. Otherwise an available device will be choosed ( current cuda device will be preferred if available).

assign_device()[source]

Assign an device to use, the cuda device will be preferred if available.

ensure_device(device: device)[source]

Make sure that current cuda kernel corresponds to the assigned device.

exit_status

python_exit_status = False

Whether Python is shutting down. This flag is guaranteed to be set before the Python core library resources are freed, but Python may already be exiting for some time when this is set.

Hook to set this flag is _set_python_exit_flag, and is same as used in Pytorch’s dataloader: https://github.com/pytorch/pytorch/blob/f1a6f32b72b7c2b73277f89bbf7e7459a400d80a/torch/utils/data/_utils/__init__.py

mixin

class CastMixin[source]

Bases: object

This class is same as PyG’s CastMixin: https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/utils/mixin.py

classmethod cast(*args, **kwargs)[source]

singleton

singleton(cls)[source]

Singleton class decorator.

tensor

tensor_equal_with_device(lhs: Tensor, rhs: Tensor)[source]

Check whether the data and device of two tensors are same.

id2idx(ids: List[int] | Tensor)[source]

Get tensor of mapping from id to its original index.

convert_to_tensor(data: Any, dtype: dtype | None = None)[source]

Convert the input data to a tensor based type.

apply_to_all_tensor(data: Any, tensor_method, *args, **kwargs)[source]

Apply the specified method to all tensors contained by the input data recursively.

share_memory(data: Any)[source]

Share memory for all tensors contained by the input data.

squeeze(data: Any)[source]

Squeeze all tensors contained by the input data.

units

parse_size(sz) int[source]