mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-16 12:13:27 +00:00
Improved download speed for finetunes on HF
This commit is contained in:
parent
500f48b074
commit
826cc3adb7
27
wgp.py
27
wgp.py
@ -2035,6 +2035,8 @@ def download_models(model_filename, model_type):
|
|||||||
hf_hub_download(repo_id=repoId, filename=onefile, local_dir = targetRoot)
|
hf_hub_download(repo_id=repoId, filename=onefile, local_dir = targetRoot)
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download, snapshot_download
|
from huggingface_hub import hf_hub_download, snapshot_download
|
||||||
|
from urllib.request import urlretrieve
|
||||||
|
from wan.utils.utils import create_progress_hook
|
||||||
|
|
||||||
shared_def = {
|
shared_def = {
|
||||||
"repoId" : "DeepBeepMeep/Wan2.1",
|
"repoId" : "DeepBeepMeep/Wan2.1",
|
||||||
@ -2053,12 +2055,28 @@ def download_models(model_filename, model_type):
|
|||||||
}
|
}
|
||||||
process_files_def(**enhancer_def)
|
process_files_def(**enhancer_def)
|
||||||
|
|
||||||
|
def download_file(url,filename):
|
||||||
|
if url.startswith("https://huggingface.co/") and "/resolve/main/" in url:
|
||||||
|
url = url[len("https://huggingface.co/"):]
|
||||||
|
url_parts = url.split("/resolve/main/")
|
||||||
|
repoId = url_parts[0]
|
||||||
|
onefile = os.path.basename(url_parts[-1])
|
||||||
|
sourceFolder = os.path.dirname(url_parts[-1])
|
||||||
|
if len(sourceFolder) == 0:
|
||||||
|
hf_hub_download(repo_id=repoId, filename=onefile, local_dir = "ckpts/")
|
||||||
|
else:
|
||||||
|
target_path = "ckpts/temp/" + sourceFolder
|
||||||
|
if not os.path.exists(target_path):
|
||||||
|
os.makedirs(target_path)
|
||||||
|
hf_hub_download(repo_id=repoId, filename=onefile, local_dir = "ckpts/temp/", subfolder=sourceFolder)
|
||||||
|
shutil.move(os.path.join( "ckpts", "temp" , sourceFolder , onefile), "ckpts/")
|
||||||
|
shutil.rmtree("ckpts/temp")
|
||||||
|
else:
|
||||||
|
urlretrieve(url,filename, create_progress_hook(filename))
|
||||||
|
|
||||||
model_family = get_model_family(model_type)
|
model_family = get_model_family(model_type)
|
||||||
finetune_def = get_model_finetune_def(model_type)
|
finetune_def = get_model_finetune_def(model_type)
|
||||||
if finetune_def != None:
|
if finetune_def != None:
|
||||||
from urllib.request import urlretrieve
|
|
||||||
from wan.utils.utils import create_progress_hook
|
|
||||||
if not os.path.isfile(model_filename ):
|
if not os.path.isfile(model_filename ):
|
||||||
for url in finetune_def["URLs"]:
|
for url in finetune_def["URLs"]:
|
||||||
if model_filename in url:
|
if model_filename in url:
|
||||||
@ -2066,7 +2084,7 @@ def download_models(model_filename, model_type):
|
|||||||
if not url.startswith("http"):
|
if not url.startswith("http"):
|
||||||
raise Exception(f"Model '{model_filename}' was not found locally and no URL was provided to download it. Please add an URL in the finetune definition file.")
|
raise Exception(f"Model '{model_filename}' was not found locally and no URL was provided to download it. Please add an URL in the finetune definition file.")
|
||||||
try:
|
try:
|
||||||
urlretrieve(url,model_filename, create_progress_hook(model_filename))
|
download_file(url, model_filename)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if os.path.isfile(model_filename): os.remove(model_filename)
|
if os.path.isfile(model_filename): os.remove(model_filename)
|
||||||
raise Exception(f"URL '{url}' is invalid for Model '{model_filename}' : {str(e)}'")
|
raise Exception(f"URL '{url}' is invalid for Model '{model_filename}' : {str(e)}'")
|
||||||
@ -2076,7 +2094,7 @@ def download_models(model_filename, model_type):
|
|||||||
if not url.startswith("http"):
|
if not url.startswith("http"):
|
||||||
raise Exception(f"File '{filename}' to preload was not found locally and no URL was provided to download it. Please add an URL in the finetune definition file.")
|
raise Exception(f"File '{filename}' to preload was not found locally and no URL was provided to download it. Please add an URL in the finetune definition file.")
|
||||||
try:
|
try:
|
||||||
urlretrieve(url,filename, create_progress_hook(filename))
|
download_file(url, filename)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if os.path.isfile(filename): os.remove(filename)
|
if os.path.isfile(filename): os.remove(filename)
|
||||||
raise Exception(f"Preload URL '{url}' is invalid: {str(e)}'")
|
raise Exception(f"Preload URL '{url}' is invalid: {str(e)}'")
|
||||||
@ -2114,7 +2132,6 @@ def download_models(model_filename, model_type):
|
|||||||
|
|
||||||
process_files_def(**model_def)
|
process_files_def(**model_def)
|
||||||
|
|
||||||
|
|
||||||
offload.default_verboseLevel = verbose_level
|
offload.default_verboseLevel = verbose_level
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user