ModelParallelStrategy

class lightning.fabric.strategies.ModelParallelStrategy(parallelize_fn, data_parallel_size='auto', tensor_parallel_size='auto', save_distributed_checkpoint=True, process_group_backend=None, timeout=datetime.timedelta(seconds=1800))[source]

Bases: ParallelStrategy

Enables user-defined parallelism applied to a model.

Warning

This is an experimental feature.

Currently supports up to 2D parallelism. Specifically, it supports the combination of Fully Sharded Data-Parallel 2 (FSDP2) with Tensor Parallelism (DTensor). These PyTorch APIs are currently still experimental in PyTorch. Requires PyTorch 2.4 or newer.

Parameters
  • parallelize_fn (Callable[[TypeVar(TModel, bound= Module), DeviceMesh], TypeVar(TModel, bound= Module)]) – A function that applies parallelisms to a module. The strategy will provide the model and device mesh as input.

  • data_parallel_size (Union[Literal['auto'], int]) – The number of devices within a data-parallel group. Defaults to "auto", which sets this size to the number of nodes in the cluster.

  • tensor_parallel_size (Union[Literal['auto'], int]) – The number of devices within a tensor-parallel group. Defaults to "auto", which sets this size to the number of GPUs in a single node.

  • save_distributed_checkpoint (bool) – If True, each rank saves its shard of weights and optimizer states to a file. The checkpoint is a folder with as many files as the world size. If False, the full weights and optimizer states get assembled on rank 0 and saved to a single file.

_configure_launcher()[source]

Attach the launcher based on Strategy.

Return type

None

all_reduce(tensor, group=None, reduce_op='mean')[source]

Reduces the given tensor (e.g. across GPUs/processes).

Parameters
  • tensor (Tensor) – the tensor to sync and reduce

  • group (Optional[Any]) – the process group to reduce

  • reduce_op (Union[ReduceOp, str, None]) – the reduction operation. Defaults to ‘mean’. Can also be a string ‘sum’ or ReduceOp.

Return type

Tensor

barrier(*args, **kwargs)[source]

Synchronizes all processes which blocks processes until the whole group enters this function.

Parameters

name – an optional name to pass into barrier.

Return type

None

broadcast(obj, src=0)[source]

Broadcasts an object to all processes.

Parameters
  • obj (TypeVar(TBroadcast)) – the object to broadcast

  • src (int) – source rank

Return type

TypeVar(TBroadcast)

load_checkpoint(path, state=None, strict=True)[source]

Load the contents from a checkpoint and restore the state of the given objects.

Return type

Dict[str, Any]

module_init_context(empty_init=None)[source]

A context manager wrapping the model instantiation.

Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other patches to the model.

Parameters

empty_init (Optional[bool]) – Whether to initialize the model with empty weights (uninitialized memory). If None, the strategy will decide. Some strategies may not support all options.

Return type

ContextManager

module_to_device(module)[source]

Moves the model to the correct device.

Return type

None

save_checkpoint(path, state, storage_options=None, filter=None)[source]

Save model, optimizer, and other state to a checkpoint on disk.

If distributed checkpointing is enabled (default), the checkpoint gets saved as a directory containing one file per process, with model- and optimizer shards stored per file. Additionally, it creates a metadata file meta.pt with the rest of the user’s state (only saved from rank 0). If distributed checkpointing is disabled (save_distributed_checkpoint=False), the checkpoint will be written to a single file containing the weights, optimizer state and other metadata.

Return type

None

setup_environment()[source]

Setup any processes or distributed connections.

This must be called by the framework at the beginning of every process, before any distributed communication takes place.

Return type

None

setup_module(module)[source]

Performs setup for the model, e.g., by wrapping it by another class.

Return type

Module

property distributed_sampler_kwargs: Dict[str, Any]

Arguments for the DistributedSampler.

If this method is not defined, or it returns None, then the DistributedSampler will not be used.

property root_device: device

Returns the root device.