Update forge_svd.py

This commit is contained in:
lllyasviel 2024-01-25 23:10:59 -08:00
parent 8c10ec65f0
commit 77803db5fe

View File

@ -38,7 +38,7 @@ def update_svd_filenames():
def predict(filename, width, height, video_frames, motion_bucket_id, fps, augmentation_level,
sampling_seed, sampling_steps, sampling_cfg, sampling_sampler_name, sampling_scheduler,
sampling_denoise, guidance_min_cfg):
sampling_denoise, guidance_min_cfg, input_image):
filename = os.path.join(svd_root, filename)
model, _, vae, clip_vision = load_checkpoint_guess_config(filename, output_vae=True, output_clip=False, output_clipvision=True)
a = 0
@ -63,6 +63,8 @@ def on_ui_tabs():
with gr.Blocks(analytics_enabled=False) as svd_block:
with gr.Row():
with gr.Column():
input_image = gr.Image(label='Drag above image to here', source='upload', type='numpy', height=400)
with gr.Row():
filename = gr.Dropdown(label="SVD Checkpoint Filename",
choices=svd_filenames,
@ -100,7 +102,7 @@ def on_ui_tabs():
ctrls = [filename, width, height, video_frames, motion_bucket_id, fps, augmentation_level,
sampling_seed, sampling_steps, sampling_cfg, sampling_sampler_name, sampling_scheduler,
sampling_denoise, guidance_min_cfg]
sampling_denoise, guidance_min_cfg, input_image]
with gr.Column():
output_gallery = gr.Gallery(label='Gallery', show_label=False, object_fit='contain',