diff --git a/generate.py b/generate.py index 739de46..c567440 100644 --- a/generate.py +++ b/generate.py @@ -949,7 +949,8 @@ def generate(args): if args.enable_teacache: wan_t2v.__class__.generate = t2v_generate wan_t2v.model.__class__.enable_teacache = True - wan_t2v.model.__class__.forward = teacache_forward + if args.ulysses_size * args.ring_size == 1: # not conflict with fsdp + wan_t2v.model.__class__.forward = teacache_forward wan_t2v.model.__class__.cnt = 0 wan_t2v.model.__class__.num_steps = args.sample_steps*2 wan_t2v.model.__class__.teacache_thresh = args.teacache_thresh