2024-01-26 05:48:59 +00:00
|
|
|
import gradio as gr
|
2024-01-26 05:51:57 +00:00
|
|
|
import os
|
2024-01-26 06:00:36 +00:00
|
|
|
import pathlib
|
2024-01-26 05:48:59 +00:00
|
|
|
|
|
|
|
from modules import scripts, script_callbacks
|
2024-01-26 05:51:57 +00:00
|
|
|
from modules.paths import models_path
|
2024-01-26 06:03:43 +00:00
|
|
|
from modules.ui_common import ToolButton, refresh_symbol
|
2024-01-26 05:54:40 +00:00
|
|
|
from modules import shared
|
2024-01-26 06:15:39 +00:00
|
|
|
|
2024-01-26 07:09:10 +00:00
|
|
|
from ldm_patched.modules.sd import load_checkpoint_guess_config
|
|
|
|
from ldm_patched.contrib.external_video_model import VideoLinearCFGGuidance, SVD_img2vid_Conditioning
|
2024-01-26 06:51:45 +00:00
|
|
|
from ldm_patched.contrib.external import KSampler, VAEDecode
|
2024-01-26 06:15:39 +00:00
|
|
|
|
|
|
|
|
2024-01-26 07:09:10 +00:00
|
|
|
# from modules_forge.gradio_compile import gradio_compile
|
2024-01-26 07:02:17 +00:00
|
|
|
# ps = []
|
|
|
|
# ps += gradio_compile(SVD_img2vid_Conditioning.INPUT_TYPES(), prefix='')
|
|
|
|
# ps += gradio_compile(KSampler.INPUT_TYPES(), prefix='sampling')
|
|
|
|
# ps += gradio_compile(VideoLinearCFGGuidance.INPUT_TYPES(), prefix='guidance')
|
|
|
|
# print(', '.join(ps))
|
2024-01-26 05:51:57 +00:00
|
|
|
|
2024-01-26 07:09:10 +00:00
|
|
|
|
|
|
|
opVideoLinearCFGGuidance = VideoLinearCFGGuidance()
|
|
|
|
opSVD_img2vid_Conditioning = SVD_img2vid_Conditioning()
|
|
|
|
opKSampler = KSampler()
|
|
|
|
opVAEDecode = VAEDecode()
|
|
|
|
|
2024-01-26 05:51:57 +00:00
|
|
|
svd_root = os.path.join(models_path, 'svd')
|
|
|
|
os.makedirs(svd_root, exist_ok=True)
|
2024-01-26 05:54:40 +00:00
|
|
|
svd_filenames = []
|
|
|
|
|
|
|
|
|
|
|
|
def update_svd_filenames():
|
|
|
|
global svd_filenames
|
2024-01-26 06:00:36 +00:00
|
|
|
svd_filenames = [pathlib.Path(x).name for x in shared.walk_files(svd_root, allowed_extensions=[".pt", ".ckpt", ".safetensors"])]
|
2024-01-26 05:54:40 +00:00
|
|
|
return svd_filenames
|
2024-01-26 05:48:59 +00:00
|
|
|
|
|
|
|
|
2024-01-26 07:04:14 +00:00
|
|
|
def predict(filename, width, height, video_frames, motion_bucket_id, fps, augmentation_level,
|
|
|
|
sampling_seed, sampling_steps, sampling_cfg, sampling_sampler_name, sampling_scheduler,
|
|
|
|
sampling_denoise, guidance_min_cfg):
|
2024-01-26 07:09:10 +00:00
|
|
|
filename = os.path.join(svd_root, filename)
|
|
|
|
model, _, vae, clip_vision = load_checkpoint_guess_config(filename, output_vae=True, output_clip=False, output_clipvision=True)
|
|
|
|
a = 0
|
2024-01-26 07:04:14 +00:00
|
|
|
return
|
|
|
|
|
|
|
|
|
2024-01-26 05:48:59 +00:00
|
|
|
class ForgeSVD(scripts.Script):
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def title(self):
|
|
|
|
return "SVD"
|
|
|
|
|
|
|
|
def show(self, is_img2img):
|
|
|
|
return scripts.AlwaysVisible
|
|
|
|
|
|
|
|
def ui(self, is_img2img):
|
|
|
|
return ()
|
|
|
|
|
|
|
|
|
|
|
|
def on_ui_tabs():
|
|
|
|
with gr.Blocks(analytics_enabled=False) as svd_block:
|
|
|
|
with gr.Row():
|
|
|
|
with gr.Column():
|
2024-01-26 06:03:43 +00:00
|
|
|
with gr.Row():
|
2024-01-26 06:07:28 +00:00
|
|
|
filename = gr.Dropdown(label="SVD Checkpoint Filename",
|
|
|
|
choices=svd_filenames,
|
|
|
|
value=svd_filenames[0] if len(svd_filenames) > 0 else None)
|
2024-01-26 06:03:43 +00:00
|
|
|
refresh_button = ToolButton(value=refresh_symbol, tooltip="Refresh")
|
2024-01-26 06:07:07 +00:00
|
|
|
refresh_button.click(
|
|
|
|
fn=lambda: gr.update(choices=update_svd_filenames),
|
|
|
|
inputs=[], outputs=filename)
|
2024-01-26 07:00:35 +00:00
|
|
|
|
|
|
|
width = gr.Slider(label='Width', minimum=16, maximum=8192, step=8, value=1024)
|
|
|
|
height = gr.Slider(label='Height', minimum=16, maximum=8192, step=8, value=576)
|
|
|
|
video_frames = gr.Slider(label='Video Frames', minimum=1, maximum=4096, step=1, value=14)
|
|
|
|
motion_bucket_id = gr.Slider(label='Motion Bucket Id', minimum=1, maximum=1023, step=1, value=127)
|
|
|
|
fps = gr.Slider(label='Fps', minimum=1, maximum=1024, step=1, value=6)
|
|
|
|
augmentation_level = gr.Slider(label='Augmentation Level', minimum=0.0, maximum=10.0, step=0.01,
|
|
|
|
value=0.0)
|
|
|
|
sampling_seed = gr.Slider(label='Sampling Seed', minimum=0, maximum=18446744073709551615, step=1,
|
|
|
|
value=0)
|
|
|
|
sampling_steps = gr.Slider(label='Sampling Steps', minimum=1, maximum=10000, step=1, value=20)
|
|
|
|
sampling_cfg = gr.Slider(label='Sampling Cfg', minimum=0.0, maximum=100.0, step=0.1, value=8.0)
|
|
|
|
sampling_sampler_name = gr.Radio(label='Sampling Sampler Name',
|
|
|
|
choices=['euler', 'euler_ancestral', 'heun', 'heunpp2', 'dpm_2',
|
|
|
|
'dpm_2_ancestral', 'lms', 'dpm_fast', 'dpm_adaptive',
|
|
|
|
'dpmpp_2s_ancestral', 'dpmpp_sde', 'dpmpp_sde_gpu',
|
|
|
|
'dpmpp_2m', 'dpmpp_2m_sde', 'dpmpp_2m_sde_gpu',
|
|
|
|
'dpmpp_3m_sde', 'dpmpp_3m_sde_gpu', 'ddpm', 'lcm', 'ddim',
|
|
|
|
'uni_pc', 'uni_pc_bh2'], value='euler')
|
|
|
|
sampling_scheduler = gr.Radio(label='Sampling Scheduler',
|
|
|
|
choices=['normal', 'karras', 'exponential', 'sgm_uniform', 'simple',
|
|
|
|
'ddim_uniform'], value='normal')
|
|
|
|
sampling_denoise = gr.Slider(label='Sampling Denoise', minimum=0.0, maximum=1.0, step=0.01, value=1.0)
|
|
|
|
guidance_min_cfg = gr.Slider(label='Guidance Min Cfg', minimum=0.0, maximum=100.0, step=0.5, value=1.0)
|
2024-01-26 07:04:14 +00:00
|
|
|
|
2024-01-26 06:11:21 +00:00
|
|
|
generate_button = gr.Button(value="Generate")
|
2024-01-26 05:48:59 +00:00
|
|
|
|
2024-01-26 07:02:17 +00:00
|
|
|
ctrls = [filename, width, height, video_frames, motion_bucket_id, fps, augmentation_level,
|
|
|
|
sampling_seed, sampling_steps, sampling_cfg, sampling_sampler_name, sampling_scheduler,
|
2024-01-26 07:04:14 +00:00
|
|
|
sampling_denoise, guidance_min_cfg]
|
2024-01-26 07:02:17 +00:00
|
|
|
|
2024-01-26 05:48:59 +00:00
|
|
|
with gr.Column():
|
2024-01-26 06:10:23 +00:00
|
|
|
output_gallery = gr.Gallery(label='Gallery', show_label=False, object_fit='contain',
|
|
|
|
visible=True, height=1024, columns=4)
|
2024-01-26 05:48:59 +00:00
|
|
|
|
2024-01-26 07:04:14 +00:00
|
|
|
generate_button.click(predict, inputs=ctrls, outputs=[output_gallery])
|
2024-01-26 05:48:59 +00:00
|
|
|
return [(svd_block, "SVD", "svd")]
|
|
|
|
|
|
|
|
|
2024-01-26 05:55:02 +00:00
|
|
|
update_svd_filenames()
|
2024-01-26 05:48:59 +00:00
|
|
|
script_callbacks.on_ui_tabs(on_ui_tabs)
|