Source code for tprofiler.distribution.gather

"""
Distributed tensor gathering utilities for PyTorch.

This module provides utilities for gathering tensors across multiple processes in a distributed
PyTorch environment. It handles tensors of different sizes along the concatenation dimension
by using padding and unpadding strategies.
"""

from typing import Optional

import torch
import torch.distributed as dist


[docs] def gather( tensor: torch.Tensor, dim: int = 0, dst: Optional[int] = None ) -> Optional[torch.Tensor]: """ Gather tensors from all processes in a distributed environment. This function collects tensors from all processes and concatenates them along the specified dimension. It handles tensors of different sizes by padding them to the same size before gathering and then removing the padding after concatenation. :param tensor: The tensor to gather from the current process. :type tensor: torch.Tensor :param dim: The dimension along which to concatenate the gathered tensors. Defaults to 0. :type dim: int :param dst: The destination rank that should receive the gathered result. If None, all ranks receive the result (all_gather mode). If specified, only the destination rank receives the result (gather mode). :type dst: Optional[int] :return: The concatenated tensor from all processes if this rank should receive the result, None otherwise. In non-distributed environments, returns the input tensor unchanged. :rtype: Optional[torch.Tensor] Example:: >>> # All-gather mode: all ranks get the result >>> local_tensor = torch.tensor([1, 2, 3]) >>> result = gather(local_tensor, dim=0) >>> # result will be concatenation of tensors from all ranks >>> # Gather mode: only rank 0 gets the result >>> local_tensor = torch.tensor([[1, 2], [3, 4]]) >>> result = gather(local_tensor, dim=0, dst=0) >>> # Only rank 0 will have the concatenated result, others get None """ if not dist.is_initialized(): return tensor world_size = dist.get_world_size() rank = dist.get_rank() if world_size == 1: return tensor # 1. Collect the sizes of tensors along the concatenation dimension from all ranks concat_dim_size = torch.tensor([tensor.shape[dim]], dtype=torch.long, device=tensor.device) all_concat_dim_sizes = [torch.zeros_like(concat_dim_size) for _ in range(world_size)] dist.all_gather(all_concat_dim_sizes, concat_dim_size) # Convert to python list for easier subsequent use all_concat_dim_sizes = [size.item() for size in all_concat_dim_sizes] # 2. To use gather, need to pad all tensors to the same size max_concat_dim_size = max(all_concat_dim_sizes) # Create padded tensor padded_shape = list(tensor.shape) padded_shape[dim] = max_concat_dim_size padded_tensor = torch.zeros(padded_shape, dtype=tensor.dtype, device=tensor.device) # Copy original tensor into padded tensor slices = [slice(None)] * tensor.ndim slices[dim] = slice(0, tensor.shape[dim]) padded_tensor[tuple(slices)] = tensor # 3. Execute gather operation if dst is None: # all_gather mode: all ranks get the result gathered_tensors = [torch.zeros_like(padded_tensor) for _ in range(world_size)] dist.all_gather(gathered_tensors, padded_tensor) should_return_result = True else: # gather mode: only target rank gets the result if rank == dst: gathered_tensors = [torch.zeros_like(padded_tensor) for _ in range(world_size)] else: gathered_tensors = None dist.gather(padded_tensor, gathered_tensors, dst=dst) should_return_result = (rank == dst) # 4. If current rank should return result, remove padding and concatenate if should_return_result: # Remove padding unpadded_tensors = [] for i, gathered_tensor in enumerate(gathered_tensors): slices = [slice(None)] * gathered_tensor.ndim slices[dim] = slice(0, all_concat_dim_sizes[i]) unpadded_tensor = gathered_tensor[tuple(slices)] unpadded_tensors.append(unpadded_tensor) # Concatenate along specified dimension result = torch.cat(unpadded_tensors, dim=dim) return result else: return None