mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-15 11:43:21 +00:00
Improved download speed for finetunes on HF
This commit is contained in:
parent
500f48b074
commit
826cc3adb7
27
wgp.py
27
wgp.py
@ -2035,6 +2035,8 @@ def download_models(model_filename, model_type):
|
||||
hf_hub_download(repo_id=repoId, filename=onefile, local_dir = targetRoot)
|
||||
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from urllib.request import urlretrieve
|
||||
from wan.utils.utils import create_progress_hook
|
||||
|
||||
shared_def = {
|
||||
"repoId" : "DeepBeepMeep/Wan2.1",
|
||||
@ -2053,12 +2055,28 @@ def download_models(model_filename, model_type):
|
||||
}
|
||||
process_files_def(**enhancer_def)
|
||||
|
||||
def download_file(url,filename):
|
||||
if url.startswith("https://huggingface.co/") and "/resolve/main/" in url:
|
||||
url = url[len("https://huggingface.co/"):]
|
||||
url_parts = url.split("/resolve/main/")
|
||||
repoId = url_parts[0]
|
||||
onefile = os.path.basename(url_parts[-1])
|
||||
sourceFolder = os.path.dirname(url_parts[-1])
|
||||
if len(sourceFolder) == 0:
|
||||
hf_hub_download(repo_id=repoId, filename=onefile, local_dir = "ckpts/")
|
||||
else:
|
||||
target_path = "ckpts/temp/" + sourceFolder
|
||||
if not os.path.exists(target_path):
|
||||
os.makedirs(target_path)
|
||||
hf_hub_download(repo_id=repoId, filename=onefile, local_dir = "ckpts/temp/", subfolder=sourceFolder)
|
||||
shutil.move(os.path.join( "ckpts", "temp" , sourceFolder , onefile), "ckpts/")
|
||||
shutil.rmtree("ckpts/temp")
|
||||
else:
|
||||
urlretrieve(url,filename, create_progress_hook(filename))
|
||||
|
||||
model_family = get_model_family(model_type)
|
||||
finetune_def = get_model_finetune_def(model_type)
|
||||
if finetune_def != None:
|
||||
from urllib.request import urlretrieve
|
||||
from wan.utils.utils import create_progress_hook
|
||||
if not os.path.isfile(model_filename ):
|
||||
for url in finetune_def["URLs"]:
|
||||
if model_filename in url:
|
||||
@ -2066,7 +2084,7 @@ def download_models(model_filename, model_type):
|
||||
if not url.startswith("http"):
|
||||
raise Exception(f"Model '{model_filename}' was not found locally and no URL was provided to download it. Please add an URL in the finetune definition file.")
|
||||
try:
|
||||
urlretrieve(url,model_filename, create_progress_hook(model_filename))
|
||||
download_file(url, model_filename)
|
||||
except Exception as e:
|
||||
if os.path.isfile(model_filename): os.remove(model_filename)
|
||||
raise Exception(f"URL '{url}' is invalid for Model '{model_filename}' : {str(e)}'")
|
||||
@ -2076,7 +2094,7 @@ def download_models(model_filename, model_type):
|
||||
if not url.startswith("http"):
|
||||
raise Exception(f"File '{filename}' to preload was not found locally and no URL was provided to download it. Please add an URL in the finetune definition file.")
|
||||
try:
|
||||
urlretrieve(url,filename, create_progress_hook(filename))
|
||||
download_file(url, filename)
|
||||
except Exception as e:
|
||||
if os.path.isfile(filename): os.remove(filename)
|
||||
raise Exception(f"Preload URL '{url}' is invalid: {str(e)}'")
|
||||
@ -2114,7 +2132,6 @@ def download_models(model_filename, model_type):
|
||||
|
||||
process_files_def(**model_def)
|
||||
|
||||
|
||||
offload.default_verboseLevel = verbose_level
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user