diff --git a/gradio_server.py b/gradio_server.py index 777e51a..57ab19b 100644 --- a/gradio_server.py +++ b/gradio_server.py @@ -269,6 +269,7 @@ if not Path(server_config_filename).is_file(): "transformer_filename": transformer_choices_t2v[0], "transformer_filename_i2v": transformer_choices_i2v[1], ######## "text_encoder_filename" : text_encoder_choices[1], + "save_path": os.path.join(os.getcwd(), "gradio_outputs"), "compile" : "", "default_ui": "t2v", "boost" : 1, @@ -643,6 +644,7 @@ def apply_changes( state, transformer_t2v_choice, transformer_i2v_choice, text_encoder_choice, + save_path_choice, attention_choice, compile_choice, profile_choice, @@ -660,6 +662,7 @@ def apply_changes( state, "transformer_filename": transformer_choices_t2v[transformer_t2v_choice], "transformer_filename_i2v": transformer_choices_i2v[transformer_i2v_choice], ########## "text_encoder_filename" : text_encoder_choices[text_encoder_choice], + "save_path" : save_path_choice, "compile" : compile_choice, "profile" : profile_choice, "vae_config" : vae_config_choice, @@ -1088,7 +1091,8 @@ def generate_video( file_list = [] state["file_list"] = file_list - save_path = os.path.join(os.getcwd(), "gradio_outputs") + global save_path + save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs")) os.makedirs(save_path, exist_ok=True) video_no = 0 total_video = repeat_generation * len(prompts) @@ -1229,7 +1233,7 @@ def generate_video( file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4" else: file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4" - video_path = os.path.join(os.getcwd(), "gradio_outputs", file_name) + video_path = os.path.join(save_path, file_name) cache_video( tensor=sample[None], save_file=video_path, @@ -1675,6 +1679,10 @@ def create_demo(): value= index, label="Text Encoder model" ) + save_path_choice = gr.Textbox( + label="Output Folder for Generated Videos", + value=server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs")) + ) def check(mode): if not mode in attention_modes_supported: return " (NOT INSTALLED)" @@ -2030,6 +2038,7 @@ def create_demo(): transformer_t2v_choice, transformer_i2v_choice, text_encoder_choice, + save_path_choice, attention_choice, compile_choice, profile_choice,