diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 7d9953151b0..e3cea98a7f4 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1468,6 +1468,8 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e if self.ddp_handler is not None: self.ddp_handler.register_comm_hook(model) elif self.distributed_type == DistributedType.TP: + if not model.supports_tp_plan: + raise NotImplementedError("Provided model does not support tensor parallelism") model.tensor_parallel(self.state.torch_tp_plugin.torch_device_mesh["tp"]) elif self.distributed_type == DistributedType.FSDP: # We need to fix the optimizer *before* sharding the model