mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
700 lines
23 KiB
Python
700 lines
23 KiB
Python
import getpass
|
|
import math
|
|
import os
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import requests
|
|
import torch
|
|
from einops import rearrange
|
|
from huggingface_hub import hf_hub_download, login
|
|
from PIL import ExifTags, Image
|
|
from safetensors.torch import load_file as load_sft
|
|
|
|
from .model import Flux, FluxLoraWrapper, FluxParams
|
|
from .modules.autoencoder import AutoEncoder, AutoEncoderParams
|
|
from .modules.conditioner import HFEmbedder
|
|
|
|
CHECKPOINTS_DIR = Path("checkpoints")
|
|
|
|
BFL_API_KEY = os.getenv("BFL_API_KEY")
|
|
|
|
|
|
def ensure_hf_auth():
|
|
hf_token = os.environ.get("HF_TOKEN")
|
|
if hf_token:
|
|
print("Trying to authenticate to HuggingFace with the HF_TOKEN environment variable.")
|
|
try:
|
|
login(token=hf_token)
|
|
print("Successfully authenticated with HuggingFace using HF_TOKEN")
|
|
return True
|
|
except Exception as e:
|
|
print(f"Warning: Failed to authenticate with HF_TOKEN: {e}")
|
|
|
|
if os.path.exists(os.path.expanduser("~/.cache/huggingface/token")):
|
|
print("Already authenticated with HuggingFace")
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def prompt_for_hf_auth():
|
|
try:
|
|
token = getpass.getpass("HF Token (hidden input): ").strip()
|
|
if not token:
|
|
print("No token provided. Aborting.")
|
|
return False
|
|
|
|
login(token=token)
|
|
print("Successfully authenticated!")
|
|
return True
|
|
except KeyboardInterrupt:
|
|
print("\nAuthentication cancelled by user.")
|
|
return False
|
|
except Exception as auth_e:
|
|
print(f"Authentication failed: {auth_e}")
|
|
print("Tip: You can also run 'huggingface-cli login' or set HF_TOKEN environment variable")
|
|
return False
|
|
|
|
|
|
def get_checkpoint_path(repo_id: str, filename: str, env_var: str) -> Path:
|
|
"""Get the local path for a checkpoint file, downloading if necessary."""
|
|
# if os.environ.get(env_var) is not None:
|
|
# local_path = os.environ[env_var]
|
|
# if os.path.exists(local_path):
|
|
# return Path(local_path)
|
|
|
|
# print(
|
|
# f"Trying to load model {repo_id}, {filename} from environment "
|
|
# f"variable {env_var}. But file {local_path} does not exist. "
|
|
# "Falling back to default location."
|
|
# )
|
|
|
|
# # Create a safe directory name from repo_id
|
|
# safe_repo_name = repo_id.replace("/", "_")
|
|
# checkpoint_dir = CHECKPOINTS_DIR / safe_repo_name
|
|
# checkpoint_dir.mkdir(exist_ok=True)
|
|
|
|
# local_path = checkpoint_dir / filename
|
|
|
|
local_path = filename
|
|
from mmgp import offload
|
|
|
|
if False:
|
|
print(f"Downloading {filename} from {repo_id} to {local_path}")
|
|
try:
|
|
ensure_hf_auth()
|
|
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=checkpoint_dir)
|
|
except Exception as e:
|
|
if "gated repo" in str(e).lower() or "restricted" in str(e).lower():
|
|
print(f"\nError: Cannot access {repo_id} -- this is a gated repository.")
|
|
|
|
# Try one more time to authenticate
|
|
if prompt_for_hf_auth():
|
|
# Retry the download after authentication
|
|
print("Retrying download...")
|
|
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=checkpoint_dir)
|
|
else:
|
|
print("Authentication failed or cancelled.")
|
|
print("You can also run 'huggingface-cli login' or set HF_TOKEN environment variable")
|
|
raise RuntimeError(f"Authentication required for {repo_id}")
|
|
else:
|
|
raise e
|
|
|
|
return local_path
|
|
|
|
|
|
def download_onnx_models_for_trt(model_name: str, trt_transformer_precision: str = "bf16") -> str | None:
|
|
"""Download ONNX models for TRT to our checkpoints directory"""
|
|
onnx_repo_map = {
|
|
"flux-dev": "black-forest-labs/FLUX.1-dev-onnx",
|
|
"flux-schnell": "black-forest-labs/FLUX.1-schnell-onnx",
|
|
"flux-dev-canny": "black-forest-labs/FLUX.1-Canny-dev-onnx",
|
|
"flux-dev-depth": "black-forest-labs/FLUX.1-Depth-dev-onnx",
|
|
"flux-dev-redux": "black-forest-labs/FLUX.1-Redux-dev-onnx",
|
|
"flux-dev-fill": "black-forest-labs/FLUX.1-Fill-dev-onnx",
|
|
"flux-dev-kontext": "black-forest-labs/FLUX.1-Kontext-dev-onnx",
|
|
}
|
|
|
|
if model_name not in onnx_repo_map:
|
|
return None # No ONNX repository required for this model
|
|
|
|
repo_id = onnx_repo_map[model_name]
|
|
safe_repo_name = repo_id.replace("/", "_")
|
|
onnx_dir = CHECKPOINTS_DIR / safe_repo_name
|
|
|
|
# Map of module names to their ONNX file paths (using specified precision)
|
|
onnx_file_map = {
|
|
"clip": "clip.opt/model.onnx",
|
|
"transformer": f"transformer.opt/{trt_transformer_precision}/model.onnx",
|
|
"transformer_data": f"transformer.opt/{trt_transformer_precision}/backbone.onnx_data",
|
|
"t5": "t5.opt/model.onnx",
|
|
"t5_data": "t5.opt/backbone.onnx_data",
|
|
"vae": "vae.opt/model.onnx",
|
|
}
|
|
|
|
# If all files exist locally, return the custom_onnx_paths format
|
|
if onnx_dir.exists():
|
|
all_files_exist = True
|
|
custom_paths = []
|
|
for module, onnx_file in onnx_file_map.items():
|
|
if module.endswith("_data"):
|
|
continue # Skip data files
|
|
local_path = onnx_dir / onnx_file
|
|
if not local_path.exists():
|
|
all_files_exist = False
|
|
break
|
|
custom_paths.append(f"{module}:{local_path}")
|
|
|
|
if all_files_exist:
|
|
print(f"ONNX models ready in {onnx_dir}")
|
|
return ",".join(custom_paths)
|
|
|
|
# If not all files exist, download them
|
|
print(f"Downloading ONNX models from {repo_id} to {onnx_dir}")
|
|
print(f"Using transformer precision: {trt_transformer_precision}")
|
|
onnx_dir.mkdir(exist_ok=True)
|
|
|
|
# Download all ONNX files
|
|
for module, onnx_file in onnx_file_map.items():
|
|
local_path = onnx_dir / onnx_file
|
|
if local_path.exists():
|
|
continue # Already downloaded
|
|
|
|
# Create parent directories
|
|
local_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
try:
|
|
print(f"Downloading {onnx_file}")
|
|
hf_hub_download(repo_id=repo_id, filename=onnx_file, local_dir=onnx_dir)
|
|
except Exception as e:
|
|
if "does not exist" in str(e).lower() or "not found" in str(e).lower():
|
|
continue
|
|
elif "gated repo" in str(e).lower() or "restricted" in str(e).lower():
|
|
print(f"Cannot access {repo_id} - requires license acceptance")
|
|
print("Please follow these steps:")
|
|
print(f" 1. Visit: https://huggingface.co/{repo_id}")
|
|
print(" 2. Log in to your HuggingFace account")
|
|
print(" 3. Accept the license terms and conditions")
|
|
print(" 4. Then retry this command")
|
|
raise RuntimeError(f"License acceptance required for {model_name}")
|
|
else:
|
|
# Re-raise other errors
|
|
raise
|
|
|
|
print(f"ONNX models ready in {onnx_dir}")
|
|
|
|
# Return the custom_onnx_paths format that TRT expects: "module1:path1,module2:path2"
|
|
# Note: Only return the actual module paths, not the data file
|
|
custom_paths = []
|
|
for module, onnx_file in onnx_file_map.items():
|
|
if module.endswith("_data"):
|
|
continue # Skip the data file in the return paths
|
|
full_path = onnx_dir / onnx_file
|
|
if full_path.exists():
|
|
custom_paths.append(f"{module}:{full_path}")
|
|
|
|
return ",".join(custom_paths)
|
|
|
|
|
|
def check_onnx_access_for_trt(model_name: str, trt_transformer_precision: str = "bf16") -> str | None:
|
|
"""Check ONNX access and download models for TRT - returns ONNX directory path"""
|
|
return download_onnx_models_for_trt(model_name, trt_transformer_precision)
|
|
|
|
|
|
def track_usage_via_api(name: str, n=1) -> None:
|
|
"""
|
|
Track usage of licensed models via the BFL API for commercial licensing compliance.
|
|
|
|
For more information on licensing BFL's models for commercial use and usage reporting,
|
|
see the README.md or visit: https://dashboard.bfl.ai/licensing/subscriptions?showInstructions=true
|
|
"""
|
|
assert BFL_API_KEY is not None, "BFL_API_KEY is not set"
|
|
|
|
model_slug_map = {
|
|
"flux-dev": "flux-1-dev",
|
|
"flux-dev-kontext": "flux-1-kontext-dev",
|
|
"flux-dev-fill": "flux-tools",
|
|
"flux-dev-depth": "flux-tools",
|
|
"flux-dev-canny": "flux-tools",
|
|
"flux-dev-canny-lora": "flux-tools",
|
|
"flux-dev-depth-lora": "flux-tools",
|
|
"flux-dev-redux": "flux-tools",
|
|
}
|
|
|
|
if name not in model_slug_map:
|
|
print(f"Skipping tracking usage for {name}, as it cannot be tracked. Please check the model name.")
|
|
return
|
|
|
|
model_slug = model_slug_map[name]
|
|
url = f"https://api.bfl.ai/v1/licenses/models/{model_slug}/usage"
|
|
headers = {"x-key": BFL_API_KEY, "Content-Type": "application/json"}
|
|
payload = {"number_of_generations": n}
|
|
|
|
response = requests.post(url, headers=headers, json=payload)
|
|
if response.status_code != 200:
|
|
raise Exception(f"Failed to track usage: {response.status_code} {response.text}")
|
|
else:
|
|
print(f"Successfully tracked usage for {name} with {n} generations")
|
|
|
|
|
|
def save_image(
|
|
nsfw_classifier,
|
|
name: str,
|
|
output_name: str,
|
|
idx: int,
|
|
x: torch.Tensor,
|
|
add_sampling_metadata: bool,
|
|
prompt: str,
|
|
nsfw_threshold: float = 0.85,
|
|
track_usage: bool = False,
|
|
) -> int:
|
|
fn = output_name.format(idx=idx)
|
|
print(f"Saving {fn}")
|
|
# bring into PIL format and save
|
|
x = x.clamp(-1, 1)
|
|
x = rearrange(x[0], "c h w -> h w c")
|
|
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
|
|
|
if nsfw_classifier is not None:
|
|
nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]
|
|
else:
|
|
nsfw_score = nsfw_threshold - 1.0
|
|
|
|
if nsfw_score < nsfw_threshold:
|
|
exif_data = Image.Exif()
|
|
if name in ["flux-dev", "flux-schnell"]:
|
|
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
|
|
else:
|
|
exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux"
|
|
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
|
|
exif_data[ExifTags.Base.Model] = name
|
|
if add_sampling_metadata:
|
|
exif_data[ExifTags.Base.ImageDescription] = prompt
|
|
img.save(fn, exif=exif_data, quality=95, subsampling=0)
|
|
if track_usage:
|
|
track_usage_via_api(name, 1)
|
|
idx += 1
|
|
else:
|
|
print("Your generated image may contain NSFW content.")
|
|
|
|
return idx
|
|
|
|
|
|
@dataclass
|
|
class ModelSpec:
|
|
params: FluxParams
|
|
ae_params: AutoEncoderParams
|
|
repo_id: str
|
|
repo_flow: str
|
|
repo_ae: str
|
|
lora_repo_id: str | None = None
|
|
lora_filename: str | None = None
|
|
|
|
|
|
configs = {
|
|
"flux-dev": ModelSpec(
|
|
repo_id="",
|
|
repo_flow="",
|
|
repo_ae="ckpts/flux_vae.safetensors",
|
|
params=FluxParams(
|
|
in_channels=64,
|
|
out_channels=64,
|
|
vec_in_dim=768,
|
|
context_in_dim=4096,
|
|
hidden_size=3072,
|
|
mlp_ratio=4.0,
|
|
num_heads=24,
|
|
depth=19,
|
|
depth_single_blocks=38,
|
|
axes_dim=[16, 56, 56],
|
|
theta=10_000,
|
|
qkv_bias=True,
|
|
guidance_embed=True,
|
|
),
|
|
ae_params=AutoEncoderParams(
|
|
resolution=256,
|
|
in_channels=3,
|
|
ch=128,
|
|
out_ch=3,
|
|
ch_mult=[1, 2, 4, 4],
|
|
num_res_blocks=2,
|
|
z_channels=16,
|
|
scale_factor=0.3611,
|
|
shift_factor=0.1159,
|
|
),
|
|
),
|
|
"flux-schnell": ModelSpec(
|
|
repo_id="black-forest-labs/FLUX.1-schnell",
|
|
repo_flow="",
|
|
repo_ae="ckpts/flux_vae.safetensors",
|
|
params=FluxParams(
|
|
in_channels=64,
|
|
out_channels=64,
|
|
vec_in_dim=768,
|
|
context_in_dim=4096,
|
|
hidden_size=3072,
|
|
mlp_ratio=4.0,
|
|
num_heads=24,
|
|
depth=19,
|
|
depth_single_blocks=38,
|
|
axes_dim=[16, 56, 56],
|
|
theta=10_000,
|
|
qkv_bias=True,
|
|
guidance_embed=False,
|
|
),
|
|
ae_params=AutoEncoderParams(
|
|
resolution=256,
|
|
in_channels=3,
|
|
ch=128,
|
|
out_ch=3,
|
|
ch_mult=[1, 2, 4, 4],
|
|
num_res_blocks=2,
|
|
z_channels=16,
|
|
scale_factor=0.3611,
|
|
shift_factor=0.1159,
|
|
),
|
|
),
|
|
"flux-dev-canny": ModelSpec(
|
|
repo_id="black-forest-labs/FLUX.1-Canny-dev",
|
|
repo_flow="",
|
|
repo_ae="ckpts/flux_vae.safetensors",
|
|
params=FluxParams(
|
|
in_channels=128,
|
|
out_channels=64,
|
|
vec_in_dim=768,
|
|
context_in_dim=4096,
|
|
hidden_size=3072,
|
|
mlp_ratio=4.0,
|
|
num_heads=24,
|
|
depth=19,
|
|
depth_single_blocks=38,
|
|
axes_dim=[16, 56, 56],
|
|
theta=10_000,
|
|
qkv_bias=True,
|
|
guidance_embed=True,
|
|
),
|
|
ae_params=AutoEncoderParams(
|
|
resolution=256,
|
|
in_channels=3,
|
|
ch=128,
|
|
out_ch=3,
|
|
ch_mult=[1, 2, 4, 4],
|
|
num_res_blocks=2,
|
|
z_channels=16,
|
|
scale_factor=0.3611,
|
|
shift_factor=0.1159,
|
|
),
|
|
),
|
|
"flux-dev-canny-lora": ModelSpec(
|
|
repo_id="black-forest-labs/FLUX.1-dev",
|
|
repo_flow="",
|
|
repo_ae="ckpts/flux_vae.safetensors",
|
|
lora_repo_id="black-forest-labs/FLUX.1-Canny-dev-lora",
|
|
lora_filename="flux1-canny-dev-lora.safetensors",
|
|
params=FluxParams(
|
|
in_channels=128,
|
|
out_channels=64,
|
|
vec_in_dim=768,
|
|
context_in_dim=4096,
|
|
hidden_size=3072,
|
|
mlp_ratio=4.0,
|
|
num_heads=24,
|
|
depth=19,
|
|
depth_single_blocks=38,
|
|
axes_dim=[16, 56, 56],
|
|
theta=10_000,
|
|
qkv_bias=True,
|
|
guidance_embed=True,
|
|
),
|
|
ae_params=AutoEncoderParams(
|
|
resolution=256,
|
|
in_channels=3,
|
|
ch=128,
|
|
out_ch=3,
|
|
ch_mult=[1, 2, 4, 4],
|
|
num_res_blocks=2,
|
|
z_channels=16,
|
|
scale_factor=0.3611,
|
|
shift_factor=0.1159,
|
|
),
|
|
),
|
|
"flux-dev-depth": ModelSpec(
|
|
repo_id="black-forest-labs/FLUX.1-Depth-dev",
|
|
repo_flow="",
|
|
repo_ae="ckpts/flux_vae.safetensors",
|
|
params=FluxParams(
|
|
in_channels=128,
|
|
out_channels=64,
|
|
vec_in_dim=768,
|
|
context_in_dim=4096,
|
|
hidden_size=3072,
|
|
mlp_ratio=4.0,
|
|
num_heads=24,
|
|
depth=19,
|
|
depth_single_blocks=38,
|
|
axes_dim=[16, 56, 56],
|
|
theta=10_000,
|
|
qkv_bias=True,
|
|
guidance_embed=True,
|
|
),
|
|
ae_params=AutoEncoderParams(
|
|
resolution=256,
|
|
in_channels=3,
|
|
ch=128,
|
|
out_ch=3,
|
|
ch_mult=[1, 2, 4, 4],
|
|
num_res_blocks=2,
|
|
z_channels=16,
|
|
scale_factor=0.3611,
|
|
shift_factor=0.1159,
|
|
),
|
|
),
|
|
"flux-dev-depth-lora": ModelSpec(
|
|
repo_id="black-forest-labs/FLUX.1-dev",
|
|
repo_flow="",
|
|
repo_ae="ckpts/flux_vae.safetensors",
|
|
lora_repo_id="black-forest-labs/FLUX.1-Depth-dev-lora",
|
|
lora_filename="flux1-depth-dev-lora.safetensors",
|
|
params=FluxParams(
|
|
in_channels=128,
|
|
out_channels=64,
|
|
vec_in_dim=768,
|
|
context_in_dim=4096,
|
|
hidden_size=3072,
|
|
mlp_ratio=4.0,
|
|
num_heads=24,
|
|
depth=19,
|
|
depth_single_blocks=38,
|
|
axes_dim=[16, 56, 56],
|
|
theta=10_000,
|
|
qkv_bias=True,
|
|
guidance_embed=True,
|
|
),
|
|
ae_params=AutoEncoderParams(
|
|
resolution=256,
|
|
in_channels=3,
|
|
ch=128,
|
|
out_ch=3,
|
|
ch_mult=[1, 2, 4, 4],
|
|
num_res_blocks=2,
|
|
z_channels=16,
|
|
scale_factor=0.3611,
|
|
shift_factor=0.1159,
|
|
),
|
|
),
|
|
"flux-dev-redux": ModelSpec(
|
|
repo_id="black-forest-labs/FLUX.1-Redux-dev",
|
|
repo_flow="",
|
|
repo_ae="ckpts/flux_vae.safetensors",
|
|
params=FluxParams(
|
|
in_channels=64,
|
|
out_channels=64,
|
|
vec_in_dim=768,
|
|
context_in_dim=4096,
|
|
hidden_size=3072,
|
|
mlp_ratio=4.0,
|
|
num_heads=24,
|
|
depth=19,
|
|
depth_single_blocks=38,
|
|
axes_dim=[16, 56, 56],
|
|
theta=10_000,
|
|
qkv_bias=True,
|
|
guidance_embed=True,
|
|
),
|
|
ae_params=AutoEncoderParams(
|
|
resolution=256,
|
|
in_channels=3,
|
|
ch=128,
|
|
out_ch=3,
|
|
ch_mult=[1, 2, 4, 4],
|
|
num_res_blocks=2,
|
|
z_channels=16,
|
|
scale_factor=0.3611,
|
|
shift_factor=0.1159,
|
|
),
|
|
),
|
|
"flux-dev-fill": ModelSpec(
|
|
repo_id="black-forest-labs/FLUX.1-Fill-dev",
|
|
repo_flow="",
|
|
repo_ae="ckpts/flux_vae.safetensors",
|
|
params=FluxParams(
|
|
in_channels=384,
|
|
out_channels=64,
|
|
vec_in_dim=768,
|
|
context_in_dim=4096,
|
|
hidden_size=3072,
|
|
mlp_ratio=4.0,
|
|
num_heads=24,
|
|
depth=19,
|
|
depth_single_blocks=38,
|
|
axes_dim=[16, 56, 56],
|
|
theta=10_000,
|
|
qkv_bias=True,
|
|
guidance_embed=True,
|
|
),
|
|
ae_params=AutoEncoderParams(
|
|
resolution=256,
|
|
in_channels=3,
|
|
ch=128,
|
|
out_ch=3,
|
|
ch_mult=[1, 2, 4, 4],
|
|
num_res_blocks=2,
|
|
z_channels=16,
|
|
scale_factor=0.3611,
|
|
shift_factor=0.1159,
|
|
),
|
|
),
|
|
"flux-dev-kontext": ModelSpec(
|
|
repo_id="black-forest-labs/FLUX.1-Kontext-dev",
|
|
repo_flow="",
|
|
repo_ae="ckpts/flux_vae.safetensors",
|
|
params=FluxParams(
|
|
in_channels=64,
|
|
out_channels=64,
|
|
vec_in_dim=768,
|
|
context_in_dim=4096,
|
|
hidden_size=3072,
|
|
mlp_ratio=4.0,
|
|
num_heads=24,
|
|
depth=19,
|
|
depth_single_blocks=38,
|
|
axes_dim=[16, 56, 56],
|
|
theta=10_000,
|
|
qkv_bias=True,
|
|
guidance_embed=True,
|
|
),
|
|
ae_params=AutoEncoderParams(
|
|
resolution=256,
|
|
in_channels=3,
|
|
ch=128,
|
|
out_ch=3,
|
|
ch_mult=[1, 2, 4, 4],
|
|
num_res_blocks=2,
|
|
z_channels=16,
|
|
scale_factor=0.3611,
|
|
shift_factor=0.1159,
|
|
),
|
|
),
|
|
}
|
|
|
|
|
|
PREFERED_KONTEXT_RESOLUTIONS = [
|
|
(672, 1568),
|
|
(688, 1504),
|
|
(720, 1456),
|
|
(752, 1392),
|
|
(800, 1328),
|
|
(832, 1248),
|
|
(880, 1184),
|
|
(944, 1104),
|
|
(1024, 1024),
|
|
(1104, 944),
|
|
(1184, 880),
|
|
(1248, 832),
|
|
(1328, 800),
|
|
(1392, 752),
|
|
(1456, 720),
|
|
(1504, 688),
|
|
(1568, 672),
|
|
]
|
|
|
|
|
|
def aspect_ratio_to_height_width(aspect_ratio: str, area: int = 1024**2) -> tuple[int, int]:
|
|
width = float(aspect_ratio.split(":")[0])
|
|
height = float(aspect_ratio.split(":")[1])
|
|
ratio = width / height
|
|
width = round(math.sqrt(area * ratio))
|
|
height = round(math.sqrt(area / ratio))
|
|
return 16 * (width // 16), 16 * (height // 16)
|
|
|
|
|
|
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
|
|
if len(missing) > 0 and len(unexpected) > 0:
|
|
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
|
print("\n" + "-" * 79 + "\n")
|
|
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
|
elif len(missing) > 0:
|
|
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
|
elif len(unexpected) > 0:
|
|
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
|
|
|
|
|
def load_flow_model(name: str, model_filename, device: str | torch.device = "cuda", verbose: bool = True) -> Flux:
|
|
# Loading Flux
|
|
config = configs[name]
|
|
|
|
ckpt_path = model_filename #config.repo_flow
|
|
|
|
with torch.device("meta"):
|
|
if config.lora_repo_id is not None and config.lora_filename is not None:
|
|
model = FluxLoraWrapper(params=config.params).to(torch.bfloat16)
|
|
else:
|
|
model = Flux(config.params).to(torch.bfloat16)
|
|
|
|
# print(f"Loading checkpoint: {ckpt_path}")
|
|
from mmgp import offload
|
|
offload.load_model_data(model, model_filename )
|
|
|
|
# # load_sft doesn't support torch.device
|
|
# sd = load_sft(ckpt_path, device=str(device))
|
|
# sd = optionally_expand_state_dict(model, sd)
|
|
# missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
|
# if verbose:
|
|
# print_load_warning(missing, unexpected)
|
|
|
|
# if config.lora_repo_id is not None and config.lora_filename is not None:
|
|
# print("Loading LoRA")
|
|
# lora_path = str(get_checkpoint_path(config.lora_repo_id, config.lora_filename, "FLUX_LORA"))
|
|
# lora_sd = load_sft(lora_path, device=str(device))
|
|
# # loading the lora params + overwriting scale values in the norms
|
|
# missing, unexpected = model.load_state_dict(lora_sd, strict=False, assign=True)
|
|
# if verbose:
|
|
# print_load_warning(missing, unexpected)
|
|
return model
|
|
|
|
|
|
def load_t5(device: str | torch.device = "cuda", text_encoder_filename = None, max_length: int = 512) -> HFEmbedder:
|
|
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
|
return HFEmbedder("",text_encoder_filename, max_length=max_length, torch_dtype=torch.bfloat16).to(device)
|
|
|
|
|
|
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
|
|
return HFEmbedder("ckpts/clip_vit_large_patch14", "", max_length=77, torch_dtype=torch.bfloat16, is_clip =True).to(device)
|
|
|
|
|
|
def load_ae(name: str, device: str | torch.device = "cuda") -> AutoEncoder:
|
|
config = configs[name]
|
|
ckpt_path = str(get_checkpoint_path(config.repo_id, config.repo_ae, "FLUX_AE"))
|
|
|
|
# Loading the autoencoder
|
|
with torch.device("meta"):
|
|
ae = AutoEncoder(config.ae_params)
|
|
|
|
# print(f"Loading AE checkpoint: {ckpt_path}")
|
|
sd = load_sft(ckpt_path, device=str(device))
|
|
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
|
|
print_load_warning(missing, unexpected)
|
|
return ae
|
|
|
|
|
|
def optionally_expand_state_dict(model: torch.nn.Module, state_dict: dict) -> dict:
|
|
"""
|
|
Optionally expand the state dict to match the model's parameters shapes.
|
|
"""
|
|
for name, param in model.named_parameters():
|
|
if name in state_dict:
|
|
if state_dict[name].shape != param.shape:
|
|
print(
|
|
f"Expanding '{name}' with shape {state_dict[name].shape} to model parameter with shape {param.shape}."
|
|
)
|
|
# expand with zeros:
|
|
expanded_state_dict_weight = torch.zeros_like(param, device=state_dict[name].device)
|
|
slices = tuple(slice(0, dim) for dim in state_dict[name].shape)
|
|
expanded_state_dict_weight[slices] = state_dict[name]
|
|
state_dict[name] = expanded_state_dict_weight
|
|
|
|
return state_dict
|
|
|
|
|