Compare commits

..

12 Commits

Author SHA1 Message Date
51cf69328e Merge branch 'dev2' into mysd
Some checks failed
Linter / ruff (push) Has been cancelled
2024-06-05 17:22:55 +03:00
GavChap
77bdb9208d
Add cos xl (#710)
* Add V_PREDICTION_EDM handing for CosXL models

Add V_PREDICTION_EDM handing for CosXL models

* Get correct sigmas from checkpoint.

* Round to 3 sig digs in order to make compatible with comfy implementation

* Add sigma data like ComfyUI has

---------

Co-authored-by: Gavin Chapman <gchapman@MAINPC>
2024-05-23 16:42:56 -04:00
altoiddealer
7eb5cbad01
Restore '/controlnet/control_types' API endpoint (#713)
Restores the '/controlnet/control_types' API endpoint, which is immensely useful for anyone using ControlNet via the API
2024-05-23 16:39:26 -04:00
Chenlei Hu
eb1e12b0dc
Add advanced weighting support (#754) 2024-05-23 13:11:11 -04:00
Chakib Benziane
49c3a080b5
implement align your steps scheduler (#726)
Signed-off-by: blob42 <contact@blob42.xyz>
2024-05-22 15:11:10 -04:00
Chenlei Hu
62e60ad403
Fix import of enum (#755) 2024-05-21 20:55:24 -04:00
Chenlei Hu
66845160de
Fix enum (#706)
* Fix enum conversion from string

* More fixes
2024-05-02 19:09:53 -04:00
bluelovers
b996316b20
Update footer.html (#591) 2024-05-02 18:31:37 -04:00
Chenlei Hu
d5cc799eec
Only feed image embed to instant id ControlNet unit (#309)
* Only feed image embed to instant id ControlNet units

* Add tests

* Add more tests
2024-05-02 18:30:54 -04:00
Qun
b55f9e7212
fix lora path (#509) 2024-05-02 18:29:25 -04:00
sJrx233S7UMo
bafd067777
Add xyz plot feature to "sd_forge_freeu" (#417)
* Add xyz plot feature to "sd_forge_freeu"

* Update extensions-builtin/sd_forge_freeu/scripts/forge_freeu.py

Co-authored-by: catboxanon <122327233+catboxanon@users.noreply.github.com>

* Update extensions-builtin/sd_forge_freeu/scripts/forge_freeu.py

Co-authored-by: catboxanon <122327233+catboxanon@users.noreply.github.com>

* Update extensions-builtin/sd_forge_freeu/scripts/forge_freeu.py

Co-authored-by: catboxanon <122327233+catboxanon@users.noreply.github.com>

* Update extensions-builtin/sd_forge_freeu/scripts/forge_freeu.py

Co-authored-by: catboxanon <122327233+catboxanon@users.noreply.github.com>

* Update extensions-builtin/sd_forge_freeu/scripts/forge_freeu.py

Co-authored-by: catboxanon <122327233+catboxanon@users.noreply.github.com>

---------

Co-authored-by: sJrx233S7UMo <sJrx233S7UMo@gmail.com>
Co-authored-by: catboxanon <122327233+catboxanon@users.noreply.github.com>
2024-05-02 18:28:24 -04:00
Chenlei Hu
61a2a9d342
Add back ControlNet multi-inputs upload tab (#264)
* Refactor gallery

* Add back multi-inputs

* nit
2024-05-02 18:25:42 -04:00
23 changed files with 553 additions and 116 deletions

View File

@ -11,6 +11,8 @@ from .global_state import (
get_all_preprocessor_names, get_all_preprocessor_names,
get_all_controlnet_names, get_all_controlnet_names,
get_preprocessor, get_preprocessor,
get_all_preprocessor_tags,
select_control_type,
) )
from .utils import judge_image_type from .utils import judge_image_type
from .logging import logger from .logging import logger
@ -53,6 +55,30 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
# "module_detail": external_code.get_modules_detail(alias_names), # "module_detail": external_code.get_modules_detail(alias_names),
} }
@app.get("/controlnet/control_types")
async def control_types():
def format_control_type(
filtered_preprocessor_list,
filtered_model_list,
default_option,
default_model,
):
control_dict = {
"module_list": filtered_preprocessor_list,
"model_list": filtered_model_list,
"default_option": default_option,
"default_model": default_model,
}
return control_dict
return {
"control_types": {
control_type: format_control_type(*select_control_type(control_type))
for control_type in get_all_preprocessor_tags()
}
}
@app.post("/controlnet/detect") @app.post("/controlnet/detect")
async def detect( async def detect(
controlnet_module: str = Body("none", title="Controlnet Module"), controlnet_module: str = Body("none", title="Controlnet Module"),

View File

@ -11,16 +11,17 @@ from lib_controlnet import (
global_state, global_state,
external_code, external_code,
) )
from lib_controlnet.external_code import ControlNetUnit
from lib_controlnet.logging import logger from lib_controlnet.logging import logger
from lib_controlnet.controlnet_ui.openpose_editor import OpenposeEditor from lib_controlnet.controlnet_ui.openpose_editor import OpenposeEditor
from lib_controlnet.controlnet_ui.preset import ControlNetPresetUI from lib_controlnet.controlnet_ui.preset import ControlNetPresetUI
from lib_controlnet.controlnet_ui.tool_button import ToolButton from lib_controlnet.controlnet_ui.tool_button import ToolButton
from lib_controlnet.controlnet_ui.photopea import Photopea from lib_controlnet.controlnet_ui.photopea import Photopea
from lib_controlnet.controlnet_ui.multi_inputs_gallery import MultiInputsGallery
from lib_controlnet.enums import InputMode, HiResFixOption from lib_controlnet.enums import InputMode, HiResFixOption
from modules import shared, script_callbacks from modules import shared, script_callbacks
from modules.ui_components import FormRow from modules.ui_components import FormRow
from modules_forge.forge_util import HWC3 from modules_forge.forge_util import HWC3
from lib_controlnet.external_code import UiControlNetUnit
@dataclass @dataclass
@ -171,10 +172,10 @@ class ControlNetUiGroup(object):
self.webcam_mirrored = False self.webcam_mirrored = False
# Note: All gradio elements declared in `render` will be defined as member variable. # Note: All gradio elements declared in `render` will be defined as member variable.
# Update counter to trigger a force update of UiControlNetUnit. # Update counter to trigger a force update of ControlNetUnit.
# dummy_gradio_update_trigger is useful when a field with no event subscriber available changes. # dummy_gradio_update_trigger is useful when a field with no event subscriber available changes.
# e.g. gr.Gallery, gr.State, etc. After an update to gr.State / gr.Gallery, please increment # e.g. gr.Gallery, gr.State, etc. After an update to gr.State / gr.Gallery, please increment
# this counter to trigger a sync update of UiControlNetUnit. # this counter to trigger a sync update of ControlNetUnit.
self.dummy_gradio_update_trigger = None self.dummy_gradio_update_trigger = None
self.enabled = None self.enabled = None
self.upload_tab = None self.upload_tab = None
@ -185,10 +186,11 @@ class ControlNetUiGroup(object):
self.mask_image = None self.mask_image = None
self.batch_tab = None self.batch_tab = None
self.batch_image_dir = None self.batch_image_dir = None
self.merge_tab = None self.batch_upload_tab = None
self.batch_input_gallery = None self.batch_input_gallery = None
self.merge_upload_button = None self.batch_mask_gallery = None
self.merge_clear_button = None self.multi_inputs_upload_tab = None
self.multi_inputs_input_gallery = None
self.create_canvas = None self.create_canvas = None
self.canvas_width = None self.canvas_width = None
self.canvas_height = None self.canvas_height = None
@ -225,6 +227,7 @@ class ControlNetUiGroup(object):
self.hr_option = None self.hr_option = None
self.batch_image_dir_state = None self.batch_image_dir_state = None
self.output_dir_state = None self.output_dir_state = None
self.advanced_weighting = gr.State(None)
# Internal states for UI state pasting. # Internal states for UI state pasting.
self.prevent_next_n_module_update = 0 self.prevent_next_n_module_update = 0
@ -331,31 +334,17 @@ class ControlNetUiGroup(object):
visible=False, visible=False,
) )
with gr.Tab(label="Batch Upload") as self.merge_tab: with gr.Tab(label="Batch Upload") as self.batch_upload_tab:
with gr.Row(): with gr.Row():
with gr.Column(): self.batch_input_gallery = MultiInputsGallery()
self.batch_input_gallery = gr.Gallery( self.batch_mask_gallery = MultiInputsGallery(
columns=[4], rows=[2], object_fit="contain", height="auto", label="Images" visible=False,
) elem_classes=["cnet-mask-gallery-group"]
with gr.Row(): )
self.merge_upload_button = gr.UploadButton(
"Upload Images", with gr.Tab(label="Multi-Inputs") as self.multi_inputs_upload_tab:
file_types=["image"], with gr.Row():
file_count="multiple", self.multi_inputs_gallery = MultiInputsGallery()
)
self.merge_clear_button = gr.Button("Clear Images")
with gr.Group(visible=False, elem_classes=["cnet-mask-gallery-group"]) as self.batch_mask_gallery_group:
with gr.Column():
self.batch_mask_gallery = gr.Gallery(
columns=[4], rows=[2], object_fit="contain", height="auto", label="Masks"
)
with gr.Row():
self.mask_merge_upload_button = gr.UploadButton(
"Upload Masks",
file_types=["image"],
file_count="multiple",
)
self.mask_merge_clear_button = gr.Button("Clear Masks")
if self.photopea: if self.photopea:
self.photopea.attach_photopea_output(self.generated_image) self.photopea.attach_photopea_output(self.generated_image)
@ -600,8 +589,9 @@ class ControlNetUiGroup(object):
self.use_preview_as_input, self.use_preview_as_input,
self.batch_image_dir, self.batch_image_dir,
self.batch_mask_dir, self.batch_mask_dir,
self.batch_input_gallery, self.batch_input_gallery.input_gallery,
self.batch_mask_gallery, self.batch_mask_gallery.input_gallery,
self.multi_inputs_gallery.input_gallery,
self.generated_image, self.generated_image,
self.mask_image, self.mask_image,
self.hr_option, self.hr_option,
@ -618,9 +608,16 @@ class ControlNetUiGroup(object):
self.guidance_end, self.guidance_end,
self.pixel_perfect, self.pixel_perfect,
self.control_mode, self.control_mode,
self.advanced_weighting,
) )
unit = gr.State(self.default_unit) unit = gr.State(self.default_unit)
def create_unit(*args):
return ControlNetUnit.from_dict({
k: v
for k, v in zip(vars(ControlNetUnit()).keys(), args)
})
for comp in unit_args + (self.dummy_gradio_update_trigger,): for comp in unit_args + (self.dummy_gradio_update_trigger,):
event_subscribers = [] event_subscribers = []
if hasattr(comp, "edit"): if hasattr(comp, "edit"):
@ -637,7 +634,7 @@ class ControlNetUiGroup(object):
for event_subscriber in event_subscribers: for event_subscriber in event_subscribers:
event_subscriber( event_subscriber(
fn=UiControlNetUnit, inputs=list(unit_args), outputs=unit fn=create_unit, inputs=list(unit_args), outputs=unit
) )
( (
@ -645,7 +642,7 @@ class ControlNetUiGroup(object):
if self.is_img2img if self.is_img2img
else ControlNetUiGroup.a1111_context.txt2img_submit_button else ControlNetUiGroup.a1111_context.txt2img_submit_button
).click( ).click(
fn=UiControlNetUnit, fn=create_unit,
inputs=list(unit_args), inputs=list(unit_args),
outputs=unit, outputs=unit,
queue=False, queue=False,
@ -981,20 +978,41 @@ class ControlNetUiGroup(object):
"""Controls whether the upload mask input should be visible.""" """Controls whether the upload mask input should be visible."""
def on_checkbox_click(checked: bool, canvas_height: int, canvas_width: int): def on_checkbox_click(checked: bool, canvas_height: int, canvas_width: int):
if not checked: if not checked:
# Clear mask_image if unchecked. # Clear mask inputs if unchecked.
return gr.update(visible=False), gr.update(value=None), gr.update(value=None, visible=False), \ return (
gr.update(visible=False), gr.update(value=None) # Single mask upload.
gr.update(visible=False),
gr.update(value=None),
# Batch mask upload dir.
gr.update(value=None, visible=False),
# Multi mask upload gallery.
gr.update(visible=False),
gr.update(value=None)
)
else: else:
# Init an empty canvas the same size as the generation target. # Init an empty canvas the same size as the generation target.
empty_canvas = np.zeros(shape=(canvas_height, canvas_width, 3), dtype=np.uint8) empty_canvas = np.zeros(shape=(canvas_height, canvas_width, 3), dtype=np.uint8)
return gr.update(visible=True), gr.update(value=empty_canvas), gr.update(visible=True), \ return (
gr.update(visible=True), gr.update() # Single mask upload.
gr.update(visible=True),
gr.update(value=empty_canvas),
# Batch mask upload dir.
gr.update(visible=True),
# Multi mask upload gallery.
gr.update(visible=True),
gr.update(),
)
self.mask_upload.change( self.mask_upload.change(
fn=on_checkbox_click, fn=on_checkbox_click,
inputs=[self.mask_upload, self.height_slider, self.width_slider], inputs=[self.mask_upload, self.height_slider, self.width_slider],
outputs=[self.mask_image_group, self.mask_image, self.batch_mask_dir, outputs=[
self.batch_mask_gallery_group, self.batch_mask_gallery], self.mask_image_group,
self.mask_image,
self.batch_mask_dir,
self.batch_mask_gallery.group,
self.batch_mask_gallery.input_gallery,
],
show_progress=False, show_progress=False,
) )
@ -1079,51 +1097,14 @@ class ControlNetUiGroup(object):
def register_multi_images_upload(self): def register_multi_images_upload(self):
"""Register callbacks on merge tab multiple images upload.""" """Register callbacks on merge tab multiple images upload."""
self.merge_clear_button.click( trigger_dict = dict(
fn=lambda: [], fn=lambda n: gr.update(value=n + 1),
inputs=[],
outputs=[self.batch_input_gallery],
).then(
fn=lambda x: gr.update(value=x + 1),
inputs=[self.dummy_gradio_update_trigger], inputs=[self.dummy_gradio_update_trigger],
outputs=[self.dummy_gradio_update_trigger], outputs=[self.dummy_gradio_update_trigger],
) )
self.mask_merge_clear_button.click( self.batch_input_gallery.register_callbacks(change_trigger=trigger_dict)
fn=lambda: [], self.batch_mask_gallery.register_callbacks(change_trigger=trigger_dict)
inputs=[], self.multi_inputs_gallery.register_callbacks(change_trigger=trigger_dict)
outputs=[self.batch_mask_gallery],
).then(
fn=lambda x: gr.update(value=x + 1),
inputs=[self.dummy_gradio_update_trigger],
outputs=[self.dummy_gradio_update_trigger],
)
def upload_file(files, current_files):
return {file_d["name"] for file_d in current_files} | {
file.name for file in files
}
self.merge_upload_button.upload(
upload_file,
inputs=[self.merge_upload_button, self.batch_input_gallery],
outputs=[self.batch_input_gallery],
queue=False,
).then(
fn=lambda x: gr.update(value=x + 1),
inputs=[self.dummy_gradio_update_trigger],
outputs=[self.dummy_gradio_update_trigger],
)
self.mask_merge_upload_button.upload(
upload_file,
inputs=[self.mask_merge_upload_button, self.batch_mask_gallery],
outputs=[self.batch_mask_gallery],
queue=False,
).then(
fn=lambda x: gr.update(value=x + 1),
inputs=[self.dummy_gradio_update_trigger],
outputs=[self.dummy_gradio_update_trigger],
)
return
def register_core_callbacks(self): def register_core_callbacks(self):
"""Register core callbacks that only involves gradio components defined """Register core callbacks that only involves gradio components defined
@ -1213,7 +1194,8 @@ class ControlNetUiGroup(object):
for input_tab, fn in ( for input_tab, fn in (
(ui_group.upload_tab, simple_fn), (ui_group.upload_tab, simple_fn),
(ui_group.batch_tab, batch_fn), (ui_group.batch_tab, batch_fn),
(ui_group.merge_tab, merge_fn), (ui_group.batch_upload_tab, batch_fn),
(ui_group.multi_inputs_upload_tab, merge_fn),
): ):
# Sync input_mode. # Sync input_mode.
input_tab.select( input_tab.select(

View File

@ -0,0 +1,65 @@
import gradio as gr
from typing import Optional
class MultiInputsGallery:
"""A gallery object that accepts multiple input images."""
def __init__(self, row: int = 2, column: int = 4, **group_kwargs) -> None:
self.gallery_row_num = row
self.gallery_column_num = column
self.group_kwargs = group_kwargs
self.group = None
self.input_gallery = None
self.upload_button = None
self.clear_button = None
self.render()
def render(self):
with gr.Group(**self.group_kwargs) as self.group:
with gr.Column():
self.input_gallery = gr.Gallery(
columns=[self.gallery_column_num],
rows=[self.gallery_row_num],
object_fit="contain",
height="auto",
label="Images",
)
with gr.Row():
self.upload_button = gr.UploadButton(
"Upload Images",
file_types=["image"],
file_count="multiple",
)
self.clear_button = gr.Button("Clear Images")
def register_callbacks(self, change_trigger: Optional[dict] = None):
"""Register callbacks on multiple images upload.
Argument:
- change_trigger: An optional gradio callback param dict to be called
after gallery content change. This is necessary as gallery has no
event subscriber. If the state change of gallery needs to be observed,
the caller needs to pass a change trigger to observe the change.
"""
handle1 = self.clear_button.click(
fn=lambda: [],
inputs=[],
outputs=[self.input_gallery],
)
def upload_file(files, current_files):
return {file_d["name"] for file_d in current_files} | {
file.name for file in files
}
handle2 = self.upload_button.upload(
upload_file,
inputs=[self.upload_button, self.input_gallery],
outputs=[self.input_gallery],
queue=False,
)
if change_trigger:
for handle in (handle1, handle2):
handle.then(**change_trigger)

View File

@ -183,6 +183,8 @@ class ControlNetUnit:
batch_input_gallery: Optional[List[str]] = None batch_input_gallery: Optional[List[str]] = None
# Optional list of gallery masks for batch processing; defaults to None. # Optional list of gallery masks for batch processing; defaults to None.
batch_mask_gallery: Optional[List[str]] = None batch_mask_gallery: Optional[List[str]] = None
# Optional list of gallery images for multi-inputs; defaults to None.
multi_inputs_gallery: Optional[List[str]] = None
# Holds the preview image as a NumPy array; defaults to None. # Holds the preview image as a NumPy array; defaults to None.
generated_image: Optional[np.ndarray] = None generated_image: Optional[np.ndarray] = None
# ====== End of UI only fields ====== # ====== End of UI only fields ======
@ -192,7 +194,7 @@ class ControlNetUnit:
mask_image: Optional[GradioImageMaskPair] = None mask_image: Optional[GradioImageMaskPair] = None
# Specifies how this unit should be applied in each pass of high-resolution fix. # Specifies how this unit should be applied in each pass of high-resolution fix.
# Ignored if high-resolution fix is not enabled. # Ignored if high-resolution fix is not enabled.
hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH hr_option: HiResFixOption = HiResFixOption.BOTH
# Indicates whether the unit is enabled; defaults to True. # Indicates whether the unit is enabled; defaults to True.
enabled: bool = True enabled: bool = True
# Name of the module being used; defaults to "None". # Name of the module being used; defaults to "None".
@ -204,7 +206,7 @@ class ControlNetUnit:
# Optional image for input; defaults to None. # Optional image for input; defaults to None.
image: Optional[GradioImageMaskPair] = None image: Optional[GradioImageMaskPair] = None
# Specifies the mode of image resizing; defaults to inner fit. # Specifies the mode of image resizing; defaults to inner fit.
resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT resize_mode: ResizeMode = ResizeMode.INNER_FIT
# Resolution for processing by the unit; defaults to -1 (unspecified). # Resolution for processing by the unit; defaults to -1 (unspecified).
processor_res: int = -1 processor_res: int = -1
# Threshold A for processing; defaults to -1 (unspecified). # Threshold A for processing; defaults to -1 (unspecified).
@ -218,7 +220,22 @@ class ControlNetUnit:
# Enables pixel-perfect processing; defaults to False. # Enables pixel-perfect processing; defaults to False.
pixel_perfect: bool = False pixel_perfect: bool = False
# Control mode for the unit; defaults to balanced. # Control mode for the unit; defaults to balanced.
control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED control_mode: ControlMode = ControlMode.BALANCED
# Weight for each layer of ControlNet params.
# For ControlNet:
# - SD1.5: 13 weights (4 encoder block * 3 + 1 middle block)
# - SDXL: 10 weights (3 encoder block * 3 + 1 middle block)
# For T2IAdapter
# - SD1.5: 5 weights (4 encoder block + 1 middle block)
# - SDXL: 4 weights (3 encoder block + 1 middle block)
# For IPAdapter
# - SD15: 16 (6 input blocks + 9 output blocks + 1 middle block)
# - SDXL: 11 weights (4 input blocks + 6 output blocks + 1 middle block)
# Note1: Setting advanced weighting will disable `soft_injection`, i.e.
# It is recommended to set ControlMode = BALANCED when using `advanced_weighting`.
# Note2: The field `weight` is still used in some places, e.g. reference_only,
# even advanced_weighting is set.
advanced_weighting: Optional[List[float]] = None
# Following fields should only be used in the API. # Following fields should only be used in the API.
# ====== Start of API only fields ====== # ====== Start of API only fields ======
@ -275,6 +292,11 @@ class ControlNetUnit:
"image": mask, "image": mask,
"mask": np.zeros_like(mask), "mask": np.zeros_like(mask),
} }
# Convert strings to enums.
unit.input_mode = InputMode(unit.input_mode)
unit.hr_option = HiResFixOption(unit.hr_option)
unit.resize_mode = ResizeMode(unit.resize_mode)
unit.control_mode = ControlMode(unit.control_mode)
return unit return unit

View File

@ -6,6 +6,7 @@ from modules import shared, sd_models
from lib_controlnet.enums import StableDiffusionVersion from lib_controlnet.enums import StableDiffusionVersion
from modules_forge.shared import controlnet_dir, supported_preprocessors from modules_forge.shared import controlnet_dir, supported_preprocessors
from typing import Dict, Tuple, List
CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin", ".patch"] CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin", ".patch"]
@ -56,6 +57,10 @@ controlnet_names = ['None']
def get_preprocessor(name): def get_preprocessor(name):
return supported_preprocessors.get(name, None) return supported_preprocessors.get(name, None)
def get_default_preprocessor(tag):
ps = get_filtered_preprocessor_names(tag)
assert len(ps) > 0
return ps[0] if len(ps) == 1 else ps[1]
def get_sorted_preprocessors(): def get_sorted_preprocessors():
preprocessors = [p for k, p in supported_preprocessors.items() if k != 'None'] preprocessors = [p for k, p in supported_preprocessors.items() if k != 'None']
@ -144,3 +149,44 @@ def get_sd_version() -> StableDiffusionVersion:
return StableDiffusionVersion.SD1x return StableDiffusionVersion.SD1x
else: else:
return StableDiffusionVersion.UNKNOWN return StableDiffusionVersion.UNKNOWN
def select_control_type(
control_type: str,
sd_version: StableDiffusionVersion = StableDiffusionVersion.UNKNOWN,
) -> Tuple[List[str], List[str], str, str]:
global controlnet_names
pattern = control_type.lower()
all_models = list(controlnet_names)
if pattern == "all":
preprocessors = get_sorted_preprocessors().values()
return [
[p.name for p in preprocessors],
all_models,
'none', # default option
"None" # default model
]
filtered_model_list = get_filtered_controlnet_names(control_type)
if pattern == "none":
filtered_model_list.append("None")
assert len(filtered_model_list) > 0, "'None' model should always be available."
if len(filtered_model_list) == 1:
default_model = "None"
else:
default_model = filtered_model_list[1]
for x in filtered_model_list:
if "11" in x.split("[")[0]:
default_model = x
break
return (
get_filtered_preprocessor_names(control_type),
filtered_model_list,
get_default_preprocessor(control_type),
default_model
)

View File

@ -1,7 +1,10 @@
from typing import Optional from typing import Optional
from copy import copy
from modules import processing from modules import processing
from modules.api import api
from lib_controlnet import external_code from lib_controlnet import external_code
from lib_controlnet.external_code import InputMode, ControlNetUnit
from modules_forge.forge_util import HWC3 from modules_forge.forge_util import HWC3
@ -361,3 +364,22 @@ def crop_and_resize_image(detected_map, resize_mode, h, w, fill_border_with_255=
def judge_image_type(img): def judge_image_type(img):
return isinstance(img, np.ndarray) and img.ndim == 3 and int(img.shape[2]) in [3, 4] return isinstance(img, np.ndarray) and img.ndim == 3 and int(img.shape[2]) in [3, 4]
def try_unfold_unit(unit: ControlNetUnit) -> List[ControlNetUnit]:
"""Unfolds an multi-inputs unit into multiple units with one input."""
if unit.input_mode != InputMode.MERGE:
return [unit]
def extract_unit(gallery_item: dict) -> ControlNetUnit:
r_unit = copy(unit)
img = np.array(api.decode_base64_to_image(read_image(gallery_item["name"]))).astype('uint8')
r_unit.image = {
"image": img,
"mask": np.zeros_like(img),
}
r_unit.input_mode = InputMode.SIMPLE
r_unit.weight = unit.weight / len(unit.multi_inputs_gallery)
return r_unit
return [extract_unit(item) for item in unit.multi_inputs_gallery]

View File

@ -11,9 +11,15 @@ from modules.api.api import decode_base64_to_image
import gradio as gr import gradio as gr
from lib_controlnet import global_state, external_code from lib_controlnet import global_state, external_code
from lib_controlnet.external_code import ControlNetUnit from lib_controlnet.external_code import ControlNetUnit, InputMode
from lib_controlnet.utils import align_dim_latent, set_numpy_seed, crop_and_resize_image, \ from lib_controlnet.utils import (
prepare_mask, judge_image_type align_dim_latent,
set_numpy_seed,
crop_and_resize_image,
prepare_mask,
judge_image_type,
try_unfold_unit,
)
from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup
from lib_controlnet.controlnet_ui.photopea import Photopea from lib_controlnet.controlnet_ui.photopea import Photopea
from lib_controlnet.logging import logger from lib_controlnet.logging import logger
@ -53,6 +59,7 @@ class ControlNetCachedParameters:
self.control_cond_for_hr_fix = None self.control_cond_for_hr_fix = None
self.control_mask = None self.control_mask = None
self.control_mask_for_hr_fix = None self.control_mask_for_hr_fix = None
self.advanced_weighting = None
class ControlNetForForgeOfficial(scripts.Script): class ControlNetForForgeOfficial(scripts.Script):
@ -105,8 +112,13 @@ class ControlNetForForgeOfficial(scripts.Script):
for unit in units for unit in units
] ]
assert all(isinstance(unit, ControlNetUnit) for unit in units) assert all(isinstance(unit, ControlNetUnit) for unit in units)
enabled_units = [x for x in units if x.enabled] return [
return enabled_units simple_unit
for unit in units
# Unfolds multi-inputs units.
for simple_unit in try_unfold_unit(unit)
if simple_unit.enabled
]
@staticmethod @staticmethod
def try_crop_image_with_a1111_mask( def try_crop_image_with_a1111_mask(
@ -153,9 +165,9 @@ class ControlNetForForgeOfficial(scripts.Script):
def get_input_data(self, p, unit, preprocessor, h, w): def get_input_data(self, p, unit, preprocessor, h, w):
logger.info(f'ControlNet Input Mode: {unit.input_mode}') logger.info(f'ControlNet Input Mode: {unit.input_mode}')
image_list = [] image_list = []
resize_mode = external_code.resize_mode_from_value(unit.resize_mode) resize_mode = unit.resize_mode
if unit.input_mode == external_code.InputMode.MERGE: if unit.input_mode == InputMode.MERGE:
for idx, item in enumerate(unit.batch_input_gallery): for idx, item in enumerate(unit.batch_input_gallery):
img_path = item['name'] img_path = item['name']
logger.info(f'Try to read image: {img_path}') logger.info(f'Try to read image: {img_path}')
@ -169,7 +181,7 @@ class ControlNetForForgeOfficial(scripts.Script):
mask = np.ascontiguousarray(cv2.imread(mask_path)[:, :, ::-1]).copy() mask = np.ascontiguousarray(cv2.imread(mask_path)[:, :, ::-1]).copy()
if img is not None: if img is not None:
image_list.append([img, mask]) image_list.append([img, mask])
elif unit.input_mode == external_code.InputMode.BATCH: elif unit.input_mode == InputMode.BATCH:
image_list = [] image_list = []
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp'] image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
batch_image_files = shared.listfiles(unit.batch_image_dir) batch_image_files = shared.listfiles(unit.batch_image_dir)
@ -339,7 +351,7 @@ class ControlNetForForgeOfficial(scripts.Script):
break break
if has_high_res_fix: if has_high_res_fix:
hr_option = HiResFixOption.from_value(unit.hr_option) hr_option = unit.hr_option
else: else:
hr_option = HiResFixOption.BOTH hr_option = HiResFixOption.BOTH
@ -430,7 +442,7 @@ class ControlNetForForgeOfficial(scripts.Script):
) )
if has_high_res_fix: if has_high_res_fix:
hr_option = HiResFixOption.from_value(unit.hr_option) hr_option = unit.hr_option
else: else:
hr_option = HiResFixOption.BOTH hr_option = HiResFixOption.BOTH
@ -494,6 +506,11 @@ class ControlNetForForgeOfficial(scripts.Script):
params.model.positive_advanced_weighting = soft_weighting.copy() params.model.positive_advanced_weighting = soft_weighting.copy()
params.model.negative_advanced_weighting = soft_weighting.copy() params.model.negative_advanced_weighting = soft_weighting.copy()
if unit.advanced_weighting is not None:
if params.model.positive_advanced_weighting is None:
logger.warn("advanced_weighting overwrite control_mode")
params.model.positive_advanced_weighting = unit.advanced_weighting
cond, mask = params.preprocessor.process_before_every_sampling(p, cond, mask, *args, **kwargs) cond, mask = params.preprocessor.process_before_every_sampling(p, cond, mask, *args, **kwargs)
params.model.advanced_mask_weighting = mask params.model.advanced_mask_weighting = mask

View File

@ -4,6 +4,7 @@ from .template import (
APITestTemplate, APITestTemplate,
girl_img, girl_img,
mask_img, mask_img,
portrait_imgs,
disable_in_cq, disable_in_cq,
get_model, get_model,
) )
@ -169,3 +170,60 @@ def test_lama_outpaint():
"resize_mode": "Resize and Fill", # OUTER_FIT "resize_mode": "Resize and Fill", # OUTER_FIT
}, },
).exec() ).exec()
@disable_in_cq
def test_instant_id_sdxl():
assert len(portrait_imgs) > 0
assert APITestTemplate(
"instant_id_sdxl",
"txt2img",
payload_overrides={
"width": 1000,
"height": 1000,
"prompt": "1girl, red background",
},
unit_overrides=[
dict(
image=portrait_imgs[0],
model=get_model("ip-adapter_instant_id_sdxl"),
module="InsightFace (InstantID)",
),
dict(
image=portrait_imgs[1],
model=get_model("control_instant_id_sdxl"),
module="instant_id_face_keypoints",
),
],
).exec()
@disable_in_cq
def test_instant_id_sdxl_multiple_units():
assert len(portrait_imgs) > 0
assert APITestTemplate(
"instant_id_sdxl_multiple_units",
"txt2img",
payload_overrides={
"width": 1000,
"height": 1000,
"prompt": "1girl, red background",
},
unit_overrides=[
dict(
image=portrait_imgs[0],
model=get_model("ip-adapter_instant_id_sdxl"),
module="InsightFace (InstantID)",
),
dict(
image=portrait_imgs[1],
model=get_model("control_instant_id_sdxl"),
module="instant_id_face_keypoints",
),
dict(
image=portrait_imgs[1],
model=get_model("diffusers_xl_canny"),
module="canny",
),
],
).exec()

View File

@ -0,0 +1,43 @@
from .template import (
APITestTemplate,
realistic_girl_face_img,
disable_in_cq,
get_model,
)
@disable_in_cq
def test_ipadapter_advanced_weighting():
weights = [0.0] * 16 # 16 weights for SD15 / 11 weights for SDXL
# SD15 composition
weights[4] = 0.25
weights[5] = 1.0
APITestTemplate(
"test_ipadapter_advanced_weighting",
"txt2img",
payload_overrides={
"width": 512,
"height": 512,
},
unit_overrides={
"image": realistic_girl_face_img,
"module": "CLIP-ViT-H (IPAdapter)",
"model": get_model("ip-adapter_sd15"),
"advanced_weighting": weights,
},
).exec()
APITestTemplate(
"test_ipadapter_advanced_weighting_ref",
"txt2img",
payload_overrides={
"width": 512,
"height": 512,
},
unit_overrides={
"image": realistic_girl_face_img,
"module": "CLIP-ViT-H (IPAdapter)",
"model": get_model("ip-adapter_sd15"),
},
).exec()

View File

@ -248,13 +248,13 @@ def get_model(model_name: str) -> str:
default_unit = { default_unit = {
"control_mode": 0, "control_mode": "Balanced",
"enabled": True, "enabled": True,
"guidance_end": 1, "guidance_end": 1,
"guidance_start": 0, "guidance_start": 0,
"pixel_perfect": True, "pixel_perfect": True,
"processor_res": 512, "processor_res": 512,
"resize_mode": 1, "resize_mode": "Crop and Resize",
"threshold_a": 64, "threshold_a": 64,
"threshold_b": 64, "threshold_b": 64,
"weight": 1, "weight": 1,

View File

@ -1,6 +1,12 @@
import gradio as gr import gradio as gr
from modules import scripts import sys
import traceback
from typing import Any
from functools import partial
from modules import script_callbacks, scripts
from ldm_patched.contrib.external_freelunch import FreeU_V2 from ldm_patched.contrib.external_freelunch import FreeU_V2
@ -68,6 +74,18 @@ class FreeUForForge(scripts.Script):
freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2 = script_args freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2 = script_args
xyz = getattr(p, "_freeu_xyz", {})
if "freeu_enabled" in xyz:
freeu_enabled = xyz["freeu_enabled"] == "True"
if "freeu_b1" in xyz:
freeu_b1 = xyz["freeu_b1"]
if "freeu_b2" in xyz:
freeu_b2 = xyz["freeu_b2"]
if "freeu_s1" in xyz:
freeu_s1 = xyz["freeu_s1"]
if "freeu_s2" in xyz:
freeu_s2 = xyz["freeu_s2"]
if not freeu_enabled: if not freeu_enabled:
return return
@ -89,3 +107,62 @@ class FreeUForForge(scripts.Script):
)) ))
return return
def set_value(p, x: Any, xs: Any, *, field: str):
if not hasattr(p, "_freeu_xyz"):
p._freeu_xyz = {}
p._freeu_xyz[field] = x
def make_axis_on_xyz_grid():
xyz_grid = None
for script in scripts.scripts_data:
if script.script_class.__module__ == "xyz_grid.py":
xyz_grid = script.module
break
if xyz_grid is None:
return
axis = [
xyz_grid.AxisOption(
"FreeU Enabled",
str,
partial(set_value, field="freeu_enabled"),
choices=lambda: ["True", "False"]
),
xyz_grid.AxisOption(
"FreeU B1",
float,
partial(set_value, field="freeu_b1"),
),
xyz_grid.AxisOption(
"FreeU B2",
float,
partial(set_value, field="freeu_b2"),
),
xyz_grid.AxisOption(
"FreeU S1",
float,
partial(set_value, field="freeu_s1"),
),
xyz_grid.AxisOption(
"FreeU S2",
float,
partial(set_value, field="freeu_s2"),
),
]
if not any(x.label.startswith("FreeU") for x in xyz_grid.axis_options):
xyz_grid.axis_options.extend(axis)
def on_before_ui():
try:
make_axis_on_xyz_grid()
except Exception:
error = traceback.format_exc()
print(
f"[-] FreeU Integrated: xyz_grid error:\n{error}",
file=sys.stderr,
)
script_callbacks.on_before_ui(on_before_ui)

View File

@ -7,6 +7,7 @@ import math
import ldm_patched.modules.utils import ldm_patched.modules.utils
import ldm_patched.modules.model_management import ldm_patched.modules.model_management
from ldm_patched.modules.controlnet import ControlNet
from ldm_patched.modules.clip_vision import clip_preprocess from ldm_patched.modules.clip_vision import clip_preprocess
from ldm_patched.ldm.modules.attention import optimized_attention from ldm_patched.ldm.modules.attention import optimized_attention
from ldm_patched.utils import path_utils as folder_paths from ldm_patched.utils import path_utils as folder_paths
@ -402,7 +403,7 @@ class CrossAttentionPatch:
batch_prompt = b // len(cond_or_uncond) batch_prompt = b // len(cond_or_uncond)
out = optimized_attention(q, k, v, extra_options["n_heads"]) out = optimized_attention(q, k, v, extra_options["n_heads"])
_, _, lh, lw = extra_options["original_shape"] _, _, lh, lw = extra_options["original_shape"]
for weight, cond, uncond, ipadapter, mask, weight_type, sigma_start, sigma_end, unfold_batch in zip(self.weights, self.conds, self.unconds, self.ipadapters, self.masks, self.weight_type, self.sigma_start, self.sigma_end, self.unfold_batch): for weight, cond, uncond, ipadapter, mask, weight_type, sigma_start, sigma_end, unfold_batch in zip(self.weights, self.conds, self.unconds, self.ipadapters, self.masks, self.weight_type, self.sigma_start, self.sigma_end, self.unfold_batch):
if sigma > sigma_start or sigma < sigma_end: if sigma > sigma_start or sigma < sigma_end:
continue continue
@ -465,8 +466,18 @@ class CrossAttentionPatch:
ip_v = ip_v_offset + ip_v_mean * W ip_v = ip_v_offset + ip_v_mean * W
out_ip = optimized_attention(q, ip_k.to(org_dtype), ip_v.to(org_dtype), extra_options["n_heads"]) out_ip = optimized_attention(q, ip_k.to(org_dtype), ip_v.to(org_dtype), extra_options["n_heads"])
if weight_type.startswith("original"):
out_ip = out_ip * weight if weight_type == "original":
assert isinstance(weight, (float, int))
weight = weight
elif weight_type == "advanced":
assert isinstance(weight, list)
transformer_index: int = extra_options["transformer_index"]
assert transformer_index < len(weight)
weight = weight[transformer_index]
else:
weight = 1.0
out_ip = out_ip * weight
if mask is not None: if mask is not None:
# TODO: needs checking # TODO: needs checking
@ -732,7 +743,7 @@ class IPAdapterApply:
is_faceid=self.is_faceid, is_faceid=self.is_faceid,
is_instant_id=self.is_instant_id is_instant_id=self.is_instant_id
) )
self.ipadapter.to(self.device, dtype=self.dtype) self.ipadapter.to(self.device, dtype=self.dtype)
if self.is_instant_id: if self.is_instant_id:
@ -749,13 +760,27 @@ class IPAdapterApply:
work_model = model.clone() work_model = model.clone()
if self.is_instant_id: if self.is_instant_id:
def modifier(cnet, x_noisy, t, cond, batched_number): def instant_id_modifier(cnet: ControlNet, x_noisy, t, cond, batched_number):
"""Overwrites crossattn inputs to InstantID ControlNet with ipadapter image embeds.
TODO: There can be multiple pairs of InstantID (ipadapter/controlnet) to control
rendering of multiple faces on canvas. We need to find a way to pair them. Currently,
the modifier is unconditionally applied to all instant id ControlNet units.
"""
if (
not isinstance(cnet, ControlNet) or
# model_file_name is None for Control LoRA.
cnet.control_model.model_file_name is None or
"instant_id" not in cnet.control_model.model_file_name.lower()
):
return x_noisy, t, cond, batched_number
cond_mark = cond['transformer_options']['cond_mark'][:, None, None].to(cond['c_crossattn']) # cond is 0 cond_mark = cond['transformer_options']['cond_mark'][:, None, None].to(cond['c_crossattn']) # cond is 0
c_crossattn = image_prompt_embeds * (1.0 - cond_mark) + uncond_image_prompt_embeds * cond_mark c_crossattn = image_prompt_embeds * (1.0 - cond_mark) + uncond_image_prompt_embeds * cond_mark
cond['c_crossattn'] = c_crossattn cond['c_crossattn'] = c_crossattn
return x_noisy, t, cond, batched_number return x_noisy, t, cond, batched_number
work_model.add_controlnet_conditioning_modifier(modifier) work_model.add_controlnet_conditioning_modifier(instant_id_modifier)
if attn_mask is not None: if attn_mask is not None:
attn_mask = attn_mask.to(self.device) attn_mask = attn_mask.to(self.device)

View File

@ -143,11 +143,17 @@ class IPAdapterPatcher(ControlModelPatcher):
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs): def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
unet = process.sd_model.forge_objects.unet unet = process.sd_model.forge_objects.unet
if self.positive_advanced_weighting is None:
weight = self.strength
cond["weight_type"] = "original"
else:
weight = self.positive_advanced_weighting
cond["weight_type"] = "advanced"
unet = opIPAdapterApply( unet = opIPAdapterApply(
ipadapter=self.ip_adapter, ipadapter=self.ip_adapter,
model=unet, model=unet,
weight=self.strength, weight=weight,
start_at=self.start_percent, start_at=self.start_percent,
end_at=self.end_percent, end_at=self.end_percent,
faceid_v2=self.faceid_v2, faceid_v2=self.faceid_v2,

View File

@ -1,7 +1,7 @@
<div> <div>
<a href="{api_docs}">API</a> <a href="{api_docs}">API</a>
 •   • 
<a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a> <a href="https://github.com/lllyasviel/stable-diffusion-webui-forge">Github</a>
 •   • 
<a href="https://gradio.app">Gradio</a> <a href="https://gradio.app">Gradio</a>
 •   • 

View File

@ -4,6 +4,7 @@
import torch import torch
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
from typing import Optional
from ldm_patched.ldm.modules.diffusionmodules.util import ( from ldm_patched.ldm.modules.diffusionmodules.util import (
zero_module, zero_module,
@ -54,9 +55,12 @@ class ControlNet(nn.Module):
transformer_depth_output=None, transformer_depth_output=None,
device=None, device=None,
operations=ldm_patched.modules.ops.disable_weight_init, operations=ldm_patched.modules.ops.disable_weight_init,
model_file_name: Optional[str] = None, # Name of model file.
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
self.model_file_name = model_file_name
assert use_spatial_transformer == True, "use_spatial_transformer has to be true" assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
if use_spatial_transformer: if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'

View File

@ -6,6 +6,7 @@ import math
from scipy import integrate from scipy import integrate
import torch import torch
import numpy as np
from torch import nn from torch import nn
import torchsde import torchsde
from tqdm.auto import trange, tqdm from tqdm.auto import trange, tqdm
@ -38,6 +39,35 @@ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min)) sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
return append_zero(sigmas) return append_zero(sigmas)
# align your steps
def get_sigmas_ays(n, sigma_min, sigma_max, is_sdxl=False, device='cpu'):
# https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
def loglinear_interp(t_steps, num_steps):
"""
Performs log-linear interpolation of a given array of decreasing numbers.
"""
xs = torch.linspace(0, 1, len(t_steps))
ys = torch.log(torch.tensor(t_steps[::-1]))
new_xs = torch.linspace(0, 1, num_steps)
new_ys = np.interp(new_xs, xs, ys)
interped_ys = torch.exp(torch.tensor(new_ys)).numpy()[::-1].copy()
return interped_ys
if is_sdxl:
sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.029]
else:
# Default to SD 1.5 sigmas.
sigmas = [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029]
if n != len(sigmas):
sigmas = np.append(loglinear_interp(sigmas, n), [0.0])
else:
sigmas.append(0.0)
return torch.FloatTensor(sigmas).to(device)
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
"""Constructs a continuous VP noise schedule.""" """Constructs a continuous VP noise schedule."""

View File

@ -487,7 +487,7 @@ def load_controlnet(ckpt_path, model=None):
controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast
controlnet_config.pop("out_channels") controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config) control_model = ldm_patched.controlnet.cldm.ControlNet(model_file_name=ckpt_path, **controlnet_config)
if pth: if pth:
if 'difference' in controlnet_data: if 'difference' in controlnet_data:

View File

@ -107,9 +107,11 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
sigma_min = sampling_settings.get("sigma_min", 0.002) sigma_min = sampling_settings.get("sigma_min", 0.002)
sigma_max = sampling_settings.get("sigma_max", 120.0) sigma_max = sampling_settings.get("sigma_max", 120.0)
self.set_sigma_range(sigma_min, sigma_max) sigma_data = sampling_settings.get("sigma_data", 1.0)
self.set_sigma_range(sigma_min, sigma_max, sigma_data)
def set_sigma_range(self, sigma_min, sigma_max): def set_sigma_range(self, sigma_min, sigma_max, sigma_data):
self.sigma_data = sigma_data
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp() sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp()
self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers

View File

@ -662,14 +662,16 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return model.process_latent_out(samples.to(torch.float32)) return model.process_latent_out(samples.to(torch.float32))
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "ays"]
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
def calculate_sigmas_scheduler(model, scheduler_name, steps): def calculate_sigmas_scheduler(model, scheduler_name, steps, is_sdxl=False):
if scheduler_name == "karras": if scheduler_name == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
elif scheduler_name == "exponential": elif scheduler_name == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
elif scheduler_name == "ays":
sigmas = k_diffusion_sampling.get_sigmas_ays(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max), is_sdxl=is_sdxl)
elif scheduler_name == "normal": elif scheduler_name == "normal":
sigmas = normal_scheduler(model, steps) sigmas = normal_scheduler(model, steps)
elif scheduler_name == "simple": elif scheduler_name == "simple":

View File

@ -169,6 +169,11 @@ class SDXL(supported_models_base.BASE):
def model_type(self, state_dict, prefix=""): def model_type(self, state_dict, prefix=""):
if "v_pred" in state_dict: if "v_pred" in state_dict:
return model_base.ModelType.V_PREDICTION return model_base.ModelType.V_PREDICTION
elif "edm_vpred.sigma_max" in state_dict:
self.sampling_settings["sigma_max"] = round(float(state_dict["edm_vpred.sigma_max"].item()),3)
if "edm_vpred.sigma_min" in state_dict:
self.sampling_settings["sigma_min"] = round(float(state_dict["edm_vpred.sigma_min"].item()),3)
return model_base.ModelType.V_PREDICTION_EDM
else: else:
return model_base.ModelType.EPS return model_base.ModelType.EPS

View File

@ -516,7 +516,7 @@ def configure_forge_reference_checkout(a1111_home: Path):
ModelRef(arg_name="--vae-dir", relative_path="models/VAE"), ModelRef(arg_name="--vae-dir", relative_path="models/VAE"),
ModelRef(arg_name="--hypernetwork-dir", relative_path="models/hypernetworks"), ModelRef(arg_name="--hypernetwork-dir", relative_path="models/hypernetworks"),
ModelRef(arg_name="--embeddings-dir", relative_path="embeddings"), ModelRef(arg_name="--embeddings-dir", relative_path="embeddings"),
ModelRef(arg_name="--lora-dir", relative_path="models/lora"), ModelRef(arg_name="--lora-dir", relative_path="models/Lora"),
# Ref A1111 need to have sd-webui-controlnet installed. # Ref A1111 need to have sd-webui-controlnet installed.
ModelRef(arg_name="--controlnet-dir", relative_path="models/ControlNet"), ModelRef(arg_name="--controlnet-dir", relative_path="models/ControlNet"),
ModelRef(arg_name="--controlnet-preprocessor-models-dir", relative_path="extensions/sd-webui-controlnet/annotator/downloads"), ModelRef(arg_name="--controlnet-preprocessor-models-dir", relative_path="extensions/sd-webui-controlnet/annotator/downloads"),

View File

@ -10,6 +10,7 @@ class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler):
self.sampler_name = sampler_name self.sampler_name = sampler_name
self.scheduler_name = scheduler_name self.scheduler_name = scheduler_name
self.unet = sd_model.forge_objects.unet self.unet = sd_model.forge_objects.unet
self.model = sd_model
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name)) sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
super().__init__(sampler_function, sd_model, None) super().__init__(sampler_function, sd_model, None)
@ -20,7 +21,7 @@ class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler):
sigmas = self.unet.model.model_sampling.sigma(timesteps) sigmas = self.unet.model.model_sampling.sigma(timesteps)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
else: else:
sigmas = calculate_sigmas_scheduler(self.unet.model, self.scheduler_name, steps) sigmas = calculate_sigmas_scheduler(self.unet.model, self.scheduler_name, steps, is_sdxl=getattr(self.model, "is_sdxl", False))
return sigmas.to(self.unet.load_device) return sigmas.to(self.unet.load_device)
@ -34,9 +35,13 @@ def build_constructor(sampler_name, scheduler_name):
samplers_data_alter = [ samplers_data_alter = [
sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm', scheduler_name='normal'), ['ddpm'], {}), sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm', scheduler_name='normal'), ['ddpm'], {}),
sd_samplers_common.SamplerData('DDPM Karras', build_constructor(sampler_name='ddpm', scheduler_name='karras'), ['ddpm_karras'], {}), sd_samplers_common.SamplerData('DDPM Karras', build_constructor(sampler_name='ddpm', scheduler_name='karras'), ['ddpm_karras'], {}),
sd_samplers_common.SamplerData('Euler AYS', build_constructor(sampler_name='euler', scheduler_name='ays'), ['euler_ays'], {}),
sd_samplers_common.SamplerData('Euler A Turbo', build_constructor(sampler_name='euler_ancestral', scheduler_name='turbo'), ['euler_ancestral_turbo'], {}), sd_samplers_common.SamplerData('Euler A Turbo', build_constructor(sampler_name='euler_ancestral', scheduler_name='turbo'), ['euler_ancestral_turbo'], {}),
sd_samplers_common.SamplerData('Euler A AYS', build_constructor(sampler_name='euler_ancestral', scheduler_name='ays'), ['euler_ancestral_ays'], {}),
sd_samplers_common.SamplerData('DPM++ 2M Turbo', build_constructor(sampler_name='dpmpp_2m', scheduler_name='turbo'), ['dpmpp_2m_turbo'], {}), sd_samplers_common.SamplerData('DPM++ 2M Turbo', build_constructor(sampler_name='dpmpp_2m', scheduler_name='turbo'), ['dpmpp_2m_turbo'], {}),
sd_samplers_common.SamplerData('DPM++ 2M AYS', build_constructor(sampler_name='dpmpp_2m', scheduler_name='ays'), ['dpmpp_2m_ays'], {}),
sd_samplers_common.SamplerData('DPM++ 2M SDE Turbo', build_constructor(sampler_name='dpmpp_2m_sde', scheduler_name='turbo'), ['dpmpp_2m_sde_turbo'], {}), sd_samplers_common.SamplerData('DPM++ 2M SDE Turbo', build_constructor(sampler_name='dpmpp_2m_sde', scheduler_name='turbo'), ['dpmpp_2m_sde_turbo'], {}),
sd_samplers_common.SamplerData('DPM++ 2M SDE AYS', build_constructor(sampler_name='dpmpp_2m_sde', scheduler_name='ays'), ['dpmpp_2m_sde_ays'], {}),
sd_samplers_common.SamplerData('LCM Karras', build_constructor(sampler_name='lcm', scheduler_name='karras'), ['lcm_karras'], {}), sd_samplers_common.SamplerData('LCM Karras', build_constructor(sampler_name='lcm', scheduler_name='karras'), ['lcm_karras'], {}),
sd_samplers_common.SamplerData('Euler SGMUniform', build_constructor(sampler_name='euler', scheduler_name='sgm_uniform'), ['euler_sgm_uniform'], {}), sd_samplers_common.SamplerData('Euler SGMUniform', build_constructor(sampler_name='euler', scheduler_name='sgm_uniform'), ['euler_sgm_uniform'], {}),
sd_samplers_common.SamplerData('Euler A SGMUniform', build_constructor(sampler_name='euler_ancestral', scheduler_name='sgm_uniform'), ['euler_ancestral_sgm_uniform'], {}), sd_samplers_common.SamplerData('Euler A SGMUniform', build_constructor(sampler_name='euler_ancestral', scheduler_name='sgm_uniform'), ['euler_ancestral_sgm_uniform'], {}),

View File

@ -113,7 +113,7 @@ class ControlNetPatcher(ControlModelPatcher):
controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast
controlnet_config.pop("out_channels") controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config) control_model = ldm_patched.controlnet.cldm.ControlNet(model_file_name=ckpt_path, **controlnet_config)
if pth: if pth:
if 'difference' in controlnet_data: if 'difference' in controlnet_data: