diff --git a/ltx_video/ltxv.py b/ltx_video/ltxv.py index 7be215b..e0decbb 100644 --- a/ltx_video/ltxv.py +++ b/ltx_video/ltxv.py @@ -154,6 +154,8 @@ class LTXV: mixed_precision_transformer = False ): + if dtype == torch.float16: + dtype = torch.bfloat16 self.mixed_precision_transformer = mixed_precision_transformer self.distilled = any("lora" in name for name in model_filepath) model_filepath = [name for name in model_filepath if not "lora" in name ]