tprofiler.distribution.gather

Distributed tensor gathering utilities for PyTorch.

This module provides utilities for gathering tensors across multiple processes in a distributed PyTorch environment. It handles tensors of different sizes along the concatenation dimension by using padding and unpadding strategies.

gather

tprofiler.distribution.gather.gather(tensor: Tensor, dim: int = 0, dst: int | None = None) Tensor | None[source]

Gather tensors from all processes in a distributed environment.

This function collects tensors from all processes and concatenates them along the specified dimension. It handles tensors of different sizes by padding them to the same size before gathering and then removing the padding after concatenation.

Parameters:
  • tensor (torch.Tensor) – The tensor to gather from the current process.

  • dim (int) – The dimension along which to concatenate the gathered tensors. Defaults to 0.

  • dst (Optional[int]) – The destination rank that should receive the gathered result. If None, all ranks receive the result (all_gather mode). If specified, only the destination rank receives the result (gather mode).

Returns:

The concatenated tensor from all processes if this rank should receive the result, None otherwise. In non-distributed environments, returns the input tensor unchanged.

Return type:

Optional[torch.Tensor]

Example:

>>> # All-gather mode: all ranks get the result
>>> local_tensor = torch.tensor([1, 2, 3])
>>> result = gather(local_tensor, dim=0)
>>> # result will be concatenation of tensors from all ranks

>>> # Gather mode: only rank 0 gets the result
>>> local_tensor = torch.tensor([[1, 2], [3, 4]])
>>> result = gather(local_tensor, dim=0, dst=0)
>>> # Only rank 0 will have the concatenated result, others get None