This commit is contained in:
commit
51cf69328e
@ -11,6 +11,8 @@ from .global_state import (
|
||||
get_all_preprocessor_names,
|
||||
get_all_controlnet_names,
|
||||
get_preprocessor,
|
||||
get_all_preprocessor_tags,
|
||||
select_control_type,
|
||||
)
|
||||
from .utils import judge_image_type
|
||||
from .logging import logger
|
||||
@ -53,6 +55,30 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
|
||||
# "module_detail": external_code.get_modules_detail(alias_names),
|
||||
}
|
||||
|
||||
@app.get("/controlnet/control_types")
|
||||
async def control_types():
|
||||
def format_control_type(
|
||||
filtered_preprocessor_list,
|
||||
filtered_model_list,
|
||||
default_option,
|
||||
default_model,
|
||||
):
|
||||
control_dict = {
|
||||
"module_list": filtered_preprocessor_list,
|
||||
"model_list": filtered_model_list,
|
||||
"default_option": default_option,
|
||||
"default_model": default_model,
|
||||
}
|
||||
|
||||
return control_dict
|
||||
|
||||
return {
|
||||
"control_types": {
|
||||
control_type: format_control_type(*select_control_type(control_type))
|
||||
for control_type in get_all_preprocessor_tags()
|
||||
}
|
||||
}
|
||||
|
||||
@app.post("/controlnet/detect")
|
||||
async def detect(
|
||||
controlnet_module: str = Body("none", title="Controlnet Module"),
|
||||
|
@ -11,16 +11,17 @@ from lib_controlnet import (
|
||||
global_state,
|
||||
external_code,
|
||||
)
|
||||
from lib_controlnet.external_code import ControlNetUnit
|
||||
from lib_controlnet.logging import logger
|
||||
from lib_controlnet.controlnet_ui.openpose_editor import OpenposeEditor
|
||||
from lib_controlnet.controlnet_ui.preset import ControlNetPresetUI
|
||||
from lib_controlnet.controlnet_ui.tool_button import ToolButton
|
||||
from lib_controlnet.controlnet_ui.photopea import Photopea
|
||||
from lib_controlnet.controlnet_ui.multi_inputs_gallery import MultiInputsGallery
|
||||
from lib_controlnet.enums import InputMode, HiResFixOption
|
||||
from modules import shared, script_callbacks
|
||||
from modules.ui_components import FormRow
|
||||
from modules_forge.forge_util import HWC3
|
||||
from lib_controlnet.external_code import UiControlNetUnit
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -171,10 +172,10 @@ class ControlNetUiGroup(object):
|
||||
self.webcam_mirrored = False
|
||||
|
||||
# Note: All gradio elements declared in `render` will be defined as member variable.
|
||||
# Update counter to trigger a force update of UiControlNetUnit.
|
||||
# Update counter to trigger a force update of ControlNetUnit.
|
||||
# dummy_gradio_update_trigger is useful when a field with no event subscriber available changes.
|
||||
# e.g. gr.Gallery, gr.State, etc. After an update to gr.State / gr.Gallery, please increment
|
||||
# this counter to trigger a sync update of UiControlNetUnit.
|
||||
# this counter to trigger a sync update of ControlNetUnit.
|
||||
self.dummy_gradio_update_trigger = None
|
||||
self.enabled = None
|
||||
self.upload_tab = None
|
||||
@ -185,10 +186,11 @@ class ControlNetUiGroup(object):
|
||||
self.mask_image = None
|
||||
self.batch_tab = None
|
||||
self.batch_image_dir = None
|
||||
self.merge_tab = None
|
||||
self.batch_upload_tab = None
|
||||
self.batch_input_gallery = None
|
||||
self.merge_upload_button = None
|
||||
self.merge_clear_button = None
|
||||
self.batch_mask_gallery = None
|
||||
self.multi_inputs_upload_tab = None
|
||||
self.multi_inputs_input_gallery = None
|
||||
self.create_canvas = None
|
||||
self.canvas_width = None
|
||||
self.canvas_height = None
|
||||
@ -225,6 +227,7 @@ class ControlNetUiGroup(object):
|
||||
self.hr_option = None
|
||||
self.batch_image_dir_state = None
|
||||
self.output_dir_state = None
|
||||
self.advanced_weighting = gr.State(None)
|
||||
|
||||
# Internal states for UI state pasting.
|
||||
self.prevent_next_n_module_update = 0
|
||||
@ -331,31 +334,17 @@ class ControlNetUiGroup(object):
|
||||
visible=False,
|
||||
)
|
||||
|
||||
with gr.Tab(label="Batch Upload") as self.merge_tab:
|
||||
with gr.Tab(label="Batch Upload") as self.batch_upload_tab:
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
self.batch_input_gallery = gr.Gallery(
|
||||
columns=[4], rows=[2], object_fit="contain", height="auto", label="Images"
|
||||
)
|
||||
with gr.Row():
|
||||
self.merge_upload_button = gr.UploadButton(
|
||||
"Upload Images",
|
||||
file_types=["image"],
|
||||
file_count="multiple",
|
||||
)
|
||||
self.merge_clear_button = gr.Button("Clear Images")
|
||||
with gr.Group(visible=False, elem_classes=["cnet-mask-gallery-group"]) as self.batch_mask_gallery_group:
|
||||
with gr.Column():
|
||||
self.batch_mask_gallery = gr.Gallery(
|
||||
columns=[4], rows=[2], object_fit="contain", height="auto", label="Masks"
|
||||
)
|
||||
with gr.Row():
|
||||
self.mask_merge_upload_button = gr.UploadButton(
|
||||
"Upload Masks",
|
||||
file_types=["image"],
|
||||
file_count="multiple",
|
||||
)
|
||||
self.mask_merge_clear_button = gr.Button("Clear Masks")
|
||||
self.batch_input_gallery = MultiInputsGallery()
|
||||
self.batch_mask_gallery = MultiInputsGallery(
|
||||
visible=False,
|
||||
elem_classes=["cnet-mask-gallery-group"]
|
||||
)
|
||||
|
||||
with gr.Tab(label="Multi-Inputs") as self.multi_inputs_upload_tab:
|
||||
with gr.Row():
|
||||
self.multi_inputs_gallery = MultiInputsGallery()
|
||||
|
||||
if self.photopea:
|
||||
self.photopea.attach_photopea_output(self.generated_image)
|
||||
@ -600,8 +589,9 @@ class ControlNetUiGroup(object):
|
||||
self.use_preview_as_input,
|
||||
self.batch_image_dir,
|
||||
self.batch_mask_dir,
|
||||
self.batch_input_gallery,
|
||||
self.batch_mask_gallery,
|
||||
self.batch_input_gallery.input_gallery,
|
||||
self.batch_mask_gallery.input_gallery,
|
||||
self.multi_inputs_gallery.input_gallery,
|
||||
self.generated_image,
|
||||
self.mask_image,
|
||||
self.hr_option,
|
||||
@ -618,9 +608,16 @@ class ControlNetUiGroup(object):
|
||||
self.guidance_end,
|
||||
self.pixel_perfect,
|
||||
self.control_mode,
|
||||
self.advanced_weighting,
|
||||
)
|
||||
|
||||
unit = gr.State(self.default_unit)
|
||||
def create_unit(*args):
|
||||
return ControlNetUnit.from_dict({
|
||||
k: v
|
||||
for k, v in zip(vars(ControlNetUnit()).keys(), args)
|
||||
})
|
||||
|
||||
for comp in unit_args + (self.dummy_gradio_update_trigger,):
|
||||
event_subscribers = []
|
||||
if hasattr(comp, "edit"):
|
||||
@ -637,7 +634,7 @@ class ControlNetUiGroup(object):
|
||||
|
||||
for event_subscriber in event_subscribers:
|
||||
event_subscriber(
|
||||
fn=UiControlNetUnit, inputs=list(unit_args), outputs=unit
|
||||
fn=create_unit, inputs=list(unit_args), outputs=unit
|
||||
)
|
||||
|
||||
(
|
||||
@ -645,7 +642,7 @@ class ControlNetUiGroup(object):
|
||||
if self.is_img2img
|
||||
else ControlNetUiGroup.a1111_context.txt2img_submit_button
|
||||
).click(
|
||||
fn=UiControlNetUnit,
|
||||
fn=create_unit,
|
||||
inputs=list(unit_args),
|
||||
outputs=unit,
|
||||
queue=False,
|
||||
@ -981,20 +978,41 @@ class ControlNetUiGroup(object):
|
||||
"""Controls whether the upload mask input should be visible."""
|
||||
def on_checkbox_click(checked: bool, canvas_height: int, canvas_width: int):
|
||||
if not checked:
|
||||
# Clear mask_image if unchecked.
|
||||
return gr.update(visible=False), gr.update(value=None), gr.update(value=None, visible=False), \
|
||||
gr.update(visible=False), gr.update(value=None)
|
||||
# Clear mask inputs if unchecked.
|
||||
return (
|
||||
# Single mask upload.
|
||||
gr.update(visible=False),
|
||||
gr.update(value=None),
|
||||
# Batch mask upload dir.
|
||||
gr.update(value=None, visible=False),
|
||||
# Multi mask upload gallery.
|
||||
gr.update(visible=False),
|
||||
gr.update(value=None)
|
||||
)
|
||||
else:
|
||||
# Init an empty canvas the same size as the generation target.
|
||||
empty_canvas = np.zeros(shape=(canvas_height, canvas_width, 3), dtype=np.uint8)
|
||||
return gr.update(visible=True), gr.update(value=empty_canvas), gr.update(visible=True), \
|
||||
gr.update(visible=True), gr.update()
|
||||
return (
|
||||
# Single mask upload.
|
||||
gr.update(visible=True),
|
||||
gr.update(value=empty_canvas),
|
||||
# Batch mask upload dir.
|
||||
gr.update(visible=True),
|
||||
# Multi mask upload gallery.
|
||||
gr.update(visible=True),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
self.mask_upload.change(
|
||||
fn=on_checkbox_click,
|
||||
inputs=[self.mask_upload, self.height_slider, self.width_slider],
|
||||
outputs=[self.mask_image_group, self.mask_image, self.batch_mask_dir,
|
||||
self.batch_mask_gallery_group, self.batch_mask_gallery],
|
||||
outputs=[
|
||||
self.mask_image_group,
|
||||
self.mask_image,
|
||||
self.batch_mask_dir,
|
||||
self.batch_mask_gallery.group,
|
||||
self.batch_mask_gallery.input_gallery,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
@ -1079,51 +1097,14 @@ class ControlNetUiGroup(object):
|
||||
|
||||
def register_multi_images_upload(self):
|
||||
"""Register callbacks on merge tab multiple images upload."""
|
||||
self.merge_clear_button.click(
|
||||
fn=lambda: [],
|
||||
inputs=[],
|
||||
outputs=[self.batch_input_gallery],
|
||||
).then(
|
||||
fn=lambda x: gr.update(value=x + 1),
|
||||
trigger_dict = dict(
|
||||
fn=lambda n: gr.update(value=n + 1),
|
||||
inputs=[self.dummy_gradio_update_trigger],
|
||||
outputs=[self.dummy_gradio_update_trigger],
|
||||
)
|
||||
self.mask_merge_clear_button.click(
|
||||
fn=lambda: [],
|
||||
inputs=[],
|
||||
outputs=[self.batch_mask_gallery],
|
||||
).then(
|
||||
fn=lambda x: gr.update(value=x + 1),
|
||||
inputs=[self.dummy_gradio_update_trigger],
|
||||
outputs=[self.dummy_gradio_update_trigger],
|
||||
)
|
||||
|
||||
def upload_file(files, current_files):
|
||||
return {file_d["name"] for file_d in current_files} | {
|
||||
file.name for file in files
|
||||
}
|
||||
|
||||
self.merge_upload_button.upload(
|
||||
upload_file,
|
||||
inputs=[self.merge_upload_button, self.batch_input_gallery],
|
||||
outputs=[self.batch_input_gallery],
|
||||
queue=False,
|
||||
).then(
|
||||
fn=lambda x: gr.update(value=x + 1),
|
||||
inputs=[self.dummy_gradio_update_trigger],
|
||||
outputs=[self.dummy_gradio_update_trigger],
|
||||
)
|
||||
self.mask_merge_upload_button.upload(
|
||||
upload_file,
|
||||
inputs=[self.mask_merge_upload_button, self.batch_mask_gallery],
|
||||
outputs=[self.batch_mask_gallery],
|
||||
queue=False,
|
||||
).then(
|
||||
fn=lambda x: gr.update(value=x + 1),
|
||||
inputs=[self.dummy_gradio_update_trigger],
|
||||
outputs=[self.dummy_gradio_update_trigger],
|
||||
)
|
||||
return
|
||||
self.batch_input_gallery.register_callbacks(change_trigger=trigger_dict)
|
||||
self.batch_mask_gallery.register_callbacks(change_trigger=trigger_dict)
|
||||
self.multi_inputs_gallery.register_callbacks(change_trigger=trigger_dict)
|
||||
|
||||
def register_core_callbacks(self):
|
||||
"""Register core callbacks that only involves gradio components defined
|
||||
@ -1213,7 +1194,8 @@ class ControlNetUiGroup(object):
|
||||
for input_tab, fn in (
|
||||
(ui_group.upload_tab, simple_fn),
|
||||
(ui_group.batch_tab, batch_fn),
|
||||
(ui_group.merge_tab, merge_fn),
|
||||
(ui_group.batch_upload_tab, batch_fn),
|
||||
(ui_group.multi_inputs_upload_tab, merge_fn),
|
||||
):
|
||||
# Sync input_mode.
|
||||
input_tab.select(
|
||||
|
@ -0,0 +1,65 @@
|
||||
import gradio as gr
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class MultiInputsGallery:
|
||||
"""A gallery object that accepts multiple input images."""
|
||||
|
||||
def __init__(self, row: int = 2, column: int = 4, **group_kwargs) -> None:
|
||||
self.gallery_row_num = row
|
||||
self.gallery_column_num = column
|
||||
self.group_kwargs = group_kwargs
|
||||
|
||||
self.group = None
|
||||
self.input_gallery = None
|
||||
self.upload_button = None
|
||||
self.clear_button = None
|
||||
self.render()
|
||||
|
||||
def render(self):
|
||||
with gr.Group(**self.group_kwargs) as self.group:
|
||||
with gr.Column():
|
||||
self.input_gallery = gr.Gallery(
|
||||
columns=[self.gallery_column_num],
|
||||
rows=[self.gallery_row_num],
|
||||
object_fit="contain",
|
||||
height="auto",
|
||||
label="Images",
|
||||
)
|
||||
with gr.Row():
|
||||
self.upload_button = gr.UploadButton(
|
||||
"Upload Images",
|
||||
file_types=["image"],
|
||||
file_count="multiple",
|
||||
)
|
||||
self.clear_button = gr.Button("Clear Images")
|
||||
|
||||
def register_callbacks(self, change_trigger: Optional[dict] = None):
|
||||
"""Register callbacks on multiple images upload.
|
||||
Argument:
|
||||
- change_trigger: An optional gradio callback param dict to be called
|
||||
after gallery content change. This is necessary as gallery has no
|
||||
event subscriber. If the state change of gallery needs to be observed,
|
||||
the caller needs to pass a change trigger to observe the change.
|
||||
"""
|
||||
handle1 = self.clear_button.click(
|
||||
fn=lambda: [],
|
||||
inputs=[],
|
||||
outputs=[self.input_gallery],
|
||||
)
|
||||
|
||||
def upload_file(files, current_files):
|
||||
return {file_d["name"] for file_d in current_files} | {
|
||||
file.name for file in files
|
||||
}
|
||||
|
||||
handle2 = self.upload_button.upload(
|
||||
upload_file,
|
||||
inputs=[self.upload_button, self.input_gallery],
|
||||
outputs=[self.input_gallery],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
if change_trigger:
|
||||
for handle in (handle1, handle2):
|
||||
handle.then(**change_trigger)
|
@ -183,6 +183,8 @@ class ControlNetUnit:
|
||||
batch_input_gallery: Optional[List[str]] = None
|
||||
# Optional list of gallery masks for batch processing; defaults to None.
|
||||
batch_mask_gallery: Optional[List[str]] = None
|
||||
# Optional list of gallery images for multi-inputs; defaults to None.
|
||||
multi_inputs_gallery: Optional[List[str]] = None
|
||||
# Holds the preview image as a NumPy array; defaults to None.
|
||||
generated_image: Optional[np.ndarray] = None
|
||||
# ====== End of UI only fields ======
|
||||
@ -192,7 +194,7 @@ class ControlNetUnit:
|
||||
mask_image: Optional[GradioImageMaskPair] = None
|
||||
# Specifies how this unit should be applied in each pass of high-resolution fix.
|
||||
# Ignored if high-resolution fix is not enabled.
|
||||
hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH
|
||||
hr_option: HiResFixOption = HiResFixOption.BOTH
|
||||
# Indicates whether the unit is enabled; defaults to True.
|
||||
enabled: bool = True
|
||||
# Name of the module being used; defaults to "None".
|
||||
@ -204,7 +206,7 @@ class ControlNetUnit:
|
||||
# Optional image for input; defaults to None.
|
||||
image: Optional[GradioImageMaskPair] = None
|
||||
# Specifies the mode of image resizing; defaults to inner fit.
|
||||
resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT
|
||||
resize_mode: ResizeMode = ResizeMode.INNER_FIT
|
||||
# Resolution for processing by the unit; defaults to -1 (unspecified).
|
||||
processor_res: int = -1
|
||||
# Threshold A for processing; defaults to -1 (unspecified).
|
||||
@ -218,7 +220,22 @@ class ControlNetUnit:
|
||||
# Enables pixel-perfect processing; defaults to False.
|
||||
pixel_perfect: bool = False
|
||||
# Control mode for the unit; defaults to balanced.
|
||||
control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED
|
||||
control_mode: ControlMode = ControlMode.BALANCED
|
||||
# Weight for each layer of ControlNet params.
|
||||
# For ControlNet:
|
||||
# - SD1.5: 13 weights (4 encoder block * 3 + 1 middle block)
|
||||
# - SDXL: 10 weights (3 encoder block * 3 + 1 middle block)
|
||||
# For T2IAdapter
|
||||
# - SD1.5: 5 weights (4 encoder block + 1 middle block)
|
||||
# - SDXL: 4 weights (3 encoder block + 1 middle block)
|
||||
# For IPAdapter
|
||||
# - SD15: 16 (6 input blocks + 9 output blocks + 1 middle block)
|
||||
# - SDXL: 11 weights (4 input blocks + 6 output blocks + 1 middle block)
|
||||
# Note1: Setting advanced weighting will disable `soft_injection`, i.e.
|
||||
# It is recommended to set ControlMode = BALANCED when using `advanced_weighting`.
|
||||
# Note2: The field `weight` is still used in some places, e.g. reference_only,
|
||||
# even advanced_weighting is set.
|
||||
advanced_weighting: Optional[List[float]] = None
|
||||
|
||||
# Following fields should only be used in the API.
|
||||
# ====== Start of API only fields ======
|
||||
@ -275,6 +292,11 @@ class ControlNetUnit:
|
||||
"image": mask,
|
||||
"mask": np.zeros_like(mask),
|
||||
}
|
||||
# Convert strings to enums.
|
||||
unit.input_mode = InputMode(unit.input_mode)
|
||||
unit.hr_option = HiResFixOption(unit.hr_option)
|
||||
unit.resize_mode = ResizeMode(unit.resize_mode)
|
||||
unit.control_mode = ControlMode(unit.control_mode)
|
||||
return unit
|
||||
|
||||
|
||||
|
@ -6,6 +6,7 @@ from modules import shared, sd_models
|
||||
from lib_controlnet.enums import StableDiffusionVersion
|
||||
from modules_forge.shared import controlnet_dir, supported_preprocessors
|
||||
|
||||
from typing import Dict, Tuple, List
|
||||
|
||||
CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin", ".patch"]
|
||||
|
||||
@ -56,6 +57,10 @@ controlnet_names = ['None']
|
||||
def get_preprocessor(name):
|
||||
return supported_preprocessors.get(name, None)
|
||||
|
||||
def get_default_preprocessor(tag):
|
||||
ps = get_filtered_preprocessor_names(tag)
|
||||
assert len(ps) > 0
|
||||
return ps[0] if len(ps) == 1 else ps[1]
|
||||
|
||||
def get_sorted_preprocessors():
|
||||
preprocessors = [p for k, p in supported_preprocessors.items() if k != 'None']
|
||||
@ -144,3 +149,44 @@ def get_sd_version() -> StableDiffusionVersion:
|
||||
return StableDiffusionVersion.SD1x
|
||||
else:
|
||||
return StableDiffusionVersion.UNKNOWN
|
||||
|
||||
|
||||
def select_control_type(
|
||||
control_type: str,
|
||||
sd_version: StableDiffusionVersion = StableDiffusionVersion.UNKNOWN,
|
||||
) -> Tuple[List[str], List[str], str, str]:
|
||||
global controlnet_names
|
||||
|
||||
pattern = control_type.lower()
|
||||
all_models = list(controlnet_names)
|
||||
|
||||
if pattern == "all":
|
||||
preprocessors = get_sorted_preprocessors().values()
|
||||
return [
|
||||
[p.name for p in preprocessors],
|
||||
all_models,
|
||||
'none', # default option
|
||||
"None" # default model
|
||||
]
|
||||
|
||||
filtered_model_list = get_filtered_controlnet_names(control_type)
|
||||
|
||||
if pattern == "none":
|
||||
filtered_model_list.append("None")
|
||||
|
||||
assert len(filtered_model_list) > 0, "'None' model should always be available."
|
||||
if len(filtered_model_list) == 1:
|
||||
default_model = "None"
|
||||
else:
|
||||
default_model = filtered_model_list[1]
|
||||
for x in filtered_model_list:
|
||||
if "11" in x.split("[")[0]:
|
||||
default_model = x
|
||||
break
|
||||
|
||||
return (
|
||||
get_filtered_preprocessor_names(control_type),
|
||||
filtered_model_list,
|
||||
get_default_preprocessor(control_type),
|
||||
default_model
|
||||
)
|
||||
|
@ -1,7 +1,10 @@
|
||||
from typing import Optional
|
||||
from copy import copy
|
||||
from modules import processing
|
||||
from modules.api import api
|
||||
|
||||
from lib_controlnet import external_code
|
||||
from lib_controlnet.external_code import InputMode, ControlNetUnit
|
||||
|
||||
from modules_forge.forge_util import HWC3
|
||||
|
||||
@ -361,3 +364,22 @@ def crop_and_resize_image(detected_map, resize_mode, h, w, fill_border_with_255=
|
||||
|
||||
def judge_image_type(img):
|
||||
return isinstance(img, np.ndarray) and img.ndim == 3 and int(img.shape[2]) in [3, 4]
|
||||
|
||||
|
||||
def try_unfold_unit(unit: ControlNetUnit) -> List[ControlNetUnit]:
|
||||
"""Unfolds an multi-inputs unit into multiple units with one input."""
|
||||
if unit.input_mode != InputMode.MERGE:
|
||||
return [unit]
|
||||
|
||||
def extract_unit(gallery_item: dict) -> ControlNetUnit:
|
||||
r_unit = copy(unit)
|
||||
img = np.array(api.decode_base64_to_image(read_image(gallery_item["name"]))).astype('uint8')
|
||||
r_unit.image = {
|
||||
"image": img,
|
||||
"mask": np.zeros_like(img),
|
||||
}
|
||||
r_unit.input_mode = InputMode.SIMPLE
|
||||
r_unit.weight = unit.weight / len(unit.multi_inputs_gallery)
|
||||
return r_unit
|
||||
|
||||
return [extract_unit(item) for item in unit.multi_inputs_gallery]
|
||||
|
@ -11,9 +11,15 @@ from modules.api.api import decode_base64_to_image
|
||||
import gradio as gr
|
||||
|
||||
from lib_controlnet import global_state, external_code
|
||||
from lib_controlnet.external_code import ControlNetUnit
|
||||
from lib_controlnet.utils import align_dim_latent, set_numpy_seed, crop_and_resize_image, \
|
||||
prepare_mask, judge_image_type
|
||||
from lib_controlnet.external_code import ControlNetUnit, InputMode
|
||||
from lib_controlnet.utils import (
|
||||
align_dim_latent,
|
||||
set_numpy_seed,
|
||||
crop_and_resize_image,
|
||||
prepare_mask,
|
||||
judge_image_type,
|
||||
try_unfold_unit,
|
||||
)
|
||||
from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup
|
||||
from lib_controlnet.controlnet_ui.photopea import Photopea
|
||||
from lib_controlnet.logging import logger
|
||||
@ -53,6 +59,7 @@ class ControlNetCachedParameters:
|
||||
self.control_cond_for_hr_fix = None
|
||||
self.control_mask = None
|
||||
self.control_mask_for_hr_fix = None
|
||||
self.advanced_weighting = None
|
||||
|
||||
|
||||
class ControlNetForForgeOfficial(scripts.Script):
|
||||
@ -105,8 +112,13 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
for unit in units
|
||||
]
|
||||
assert all(isinstance(unit, ControlNetUnit) for unit in units)
|
||||
enabled_units = [x for x in units if x.enabled]
|
||||
return enabled_units
|
||||
return [
|
||||
simple_unit
|
||||
for unit in units
|
||||
# Unfolds multi-inputs units.
|
||||
for simple_unit in try_unfold_unit(unit)
|
||||
if simple_unit.enabled
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def try_crop_image_with_a1111_mask(
|
||||
@ -153,9 +165,9 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
def get_input_data(self, p, unit, preprocessor, h, w):
|
||||
logger.info(f'ControlNet Input Mode: {unit.input_mode}')
|
||||
image_list = []
|
||||
resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
|
||||
resize_mode = unit.resize_mode
|
||||
|
||||
if unit.input_mode == external_code.InputMode.MERGE:
|
||||
if unit.input_mode == InputMode.MERGE:
|
||||
for idx, item in enumerate(unit.batch_input_gallery):
|
||||
img_path = item['name']
|
||||
logger.info(f'Try to read image: {img_path}')
|
||||
@ -169,7 +181,7 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
mask = np.ascontiguousarray(cv2.imread(mask_path)[:, :, ::-1]).copy()
|
||||
if img is not None:
|
||||
image_list.append([img, mask])
|
||||
elif unit.input_mode == external_code.InputMode.BATCH:
|
||||
elif unit.input_mode == InputMode.BATCH:
|
||||
image_list = []
|
||||
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
|
||||
batch_image_files = shared.listfiles(unit.batch_image_dir)
|
||||
@ -339,7 +351,7 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
break
|
||||
|
||||
if has_high_res_fix:
|
||||
hr_option = HiResFixOption.from_value(unit.hr_option)
|
||||
hr_option = unit.hr_option
|
||||
else:
|
||||
hr_option = HiResFixOption.BOTH
|
||||
|
||||
@ -430,7 +442,7 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
)
|
||||
|
||||
if has_high_res_fix:
|
||||
hr_option = HiResFixOption.from_value(unit.hr_option)
|
||||
hr_option = unit.hr_option
|
||||
else:
|
||||
hr_option = HiResFixOption.BOTH
|
||||
|
||||
@ -494,6 +506,11 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
params.model.positive_advanced_weighting = soft_weighting.copy()
|
||||
params.model.negative_advanced_weighting = soft_weighting.copy()
|
||||
|
||||
if unit.advanced_weighting is not None:
|
||||
if params.model.positive_advanced_weighting is None:
|
||||
logger.warn("advanced_weighting overwrite control_mode")
|
||||
params.model.positive_advanced_weighting = unit.advanced_weighting
|
||||
|
||||
cond, mask = params.preprocessor.process_before_every_sampling(p, cond, mask, *args, **kwargs)
|
||||
|
||||
params.model.advanced_mask_weighting = mask
|
||||
|
@ -4,6 +4,7 @@ from .template import (
|
||||
APITestTemplate,
|
||||
girl_img,
|
||||
mask_img,
|
||||
portrait_imgs,
|
||||
disable_in_cq,
|
||||
get_model,
|
||||
)
|
||||
@ -169,3 +170,60 @@ def test_lama_outpaint():
|
||||
"resize_mode": "Resize and Fill", # OUTER_FIT
|
||||
},
|
||||
).exec()
|
||||
|
||||
|
||||
@disable_in_cq
|
||||
def test_instant_id_sdxl():
|
||||
assert len(portrait_imgs) > 0
|
||||
assert APITestTemplate(
|
||||
"instant_id_sdxl",
|
||||
"txt2img",
|
||||
payload_overrides={
|
||||
"width": 1000,
|
||||
"height": 1000,
|
||||
"prompt": "1girl, red background",
|
||||
},
|
||||
unit_overrides=[
|
||||
dict(
|
||||
image=portrait_imgs[0],
|
||||
model=get_model("ip-adapter_instant_id_sdxl"),
|
||||
module="InsightFace (InstantID)",
|
||||
),
|
||||
dict(
|
||||
image=portrait_imgs[1],
|
||||
model=get_model("control_instant_id_sdxl"),
|
||||
module="instant_id_face_keypoints",
|
||||
),
|
||||
],
|
||||
).exec()
|
||||
|
||||
|
||||
@disable_in_cq
|
||||
def test_instant_id_sdxl_multiple_units():
|
||||
assert len(portrait_imgs) > 0
|
||||
assert APITestTemplate(
|
||||
"instant_id_sdxl_multiple_units",
|
||||
"txt2img",
|
||||
payload_overrides={
|
||||
"width": 1000,
|
||||
"height": 1000,
|
||||
"prompt": "1girl, red background",
|
||||
},
|
||||
unit_overrides=[
|
||||
dict(
|
||||
image=portrait_imgs[0],
|
||||
model=get_model("ip-adapter_instant_id_sdxl"),
|
||||
module="InsightFace (InstantID)",
|
||||
),
|
||||
dict(
|
||||
image=portrait_imgs[1],
|
||||
model=get_model("control_instant_id_sdxl"),
|
||||
module="instant_id_face_keypoints",
|
||||
),
|
||||
dict(
|
||||
image=portrait_imgs[1],
|
||||
model=get_model("diffusers_xl_canny"),
|
||||
module="canny",
|
||||
),
|
||||
],
|
||||
).exec()
|
||||
|
@ -0,0 +1,43 @@
|
||||
from .template import (
|
||||
APITestTemplate,
|
||||
realistic_girl_face_img,
|
||||
disable_in_cq,
|
||||
get_model,
|
||||
)
|
||||
|
||||
|
||||
@disable_in_cq
|
||||
def test_ipadapter_advanced_weighting():
|
||||
weights = [0.0] * 16 # 16 weights for SD15 / 11 weights for SDXL
|
||||
# SD15 composition
|
||||
weights[4] = 0.25
|
||||
weights[5] = 1.0
|
||||
|
||||
APITestTemplate(
|
||||
"test_ipadapter_advanced_weighting",
|
||||
"txt2img",
|
||||
payload_overrides={
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
},
|
||||
unit_overrides={
|
||||
"image": realistic_girl_face_img,
|
||||
"module": "CLIP-ViT-H (IPAdapter)",
|
||||
"model": get_model("ip-adapter_sd15"),
|
||||
"advanced_weighting": weights,
|
||||
},
|
||||
).exec()
|
||||
|
||||
APITestTemplate(
|
||||
"test_ipadapter_advanced_weighting_ref",
|
||||
"txt2img",
|
||||
payload_overrides={
|
||||
"width": 512,
|
||||
"height": 512,
|
||||
},
|
||||
unit_overrides={
|
||||
"image": realistic_girl_face_img,
|
||||
"module": "CLIP-ViT-H (IPAdapter)",
|
||||
"model": get_model("ip-adapter_sd15"),
|
||||
},
|
||||
).exec()
|
@ -248,13 +248,13 @@ def get_model(model_name: str) -> str:
|
||||
|
||||
|
||||
default_unit = {
|
||||
"control_mode": 0,
|
||||
"control_mode": "Balanced",
|
||||
"enabled": True,
|
||||
"guidance_end": 1,
|
||||
"guidance_start": 0,
|
||||
"pixel_perfect": True,
|
||||
"processor_res": 512,
|
||||
"resize_mode": 1,
|
||||
"resize_mode": "Crop and Resize",
|
||||
"threshold_a": 64,
|
||||
"threshold_b": 64,
|
||||
"weight": 1,
|
||||
|
@ -1,6 +1,12 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import scripts
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from typing import Any
|
||||
from functools import partial
|
||||
|
||||
from modules import script_callbacks, scripts
|
||||
from ldm_patched.contrib.external_freelunch import FreeU_V2
|
||||
|
||||
|
||||
@ -68,6 +74,18 @@ class FreeUForForge(scripts.Script):
|
||||
|
||||
freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2 = script_args
|
||||
|
||||
xyz = getattr(p, "_freeu_xyz", {})
|
||||
if "freeu_enabled" in xyz:
|
||||
freeu_enabled = xyz["freeu_enabled"] == "True"
|
||||
if "freeu_b1" in xyz:
|
||||
freeu_b1 = xyz["freeu_b1"]
|
||||
if "freeu_b2" in xyz:
|
||||
freeu_b2 = xyz["freeu_b2"]
|
||||
if "freeu_s1" in xyz:
|
||||
freeu_s1 = xyz["freeu_s1"]
|
||||
if "freeu_s2" in xyz:
|
||||
freeu_s2 = xyz["freeu_s2"]
|
||||
|
||||
if not freeu_enabled:
|
||||
return
|
||||
|
||||
@ -89,3 +107,62 @@ class FreeUForForge(scripts.Script):
|
||||
))
|
||||
|
||||
return
|
||||
|
||||
def set_value(p, x: Any, xs: Any, *, field: str):
|
||||
if not hasattr(p, "_freeu_xyz"):
|
||||
p._freeu_xyz = {}
|
||||
p._freeu_xyz[field] = x
|
||||
|
||||
def make_axis_on_xyz_grid():
|
||||
xyz_grid = None
|
||||
for script in scripts.scripts_data:
|
||||
if script.script_class.__module__ == "xyz_grid.py":
|
||||
xyz_grid = script.module
|
||||
break
|
||||
|
||||
if xyz_grid is None:
|
||||
return
|
||||
|
||||
axis = [
|
||||
xyz_grid.AxisOption(
|
||||
"FreeU Enabled",
|
||||
str,
|
||||
partial(set_value, field="freeu_enabled"),
|
||||
choices=lambda: ["True", "False"]
|
||||
),
|
||||
xyz_grid.AxisOption(
|
||||
"FreeU B1",
|
||||
float,
|
||||
partial(set_value, field="freeu_b1"),
|
||||
),
|
||||
xyz_grid.AxisOption(
|
||||
"FreeU B2",
|
||||
float,
|
||||
partial(set_value, field="freeu_b2"),
|
||||
),
|
||||
xyz_grid.AxisOption(
|
||||
"FreeU S1",
|
||||
float,
|
||||
partial(set_value, field="freeu_s1"),
|
||||
),
|
||||
xyz_grid.AxisOption(
|
||||
"FreeU S2",
|
||||
float,
|
||||
partial(set_value, field="freeu_s2"),
|
||||
),
|
||||
]
|
||||
|
||||
if not any(x.label.startswith("FreeU") for x in xyz_grid.axis_options):
|
||||
xyz_grid.axis_options.extend(axis)
|
||||
|
||||
def on_before_ui():
|
||||
try:
|
||||
make_axis_on_xyz_grid()
|
||||
except Exception:
|
||||
error = traceback.format_exc()
|
||||
print(
|
||||
f"[-] FreeU Integrated: xyz_grid error:\n{error}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
script_callbacks.on_before_ui(on_before_ui)
|
||||
|
@ -7,6 +7,7 @@ import math
|
||||
|
||||
import ldm_patched.modules.utils
|
||||
import ldm_patched.modules.model_management
|
||||
from ldm_patched.modules.controlnet import ControlNet
|
||||
from ldm_patched.modules.clip_vision import clip_preprocess
|
||||
from ldm_patched.ldm.modules.attention import optimized_attention
|
||||
from ldm_patched.utils import path_utils as folder_paths
|
||||
@ -402,7 +403,7 @@ class CrossAttentionPatch:
|
||||
batch_prompt = b // len(cond_or_uncond)
|
||||
out = optimized_attention(q, k, v, extra_options["n_heads"])
|
||||
_, _, lh, lw = extra_options["original_shape"]
|
||||
|
||||
|
||||
for weight, cond, uncond, ipadapter, mask, weight_type, sigma_start, sigma_end, unfold_batch in zip(self.weights, self.conds, self.unconds, self.ipadapters, self.masks, self.weight_type, self.sigma_start, self.sigma_end, self.unfold_batch):
|
||||
if sigma > sigma_start or sigma < sigma_end:
|
||||
continue
|
||||
@ -465,8 +466,18 @@ class CrossAttentionPatch:
|
||||
ip_v = ip_v_offset + ip_v_mean * W
|
||||
|
||||
out_ip = optimized_attention(q, ip_k.to(org_dtype), ip_v.to(org_dtype), extra_options["n_heads"])
|
||||
if weight_type.startswith("original"):
|
||||
out_ip = out_ip * weight
|
||||
|
||||
if weight_type == "original":
|
||||
assert isinstance(weight, (float, int))
|
||||
weight = weight
|
||||
elif weight_type == "advanced":
|
||||
assert isinstance(weight, list)
|
||||
transformer_index: int = extra_options["transformer_index"]
|
||||
assert transformer_index < len(weight)
|
||||
weight = weight[transformer_index]
|
||||
else:
|
||||
weight = 1.0
|
||||
out_ip = out_ip * weight
|
||||
|
||||
if mask is not None:
|
||||
# TODO: needs checking
|
||||
@ -732,7 +743,7 @@ class IPAdapterApply:
|
||||
is_faceid=self.is_faceid,
|
||||
is_instant_id=self.is_instant_id
|
||||
)
|
||||
|
||||
|
||||
self.ipadapter.to(self.device, dtype=self.dtype)
|
||||
|
||||
if self.is_instant_id:
|
||||
@ -749,13 +760,27 @@ class IPAdapterApply:
|
||||
work_model = model.clone()
|
||||
|
||||
if self.is_instant_id:
|
||||
def modifier(cnet, x_noisy, t, cond, batched_number):
|
||||
def instant_id_modifier(cnet: ControlNet, x_noisy, t, cond, batched_number):
|
||||
"""Overwrites crossattn inputs to InstantID ControlNet with ipadapter image embeds.
|
||||
|
||||
TODO: There can be multiple pairs of InstantID (ipadapter/controlnet) to control
|
||||
rendering of multiple faces on canvas. We need to find a way to pair them. Currently,
|
||||
the modifier is unconditionally applied to all instant id ControlNet units.
|
||||
"""
|
||||
if (
|
||||
not isinstance(cnet, ControlNet) or
|
||||
# model_file_name is None for Control LoRA.
|
||||
cnet.control_model.model_file_name is None or
|
||||
"instant_id" not in cnet.control_model.model_file_name.lower()
|
||||
):
|
||||
return x_noisy, t, cond, batched_number
|
||||
|
||||
cond_mark = cond['transformer_options']['cond_mark'][:, None, None].to(cond['c_crossattn']) # cond is 0
|
||||
c_crossattn = image_prompt_embeds * (1.0 - cond_mark) + uncond_image_prompt_embeds * cond_mark
|
||||
cond['c_crossattn'] = c_crossattn
|
||||
return x_noisy, t, cond, batched_number
|
||||
|
||||
work_model.add_controlnet_conditioning_modifier(modifier)
|
||||
work_model.add_controlnet_conditioning_modifier(instant_id_modifier)
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.to(self.device)
|
||||
|
@ -143,11 +143,17 @@ class IPAdapterPatcher(ControlModelPatcher):
|
||||
|
||||
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
|
||||
unet = process.sd_model.forge_objects.unet
|
||||
if self.positive_advanced_weighting is None:
|
||||
weight = self.strength
|
||||
cond["weight_type"] = "original"
|
||||
else:
|
||||
weight = self.positive_advanced_weighting
|
||||
cond["weight_type"] = "advanced"
|
||||
|
||||
unet = opIPAdapterApply(
|
||||
ipadapter=self.ip_adapter,
|
||||
model=unet,
|
||||
weight=self.strength,
|
||||
weight=weight,
|
||||
start_at=self.start_percent,
|
||||
end_at=self.end_percent,
|
||||
faceid_v2=self.faceid_v2,
|
||||
|
@ -1,7 +1,7 @@
|
||||
<div>
|
||||
<a href="{api_docs}">API</a>
|
||||
•
|
||||
<a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a>
|
||||
<a href="https://github.com/lllyasviel/stable-diffusion-webui-forge">Github</a>
|
||||
•
|
||||
<a href="https://gradio.app">Gradio</a>
|
||||
•
|
||||
|
@ -4,6 +4,7 @@
|
||||
import torch
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
from typing import Optional
|
||||
|
||||
from ldm_patched.ldm.modules.diffusionmodules.util import (
|
||||
zero_module,
|
||||
@ -54,9 +55,12 @@ class ControlNet(nn.Module):
|
||||
transformer_depth_output=None,
|
||||
device=None,
|
||||
operations=ldm_patched.modules.ops.disable_weight_init,
|
||||
model_file_name: Optional[str] = None, # Name of model file.
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_file_name = model_file_name
|
||||
|
||||
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
||||
if use_spatial_transformer:
|
||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||
|
@ -6,6 +6,7 @@ import math
|
||||
|
||||
from scipy import integrate
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torchsde
|
||||
from tqdm.auto import trange, tqdm
|
||||
@ -38,6 +39,35 @@ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
|
||||
sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
|
||||
return append_zero(sigmas)
|
||||
|
||||
# align your steps
|
||||
def get_sigmas_ays(n, sigma_min, sigma_max, is_sdxl=False, device='cpu'):
|
||||
# https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
|
||||
def loglinear_interp(t_steps, num_steps):
|
||||
"""
|
||||
Performs log-linear interpolation of a given array of decreasing numbers.
|
||||
"""
|
||||
xs = torch.linspace(0, 1, len(t_steps))
|
||||
ys = torch.log(torch.tensor(t_steps[::-1]))
|
||||
|
||||
new_xs = torch.linspace(0, 1, num_steps)
|
||||
new_ys = np.interp(new_xs, xs, ys)
|
||||
|
||||
interped_ys = torch.exp(torch.tensor(new_ys)).numpy()[::-1].copy()
|
||||
return interped_ys
|
||||
|
||||
if is_sdxl:
|
||||
sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.029]
|
||||
else:
|
||||
# Default to SD 1.5 sigmas.
|
||||
sigmas = [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029]
|
||||
|
||||
if n != len(sigmas):
|
||||
sigmas = np.append(loglinear_interp(sigmas, n), [0.0])
|
||||
else:
|
||||
sigmas.append(0.0)
|
||||
|
||||
return torch.FloatTensor(sigmas).to(device)
|
||||
|
||||
|
||||
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
||||
"""Constructs a continuous VP noise schedule."""
|
||||
|
@ -487,7 +487,7 @@ def load_controlnet(ckpt_path, model=None):
|
||||
controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||
control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config)
|
||||
control_model = ldm_patched.controlnet.cldm.ControlNet(model_file_name=ckpt_path, **controlnet_config)
|
||||
|
||||
if pth:
|
||||
if 'difference' in controlnet_data:
|
||||
|
@ -107,9 +107,11 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
|
||||
|
||||
sigma_min = sampling_settings.get("sigma_min", 0.002)
|
||||
sigma_max = sampling_settings.get("sigma_max", 120.0)
|
||||
self.set_sigma_range(sigma_min, sigma_max)
|
||||
sigma_data = sampling_settings.get("sigma_data", 1.0)
|
||||
self.set_sigma_range(sigma_min, sigma_max, sigma_data)
|
||||
|
||||
def set_sigma_range(self, sigma_min, sigma_max):
|
||||
def set_sigma_range(self, sigma_min, sigma_max, sigma_data):
|
||||
self.sigma_data = sigma_data
|
||||
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp()
|
||||
|
||||
self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers
|
||||
|
@ -662,14 +662,16 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
||||
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
return model.process_latent_out(samples.to(torch.float32))
|
||||
|
||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
|
||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "ays"]
|
||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
def calculate_sigmas_scheduler(model, scheduler_name, steps):
|
||||
def calculate_sigmas_scheduler(model, scheduler_name, steps, is_sdxl=False):
|
||||
if scheduler_name == "karras":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
|
||||
elif scheduler_name == "exponential":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
|
||||
elif scheduler_name == "ays":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_ays(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max), is_sdxl=is_sdxl)
|
||||
elif scheduler_name == "normal":
|
||||
sigmas = normal_scheduler(model, steps)
|
||||
elif scheduler_name == "simple":
|
||||
|
@ -169,6 +169,11 @@ class SDXL(supported_models_base.BASE):
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
if "v_pred" in state_dict:
|
||||
return model_base.ModelType.V_PREDICTION
|
||||
elif "edm_vpred.sigma_max" in state_dict:
|
||||
self.sampling_settings["sigma_max"] = round(float(state_dict["edm_vpred.sigma_max"].item()),3)
|
||||
if "edm_vpred.sigma_min" in state_dict:
|
||||
self.sampling_settings["sigma_min"] = round(float(state_dict["edm_vpred.sigma_min"].item()),3)
|
||||
return model_base.ModelType.V_PREDICTION_EDM
|
||||
else:
|
||||
return model_base.ModelType.EPS
|
||||
|
||||
|
@ -516,7 +516,7 @@ def configure_forge_reference_checkout(a1111_home: Path):
|
||||
ModelRef(arg_name="--vae-dir", relative_path="models/VAE"),
|
||||
ModelRef(arg_name="--hypernetwork-dir", relative_path="models/hypernetworks"),
|
||||
ModelRef(arg_name="--embeddings-dir", relative_path="embeddings"),
|
||||
ModelRef(arg_name="--lora-dir", relative_path="models/lora"),
|
||||
ModelRef(arg_name="--lora-dir", relative_path="models/Lora"),
|
||||
# Ref A1111 need to have sd-webui-controlnet installed.
|
||||
ModelRef(arg_name="--controlnet-dir", relative_path="models/ControlNet"),
|
||||
ModelRef(arg_name="--controlnet-preprocessor-models-dir", relative_path="extensions/sd-webui-controlnet/annotator/downloads"),
|
||||
|
@ -10,6 +10,7 @@ class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler):
|
||||
self.sampler_name = sampler_name
|
||||
self.scheduler_name = scheduler_name
|
||||
self.unet = sd_model.forge_objects.unet
|
||||
self.model = sd_model
|
||||
|
||||
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
|
||||
super().__init__(sampler_function, sd_model, None)
|
||||
@ -20,7 +21,7 @@ class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler):
|
||||
sigmas = self.unet.model.model_sampling.sigma(timesteps)
|
||||
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
||||
else:
|
||||
sigmas = calculate_sigmas_scheduler(self.unet.model, self.scheduler_name, steps)
|
||||
sigmas = calculate_sigmas_scheduler(self.unet.model, self.scheduler_name, steps, is_sdxl=getattr(self.model, "is_sdxl", False))
|
||||
return sigmas.to(self.unet.load_device)
|
||||
|
||||
|
||||
@ -34,9 +35,13 @@ def build_constructor(sampler_name, scheduler_name):
|
||||
samplers_data_alter = [
|
||||
sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm', scheduler_name='normal'), ['ddpm'], {}),
|
||||
sd_samplers_common.SamplerData('DDPM Karras', build_constructor(sampler_name='ddpm', scheduler_name='karras'), ['ddpm_karras'], {}),
|
||||
sd_samplers_common.SamplerData('Euler AYS', build_constructor(sampler_name='euler', scheduler_name='ays'), ['euler_ays'], {}),
|
||||
sd_samplers_common.SamplerData('Euler A Turbo', build_constructor(sampler_name='euler_ancestral', scheduler_name='turbo'), ['euler_ancestral_turbo'], {}),
|
||||
sd_samplers_common.SamplerData('Euler A AYS', build_constructor(sampler_name='euler_ancestral', scheduler_name='ays'), ['euler_ancestral_ays'], {}),
|
||||
sd_samplers_common.SamplerData('DPM++ 2M Turbo', build_constructor(sampler_name='dpmpp_2m', scheduler_name='turbo'), ['dpmpp_2m_turbo'], {}),
|
||||
sd_samplers_common.SamplerData('DPM++ 2M AYS', build_constructor(sampler_name='dpmpp_2m', scheduler_name='ays'), ['dpmpp_2m_ays'], {}),
|
||||
sd_samplers_common.SamplerData('DPM++ 2M SDE Turbo', build_constructor(sampler_name='dpmpp_2m_sde', scheduler_name='turbo'), ['dpmpp_2m_sde_turbo'], {}),
|
||||
sd_samplers_common.SamplerData('DPM++ 2M SDE AYS', build_constructor(sampler_name='dpmpp_2m_sde', scheduler_name='ays'), ['dpmpp_2m_sde_ays'], {}),
|
||||
sd_samplers_common.SamplerData('LCM Karras', build_constructor(sampler_name='lcm', scheduler_name='karras'), ['lcm_karras'], {}),
|
||||
sd_samplers_common.SamplerData('Euler SGMUniform', build_constructor(sampler_name='euler', scheduler_name='sgm_uniform'), ['euler_sgm_uniform'], {}),
|
||||
sd_samplers_common.SamplerData('Euler A SGMUniform', build_constructor(sampler_name='euler_ancestral', scheduler_name='sgm_uniform'), ['euler_ancestral_sgm_uniform'], {}),
|
||||
|
@ -113,7 +113,7 @@ class ControlNetPatcher(ControlModelPatcher):
|
||||
controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||
control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config)
|
||||
control_model = ldm_patched.controlnet.cldm.ControlNet(model_file_name=ckpt_path, **controlnet_config)
|
||||
|
||||
if pth:
|
||||
if 'difference' in controlnet_data:
|
||||
|
Loading…
Reference in New Issue
Block a user