tprofiler.time.manage
Time measurement and profiling utilities for distributed PyTorch applications.
This module provides comprehensive timing functionality for distributed training scenarios, including context managers for timing code blocks, gathering timing data across processes, and visualization tools for analyzing performance metrics.
The main components include:
TimeManager: Core timing functionality with context managers
GatheredTime: Analysis and visualization of distributed timing data
ProfiledTime: Container for multiple GatheredTime objects with serialization support
TimeManager
- class tprofiler.time.manage.TimeManager(records: ~typing.Dict[str, ~typing.List[float]] = <factory>)[source]
A comprehensive time measurement manager for tracking execution times.
This class provides functionality to measure and record execution times for different operations, with support for distributed environments and context management. It maintains timing records for named operations and provides context managers for convenient timing measurement.
- Parameters:
records (Dict[str, List[float]]) – Dictionary storing timing records for different operations.
Example:
>>> tm = TimeManager() >>> with tm.timer('operation1'): ... time.sleep(0.1) >>> times = tm._get_time('operation1') >>> len(times) == 1 True
- _append_time(name: str, secs: float)[source]
Append a timing record for the specified operation.
This method adds a new timing measurement to the records for the given operation name. If this is the first measurement for the operation, a new list is created.
- Parameters:
name (str) – The name of the operation being timed.
secs (float) – The execution time in seconds.
- _get_time(name: str) List[float][source]
Get all timing records for a specific operation.
Retrieves the complete list of timing measurements recorded for the specified operation name. Returns an empty list if no measurements have been recorded for this operation.
- Parameters:
name (str) – The name of the operation.
- Returns:
List of timing records in seconds.
- Return type:
List[float]
- _get_time_torch(name: str) Tensor[source]
Get timing records as a PyTorch tensor.
Converts the timing records for the specified operation into a PyTorch tensor with float32 dtype. This is useful for mathematical operations and distributed gathering.
- Parameters:
name (str) – The name of the operation.
- Returns:
Timing records as a float32 tensor.
- Return type:
torch.Tensor
- _get_time_with_rank(name: str) Tuple[Tensor, Tensor][source]
Get timing records along with corresponding rank information.
Returns timing data paired with rank information for distributed environments. Each timing measurement is associated with the current process rank.
- Parameters:
name (str) – The name of the operation.
- Returns:
Tuple of (times, ranks) tensors.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- clear()[source]
Clear all timing records.
Removes all recorded timing data from this TimeManager instance, effectively resetting it to a clean state.
- enable_timer()[source]
Context manager to enable this TimeManager in the timer stack.
This method pushes the current TimeManager instance onto the timer stack, making it available for use by timer decorators and context managers throughout the nested execution context. The TimeManager is automatically removed from the stack when the context exits.
- Yields:
None
Example:
>>> tm = TimeManager() >>> with tm.enable_timer(): ... # Now timer decorators and contexts will use this TimeManager ... with timer('operation'): ... pass
- gather(dst: int | None = None) ProfiledTime | None[source]
Gather timing data from all processes in a distributed environment.
Collects timing measurements from all processes in the distributed group and returns them as GatheredTime objects for analysis. This is useful for understanding performance characteristics across the entire distributed training setup.
- Parameters:
dst (Optional[int]) – Destination rank to gather data to. If None, gathers to all ranks.
- Returns:
Dictionary of gathered timing data, or None if not the destination rank.
- Return type:
Optional[ProfiledTime]
Example:
>>> tm = TimeManager() >>> # After recording some times... >>> gathered = tm.gather(dst=0) # Gather to rank 0
- timer(name: str)[source]
Context manager for timing code execution.
This context manager measures the execution time of the code block within it and automatically records the timing under the specified name. It uses high-precision time measurement and is suitable for both short and long-running operations.
- Parameters:
name (str) – The name to associate with this timing measurement.
- Yields:
None
Example:
>>> tm = TimeManager() >>> with tm.timer('my_operation'): ... # Your code here ... pass
ProfiledTime
- class tprofiler.time.manage.ProfiledTime(records: Dict[str, GatheredTime] | None = None)[source]
Container for multiple GatheredTime objects with serialization support.
This class acts as a mapping container for gathered timing data from multiple operations, providing serialization capabilities for saving and loading profiling results. It implements the Mapping interface to allow dictionary-like access to the timing data.
- Parameters:
records (Dict[str, GatheredTime]) – Dictionary mapping operation names to GatheredTime objects.
Example:
>>> pt = ProfiledTime({'op1': gathered_time1, 'op2': gathered_time2}) >>> pt['op1'].mean() # Access timing data for operation 'op1' >>> pt.save('profile.pt') # Save profiling data >>> loaded_pt = ProfiledTime.load('profile.pt') # Load profiling data
- __getitem__(_ProfiledTime__key: str)[source]
Get a GatheredTime object by operation name.
Provides dictionary-like access to the timing data for specific operations.
- Parameters:
__key (str) – The operation name to retrieve timing data for.
- Returns:
The GatheredTime object for the specified operation.
- Return type:
- __len__() int[source]
Get the number of operations with timing data.
- Returns:
Number of operations in the records.
- Return type:
int
- classmethod load(file)[source]
Load profiling data from a file.
Creates a new ProfiledTime instance from previously saved timing data, allowing for analysis of profiling results from previous runs.
- Parameters:
file (str or file-like object) – File path or file-like object to load from.
- Returns:
New ProfiledTime instance with loaded data.
- Return type:
Example:
>>> pt = ProfiledTime.load('my_profile.pt') >>> print(pt['op1'].mean()) # Analyze loaded timing data
- save(file)[source]
Save the profiling data to a file.
Serializes the timing records to a file using PyTorch’s save functionality, allowing the profiling data to be persisted and loaded later for analysis.
- Parameters:
file (str or file-like object) – File path or file-like object to save to.
Example:
>>> pt = ProfiledTime({'op1': gathered_time}) >>> pt.save('my_profile.pt')
GatheredTime
- class tprofiler.time.manage.GatheredTime(times: Tensor, ranks: Tensor)[source]
Container for timing data gathered from multiple processes.
This class provides analysis and visualization capabilities for timing data collected from distributed training environments. It maintains timing measurements along with their corresponding process ranks, enabling detailed performance analysis across the distributed system.
- Parameters:
times (torch.Tensor) – Tensor containing timing measurements.
ranks (torch.Tensor) – Tensor containing corresponding rank information.
Example:
>>> times = torch.tensor([0.1, 0.2, 0.15]) >>> ranks = torch.tensor([0, 1, 0]) >>> gt = GatheredTime(times, ranks) >>> gt.mean() 0.15...
- __bool__()[source]
Check if the gathered time data is non-empty.
Returns True if there are timing measurements available, False if the data structure is empty.
- Returns:
True if there is timing data, False otherwise.
- Return type:
bool
- __getitem__(item) GatheredTime[source]
Get a subset of the gathered time data using indexing.
Supports standard Python indexing and slicing operations to extract subsets of the timing data while maintaining the correspondence between times and ranks.
- Parameters:
item – Index or slice to apply to the data.
- Returns:
New GatheredTime instance with indexed data.
- Return type:
- count() int[source]
Get the total number of timing measurements.
Returns the total count of timing measurements across all ranks, which indicates how many times the measured operation was executed.
- Returns:
Number of timing measurements.
- Return type:
int
- get_rank(*ranks: int) GatheredTime[source]
Filter timing data for specific ranks.
Creates a new GatheredTime instance containing only the timing data from the specified process ranks. This is useful for analyzing performance characteristics of specific processes or comparing performance across different ranks.
- Parameters:
ranks (int) – Rank numbers to filter for.
- Returns:
New GatheredTime instance with filtered data.
- Return type:
Example:
>>> gt = GatheredTime(torch.tensor([0.1, 0.2]), torch.tensor([0, 1])) >>> rank0_data = gt.get_rank(0)
- hist(bins: int | None = None, separate_ranks: bool = False, alpha: float = 0.7, title: str | None = None, ax=None, **kwargs)[source]
Plot histogram of timing distribution.
Creates a histogram visualization of the timing data distribution, with options to show all ranks together or separately. This is useful for understanding the performance characteristics and identifying outliers or patterns in execution times.
- Parameters:
bins (Optional[int]) – Number of histogram bins. If None, uses matplotlib default.
separate_ranks (bool) – Whether to show each rank separately.
alpha (float) – Transparency level when separate_ranks=True.
title (Optional[str]) – Title for the plot. Defaults to ‘Time Distribution’.
ax (Optional[matplotlib.axes.Axes]) – Matplotlib axes object for plotting. If None, uses current axes.
kwargs (dict) – Additional arguments passed to matplotlib hist function.
Example:
>>> import matplotlib.pyplot as plt >>> gt = GatheredTime(torch.tensor([0.1, 0.2, 0.15]), torch.tensor([0, 1, 0])) >>> fig, ax = plt.subplots() >>> gt.hist(ax=ax, bins=10)
- mean() float[source]
Calculate the mean of all timing measurements.
Computes the average execution time across all measurements and ranks, providing a central tendency measure for the operation performance.
- Returns:
Mean time in seconds.
- Return type:
float
- rank_count() int[source]
Get the number of unique ranks in the data.
Returns the count of distinct process ranks represented in the gathered timing data, indicating how many processes contributed measurements.
- Returns:
Number of unique ranks.
- Return type:
int