my-sd/extensions-builtin/sd_forge_svd/scripts/forge_svd.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

69 lines
2.3 KiB
Python
Raw Normal View History

2024-01-26 05:48:59 +00:00
import gradio as gr
2024-01-26 05:51:57 +00:00
import os
2024-01-26 06:00:36 +00:00
import pathlib
2024-01-26 05:48:59 +00:00
from modules import scripts, script_callbacks
2024-01-26 05:51:57 +00:00
from modules.paths import models_path
2024-01-26 06:03:43 +00:00
from modules.ui_common import ToolButton, refresh_symbol
2024-01-26 05:54:40 +00:00
from modules import shared
2024-01-26 06:15:39 +00:00
from modules_forge.gradio_compile import gradio_compile
2024-01-26 06:51:45 +00:00
from ldm_patched.contrib.external_video_model import ImageOnlyCheckpointLoader, VideoLinearCFGGuidance, SVD_img2vid_Conditioning
from ldm_patched.contrib.external import KSampler, VAEDecode
2024-01-26 06:15:39 +00:00
2024-01-26 06:57:14 +00:00
ps = []
ps += gradio_compile(SVD_img2vid_Conditioning.INPUT_TYPES(), prefix='')
ps += gradio_compile(KSampler.INPUT_TYPES(), prefix='sampling')
ps += gradio_compile(VideoLinearCFGGuidance.INPUT_TYPES(), prefix='guidance')
print(', '.join(ps))
2024-01-26 05:51:57 +00:00
svd_root = os.path.join(models_path, 'svd')
os.makedirs(svd_root, exist_ok=True)
2024-01-26 05:54:40 +00:00
svd_filenames = []
def update_svd_filenames():
global svd_filenames
2024-01-26 06:00:36 +00:00
svd_filenames = [pathlib.Path(x).name for x in shared.walk_files(svd_root, allowed_extensions=[".pt", ".ckpt", ".safetensors"])]
2024-01-26 05:54:40 +00:00
return svd_filenames
2024-01-26 05:48:59 +00:00
class ForgeSVD(scripts.Script):
def __init__(self) -> None:
super().__init__()
def title(self):
return "SVD"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
return ()
def on_ui_tabs():
with gr.Blocks(analytics_enabled=False) as svd_block:
with gr.Row():
with gr.Column():
2024-01-26 06:03:43 +00:00
with gr.Row():
2024-01-26 06:07:28 +00:00
filename = gr.Dropdown(label="SVD Checkpoint Filename",
choices=svd_filenames,
value=svd_filenames[0] if len(svd_filenames) > 0 else None)
2024-01-26 06:03:43 +00:00
refresh_button = ToolButton(value=refresh_symbol, tooltip="Refresh")
2024-01-26 06:07:07 +00:00
refresh_button.click(
fn=lambda: gr.update(choices=update_svd_filenames),
inputs=[], outputs=filename)
2024-01-26 06:11:21 +00:00
generate_button = gr.Button(value="Generate")
2024-01-26 05:48:59 +00:00
with gr.Column():
2024-01-26 06:10:23 +00:00
output_gallery = gr.Gallery(label='Gallery', show_label=False, object_fit='contain',
visible=True, height=1024, columns=4)
2024-01-26 05:48:59 +00:00
return [(svd_block, "SVD", "svd")]
2024-01-26 05:55:02 +00:00
update_svd_filenames()
2024-01-26 05:48:59 +00:00
script_callbacks.on_ui_tabs(on_ui_tabs)