"""
Time measurement and profiling utilities for distributed PyTorch applications.
This module provides comprehensive timing functionality for distributed training scenarios,
including context managers for timing code blocks, gathering timing data across processes,
and visualization tools for analyzing performance metrics.
The main components include:
- TimeManager: Core timing functionality with context managers
- GatheredTime: Analysis and visualization of distributed timing data
- ProfiledTime: Container for multiple GatheredTime objects with serialization support
"""
import time
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import field, dataclass
from typing import Dict, List, Tuple, Optional
import numpy as np
import torch
import torch.distributed as dist
from hbutils.reflection import context
from hbutils.string import plural_word
from matplotlib import pyplot as plt
from ..distribution import gather
from ..utils import Stack
_TIMER_STACK_NAME = 'timer_stack'
def _get_timer_stack() -> Stack:
"""
Get the current timer stack from the context.
This function retrieves the timer stack from the current execution context,
creating a new stack if none exists. The timer stack manages active
TimeManager instances for nested timing operations.
:return: The timer stack instance.
:rtype: Stack
"""
timer_stack = context().get(_TIMER_STACK_NAME, None) or Stack()
return timer_stack
[docs]
@dataclass
class TimeManager:
"""
A comprehensive time measurement manager for tracking execution times.
This class provides functionality to measure and record execution times for different
operations, with support for distributed environments and context management.
It maintains timing records for named operations and provides context managers
for convenient timing measurement.
:param records: Dictionary storing timing records for different operations.
:type records: Dict[str, List[float]]
Example::
>>> tm = TimeManager()
>>> with tm.timer('operation1'):
... time.sleep(0.1)
>>> times = tm._get_time('operation1')
>>> len(times) == 1
True
"""
records: Dict[str, List[float]] = field(default_factory=dict)
[docs]
def _append_time(self, name: str, secs: float):
"""
Append a timing record for the specified operation.
This method adds a new timing measurement to the records for the given
operation name. If this is the first measurement for the operation,
a new list is created.
:param name: The name of the operation being timed.
:type name: str
:param secs: The execution time in seconds.
:type secs: float
"""
if name not in self.records:
self.records[name] = []
self.records[name].append(secs)
[docs]
def _get_time(self, name: str) -> List[float]:
"""
Get all timing records for a specific operation.
Retrieves the complete list of timing measurements recorded for the
specified operation name. Returns an empty list if no measurements
have been recorded for this operation.
:param name: The name of the operation.
:type name: str
:return: List of timing records in seconds.
:rtype: List[float]
"""
if name not in self.records:
return []
else:
return self.records[name]
[docs]
def _get_time_torch(self, name: str) -> torch.Tensor:
"""
Get timing records as a PyTorch tensor.
Converts the timing records for the specified operation into a
PyTorch tensor with float32 dtype. This is useful for mathematical
operations and distributed gathering.
:param name: The name of the operation.
:type name: str
:return: Timing records as a float32 tensor.
:rtype: torch.Tensor
"""
return torch.tensor(self._get_time(name), dtype=torch.float32)
[docs]
def _get_time_with_rank(self, name: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get timing records along with corresponding rank information.
Returns timing data paired with rank information for distributed
environments. Each timing measurement is associated with the current
process rank.
:param name: The name of the operation.
:type name: str
:return: Tuple of (times, ranks) tensors.
:rtype: Tuple[torch.Tensor, torch.Tensor]
"""
t = self._get_time_torch(name)
r = torch.ones_like(t, dtype=torch.int32) * dist.get_rank()
return t, r
[docs]
@contextmanager
def timer(self, name: str):
"""
Context manager for timing code execution.
This context manager measures the execution time of the code block
within it and automatically records the timing under the specified name.
It uses high-precision time measurement and is suitable for both
short and long-running operations.
:param name: The name to associate with this timing measurement.
:type name: str
:yields: None
Example::
>>> tm = TimeManager()
>>> with tm.timer('my_operation'):
... # Your code here
... pass
"""
start_time = time.time()
yield
self._append_time(name, time.time() - start_time)
[docs]
def clear(self):
"""
Clear all timing records.
Removes all recorded timing data from this TimeManager instance,
effectively resetting it to a clean state.
"""
self.records.clear()
[docs]
@contextmanager
def enable_timer(self):
"""
Context manager to enable this TimeManager in the timer stack.
This method pushes the current TimeManager instance onto the timer stack,
making it available for use by timer decorators and context managers
throughout the nested execution context. The TimeManager is automatically
removed from the stack when the context exits.
:yields: None
Example::
>>> tm = TimeManager()
>>> with tm.enable_timer():
... # Now timer decorators and contexts will use this TimeManager
... with timer('operation'):
... pass
"""
timer_stack = _get_timer_stack()
timer_stack.push(self)
try:
with context().vars(**{_TIMER_STACK_NAME: timer_stack}):
yield
finally:
timer_stack.pop()
[docs]
def gather(self, dst: Optional[int] = None) -> Optional['ProfiledTime']:
"""
Gather timing data from all processes in a distributed environment.
Collects timing measurements from all processes in the distributed group
and returns them as GatheredTime objects for analysis. This is useful
for understanding performance characteristics across the entire distributed
training setup.
:param dst: Destination rank to gather data to. If None, gathers to all ranks.
:type dst: Optional[int]
:return: Dictionary of gathered timing data, or None if not the destination rank.
:rtype: Optional[ProfiledTime]
Example::
>>> tm = TimeManager()
>>> # After recording some times...
>>> gathered = tm.gather(dst=0) # Gather to rank 0
"""
is_gather_dst = dst is None or dst == dist.get_rank()
retval = {} if is_gather_dst else None
for key in self.records:
times, ranks = self._get_time_with_rank(key)
gathered_times = gather(times, dst=dst)
gathered_ranks = gather(ranks, dst=dst)
if is_gather_dst:
retval[key] = GatheredTime(
times=gathered_times,
ranks=gathered_ranks,
)
if retval is not None:
retval = ProfiledTime(retval)
return retval
[docs]
@dataclass
class ProfiledTime(Mapping):
"""
Container for multiple GatheredTime objects with serialization support.
This class acts as a mapping container for gathered timing data from multiple
operations, providing serialization capabilities for saving and loading
profiling results. It implements the Mapping interface to allow dictionary-like
access to the timing data.
:param records: Dictionary mapping operation names to GatheredTime objects.
:type records: Dict[str, GatheredTime]
Example::
>>> pt = ProfiledTime({'op1': gathered_time1, 'op2': gathered_time2})
>>> pt['op1'].mean() # Access timing data for operation 'op1'
>>> pt.save('profile.pt') # Save profiling data
>>> loaded_pt = ProfiledTime.load('profile.pt') # Load profiling data
"""
records: Dict[str, 'GatheredTime'] = None
def __post_init__(self):
"""
Initialize the records dictionary if not provided.
Sets up an empty dictionary for records if none was provided during
initialization, ensuring the object is always in a valid state.
"""
self.records = self.records or {}
[docs]
def __getitem__(self, __key: str):
"""
Get a GatheredTime object by operation name.
Provides dictionary-like access to the timing data for specific operations.
:param __key: The operation name to retrieve timing data for.
:type __key: str
:return: The GatheredTime object for the specified operation.
:rtype: GatheredTime
"""
return self.records[__key]
[docs]
def __len__(self) -> int:
"""
Get the number of operations with timing data.
:return: Number of operations in the records.
:rtype: int
"""
return len(self.records)
[docs]
def __iter__(self):
"""
Iterate over operation names.
:return: Iterator over operation names.
"""
yield from self.records
[docs]
def save(self, file):
"""
Save the profiling data to a file.
Serializes the timing records to a file using PyTorch's save functionality,
allowing the profiling data to be persisted and loaded later for analysis.
:param file: File path or file-like object to save to.
:type file: str or file-like object
Example::
>>> pt = ProfiledTime({'op1': gathered_time})
>>> pt.save('my_profile.pt')
"""
torch.save(self.records, file)
[docs]
@classmethod
def load(cls, file):
"""
Load profiling data from a file.
Creates a new ProfiledTime instance from previously saved timing data,
allowing for analysis of profiling results from previous runs.
:param file: File path or file-like object to load from.
:type file: str or file-like object
:return: New ProfiledTime instance with loaded data.
:rtype: ProfiledTime
Example::
>>> pt = ProfiledTime.load('my_profile.pt')
>>> print(pt['op1'].mean()) # Analyze loaded timing data
"""
records = torch.load(file, map_location='cpu', weights_only=False)
return cls(records=records)
[docs]
@dataclass
class GatheredTime:
"""
Container for timing data gathered from multiple processes.
This class provides analysis and visualization capabilities for timing data
collected from distributed training environments. It maintains timing
measurements along with their corresponding process ranks, enabling
detailed performance analysis across the distributed system.
:param times: Tensor containing timing measurements.
:type times: torch.Tensor
:param ranks: Tensor containing corresponding rank information.
:type ranks: torch.Tensor
Example::
>>> times = torch.tensor([0.1, 0.2, 0.15])
>>> ranks = torch.tensor([0, 1, 0])
>>> gt = GatheredTime(times, ranks)
>>> gt.mean()
0.15...
"""
times: torch.Tensor
ranks: torch.Tensor
[docs]
def get_rank(self, *ranks: int) -> 'GatheredTime':
"""
Filter timing data for specific ranks.
Creates a new GatheredTime instance containing only the timing data
from the specified process ranks. This is useful for analyzing
performance characteristics of specific processes or comparing
performance across different ranks.
:param ranks: Rank numbers to filter for.
:type ranks: int
:return: New GatheredTime instance with filtered data.
:rtype: GatheredTime
Example::
>>> gt = GatheredTime(torch.tensor([0.1, 0.2]), torch.tensor([0, 1]))
>>> rank0_data = gt.get_rank(0)
"""
mask = torch.zeros_like(self.ranks, dtype=torch.bool, device=self.ranks.device)
for rank in ranks:
mask |= (self.ranks == rank)
return GatheredTime(
times=self.times[mask],
ranks=self.ranks[mask],
)
[docs]
def __getitem__(self, item) -> 'GatheredTime':
"""
Get a subset of the gathered time data using indexing.
Supports standard Python indexing and slicing operations to extract
subsets of the timing data while maintaining the correspondence
between times and ranks.
:param item: Index or slice to apply to the data.
:return: New GatheredTime instance with indexed data.
:rtype: GatheredTime
"""
return GatheredTime(
times=self.times[item],
ranks=self.ranks[item],
)
[docs]
def __bool__(self):
"""
Check if the gathered time data is non-empty.
Returns True if there are timing measurements available, False if
the data structure is empty.
:return: True if there is timing data, False otherwise.
:rtype: bool
"""
return self.ranks.numel() > 0
[docs]
def sum(self) -> float:
"""
Calculate the sum of all timing measurements.
Computes the total time across all measurements, which can be useful
for understanding the cumulative time spent on an operation across
all processes.
:return: Sum of all times in seconds.
:rtype: float
"""
return self.times.sum().detach().cpu().item()
[docs]
def count(self) -> int:
"""
Get the total number of timing measurements.
Returns the total count of timing measurements across all ranks,
which indicates how many times the measured operation was executed.
:return: Number of timing measurements.
:rtype: int
"""
return self.times.shape[0]
[docs]
def mean(self) -> float:
"""
Calculate the mean of all timing measurements.
Computes the average execution time across all measurements and ranks,
providing a central tendency measure for the operation performance.
:return: Mean time in seconds.
:rtype: float
"""
return self.times.mean().detach().cpu().item()
[docs]
def std(self) -> float:
"""
Calculate the standard deviation of timing measurements.
Computes the standard deviation to measure the variability in
execution times, which can indicate performance consistency
across processes and executions.
:return: Standard deviation in seconds.
:rtype: float
"""
return self.times.std().detach().cpu().item()
[docs]
def rank_count(self) -> int:
"""
Get the number of unique ranks in the data.
Returns the count of distinct process ranks represented in the
gathered timing data, indicating how many processes contributed
measurements.
:return: Number of unique ranks.
:rtype: int
"""
return torch.unique(self.ranks).numel()
[docs]
def hist(self, bins: Optional[int] = None, separate_ranks: bool = False,
alpha: float = 0.7, title: Optional[str] = None, ax=None, **kwargs):
"""
Plot histogram of timing distribution.
Creates a histogram visualization of the timing data distribution,
with options to show all ranks together or separately. This is useful
for understanding the performance characteristics and identifying
outliers or patterns in execution times.
:param bins: Number of histogram bins. If None, uses matplotlib default.
:type bins: Optional[int]
:param separate_ranks: Whether to show each rank separately.
:type separate_ranks: bool
:param alpha: Transparency level when separate_ranks=True.
:type alpha: float
:param title: Title for the plot. Defaults to 'Time Distribution'.
:type title: Optional[str]
:param ax: Matplotlib axes object for plotting. If None, uses current axes.
:type ax: Optional[matplotlib.axes.Axes]
:param kwargs: Additional arguments passed to matplotlib hist function.
:type kwargs: dict
Example::
>>> import matplotlib.pyplot as plt
>>> gt = GatheredTime(torch.tensor([0.1, 0.2, 0.15]), torch.tensor([0, 1, 0]))
>>> fig, ax = plt.subplots()
>>> gt.hist(ax=ax, bins=10)
"""
ax = ax or plt.gca()
title = title or 'Time Distribution'
times_np = self.times.detach().cpu().numpy()
ranks_np = self.ranks.detach().cpu().numpy()
unique_ranks = np.unique(ranks_np)
if not separate_ranks:
# Show all ranks together
ax.hist(times_np, bins=bins, alpha=1.0, **kwargs)
ax.set_title(f'{title}\n'
f'(All {plural_word(self.rank_count(), "Rank")}, n={len(times_np)}, '
f'mean={self.mean():.3g}s, std={self.std():.3g}s)')
else:
# Show each rank separately
for i, rank in enumerate(unique_ranks):
rank_mask = ranks_np == rank
rank_times = times_np[rank_mask]
ax.hist(
rank_times, bins=bins, alpha=alpha,
label=f'Rank #{rank} (n={len(rank_times)}, '
f'mean={rank_times.mean():.2g}s, std={rank_times.std():.2g}s)',
**kwargs
)
ax.set_title(f'{title} by {plural_word(self.rank_count(), "Rank")}\n'
f'(mean={self.mean():.3g}s, std={self.std():.3g}s)')
ax.legend()
ax.set_xlabel(f'Time (seconds)')
ax.set_ylabel(f'Frequency')
ax.grid()