Compare commits

..

12 Commits

Author SHA1 Message Date
51cf69328e Merge branch 'dev2' into mysd
Some checks failed
Linter / ruff (push) Has been cancelled
2024-06-05 17:22:55 +03:00
GavChap
77bdb9208d
Add cos xl (#710)
* Add V_PREDICTION_EDM handing for CosXL models

Add V_PREDICTION_EDM handing for CosXL models

* Get correct sigmas from checkpoint.

* Round to 3 sig digs in order to make compatible with comfy implementation

* Add sigma data like ComfyUI has

---------

Co-authored-by: Gavin Chapman <gchapman@MAINPC>
2024-05-23 16:42:56 -04:00
altoiddealer
7eb5cbad01
Restore '/controlnet/control_types' API endpoint (#713)
Restores the '/controlnet/control_types' API endpoint, which is immensely useful for anyone using ControlNet via the API
2024-05-23 16:39:26 -04:00
Chenlei Hu
eb1e12b0dc
Add advanced weighting support (#754) 2024-05-23 13:11:11 -04:00
Chakib Benziane
49c3a080b5
implement align your steps scheduler (#726)
Signed-off-by: blob42 <contact@blob42.xyz>
2024-05-22 15:11:10 -04:00
Chenlei Hu
62e60ad403
Fix import of enum (#755) 2024-05-21 20:55:24 -04:00
Chenlei Hu
66845160de
Fix enum (#706)
* Fix enum conversion from string

* More fixes
2024-05-02 19:09:53 -04:00
bluelovers
b996316b20
Update footer.html (#591) 2024-05-02 18:31:37 -04:00
Chenlei Hu
d5cc799eec
Only feed image embed to instant id ControlNet unit (#309)
* Only feed image embed to instant id ControlNet units

* Add tests

* Add more tests
2024-05-02 18:30:54 -04:00
Qun
b55f9e7212
fix lora path (#509) 2024-05-02 18:29:25 -04:00
sJrx233S7UMo
bafd067777
Add xyz plot feature to "sd_forge_freeu" (#417)
* Add xyz plot feature to "sd_forge_freeu"

* Update extensions-builtin/sd_forge_freeu/scripts/forge_freeu.py

Co-authored-by: catboxanon <122327233+catboxanon@users.noreply.github.com>

* Update extensions-builtin/sd_forge_freeu/scripts/forge_freeu.py

Co-authored-by: catboxanon <122327233+catboxanon@users.noreply.github.com>

* Update extensions-builtin/sd_forge_freeu/scripts/forge_freeu.py

Co-authored-by: catboxanon <122327233+catboxanon@users.noreply.github.com>

* Update extensions-builtin/sd_forge_freeu/scripts/forge_freeu.py

Co-authored-by: catboxanon <122327233+catboxanon@users.noreply.github.com>

* Update extensions-builtin/sd_forge_freeu/scripts/forge_freeu.py

Co-authored-by: catboxanon <122327233+catboxanon@users.noreply.github.com>

---------

Co-authored-by: sJrx233S7UMo <sJrx233S7UMo@gmail.com>
Co-authored-by: catboxanon <122327233+catboxanon@users.noreply.github.com>
2024-05-02 18:28:24 -04:00
Chenlei Hu
61a2a9d342
Add back ControlNet multi-inputs upload tab (#264)
* Refactor gallery

* Add back multi-inputs

* nit
2024-05-02 18:25:42 -04:00
23 changed files with 553 additions and 116 deletions

View File

@ -11,6 +11,8 @@ from .global_state import (
get_all_preprocessor_names,
get_all_controlnet_names,
get_preprocessor,
get_all_preprocessor_tags,
select_control_type,
)
from .utils import judge_image_type
from .logging import logger
@ -53,6 +55,30 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
# "module_detail": external_code.get_modules_detail(alias_names),
}
@app.get("/controlnet/control_types")
async def control_types():
def format_control_type(
filtered_preprocessor_list,
filtered_model_list,
default_option,
default_model,
):
control_dict = {
"module_list": filtered_preprocessor_list,
"model_list": filtered_model_list,
"default_option": default_option,
"default_model": default_model,
}
return control_dict
return {
"control_types": {
control_type: format_control_type(*select_control_type(control_type))
for control_type in get_all_preprocessor_tags()
}
}
@app.post("/controlnet/detect")
async def detect(
controlnet_module: str = Body("none", title="Controlnet Module"),

View File

@ -11,16 +11,17 @@ from lib_controlnet import (
global_state,
external_code,
)
from lib_controlnet.external_code import ControlNetUnit
from lib_controlnet.logging import logger
from lib_controlnet.controlnet_ui.openpose_editor import OpenposeEditor
from lib_controlnet.controlnet_ui.preset import ControlNetPresetUI
from lib_controlnet.controlnet_ui.tool_button import ToolButton
from lib_controlnet.controlnet_ui.photopea import Photopea
from lib_controlnet.controlnet_ui.multi_inputs_gallery import MultiInputsGallery
from lib_controlnet.enums import InputMode, HiResFixOption
from modules import shared, script_callbacks
from modules.ui_components import FormRow
from modules_forge.forge_util import HWC3
from lib_controlnet.external_code import UiControlNetUnit
@dataclass
@ -171,10 +172,10 @@ class ControlNetUiGroup(object):
self.webcam_mirrored = False
# Note: All gradio elements declared in `render` will be defined as member variable.
# Update counter to trigger a force update of UiControlNetUnit.
# Update counter to trigger a force update of ControlNetUnit.
# dummy_gradio_update_trigger is useful when a field with no event subscriber available changes.
# e.g. gr.Gallery, gr.State, etc. After an update to gr.State / gr.Gallery, please increment
# this counter to trigger a sync update of UiControlNetUnit.
# this counter to trigger a sync update of ControlNetUnit.
self.dummy_gradio_update_trigger = None
self.enabled = None
self.upload_tab = None
@ -185,10 +186,11 @@ class ControlNetUiGroup(object):
self.mask_image = None
self.batch_tab = None
self.batch_image_dir = None
self.merge_tab = None
self.batch_upload_tab = None
self.batch_input_gallery = None
self.merge_upload_button = None
self.merge_clear_button = None
self.batch_mask_gallery = None
self.multi_inputs_upload_tab = None
self.multi_inputs_input_gallery = None
self.create_canvas = None
self.canvas_width = None
self.canvas_height = None
@ -225,6 +227,7 @@ class ControlNetUiGroup(object):
self.hr_option = None
self.batch_image_dir_state = None
self.output_dir_state = None
self.advanced_weighting = gr.State(None)
# Internal states for UI state pasting.
self.prevent_next_n_module_update = 0
@ -331,31 +334,17 @@ class ControlNetUiGroup(object):
visible=False,
)
with gr.Tab(label="Batch Upload") as self.merge_tab:
with gr.Tab(label="Batch Upload") as self.batch_upload_tab:
with gr.Row():
with gr.Column():
self.batch_input_gallery = gr.Gallery(
columns=[4], rows=[2], object_fit="contain", height="auto", label="Images"
)
with gr.Row():
self.merge_upload_button = gr.UploadButton(
"Upload Images",
file_types=["image"],
file_count="multiple",
)
self.merge_clear_button = gr.Button("Clear Images")
with gr.Group(visible=False, elem_classes=["cnet-mask-gallery-group"]) as self.batch_mask_gallery_group:
with gr.Column():
self.batch_mask_gallery = gr.Gallery(
columns=[4], rows=[2], object_fit="contain", height="auto", label="Masks"
)
with gr.Row():
self.mask_merge_upload_button = gr.UploadButton(
"Upload Masks",
file_types=["image"],
file_count="multiple",
)
self.mask_merge_clear_button = gr.Button("Clear Masks")
self.batch_input_gallery = MultiInputsGallery()
self.batch_mask_gallery = MultiInputsGallery(
visible=False,
elem_classes=["cnet-mask-gallery-group"]
)
with gr.Tab(label="Multi-Inputs") as self.multi_inputs_upload_tab:
with gr.Row():
self.multi_inputs_gallery = MultiInputsGallery()
if self.photopea:
self.photopea.attach_photopea_output(self.generated_image)
@ -600,8 +589,9 @@ class ControlNetUiGroup(object):
self.use_preview_as_input,
self.batch_image_dir,
self.batch_mask_dir,
self.batch_input_gallery,
self.batch_mask_gallery,
self.batch_input_gallery.input_gallery,
self.batch_mask_gallery.input_gallery,
self.multi_inputs_gallery.input_gallery,
self.generated_image,
self.mask_image,
self.hr_option,
@ -618,9 +608,16 @@ class ControlNetUiGroup(object):
self.guidance_end,
self.pixel_perfect,
self.control_mode,
self.advanced_weighting,
)
unit = gr.State(self.default_unit)
def create_unit(*args):
return ControlNetUnit.from_dict({
k: v
for k, v in zip(vars(ControlNetUnit()).keys(), args)
})
for comp in unit_args + (self.dummy_gradio_update_trigger,):
event_subscribers = []
if hasattr(comp, "edit"):
@ -637,7 +634,7 @@ class ControlNetUiGroup(object):
for event_subscriber in event_subscribers:
event_subscriber(
fn=UiControlNetUnit, inputs=list(unit_args), outputs=unit
fn=create_unit, inputs=list(unit_args), outputs=unit
)
(
@ -645,7 +642,7 @@ class ControlNetUiGroup(object):
if self.is_img2img
else ControlNetUiGroup.a1111_context.txt2img_submit_button
).click(
fn=UiControlNetUnit,
fn=create_unit,
inputs=list(unit_args),
outputs=unit,
queue=False,
@ -981,20 +978,41 @@ class ControlNetUiGroup(object):
"""Controls whether the upload mask input should be visible."""
def on_checkbox_click(checked: bool, canvas_height: int, canvas_width: int):
if not checked:
# Clear mask_image if unchecked.
return gr.update(visible=False), gr.update(value=None), gr.update(value=None, visible=False), \
gr.update(visible=False), gr.update(value=None)
# Clear mask inputs if unchecked.
return (
# Single mask upload.
gr.update(visible=False),
gr.update(value=None),
# Batch mask upload dir.
gr.update(value=None, visible=False),
# Multi mask upload gallery.
gr.update(visible=False),
gr.update(value=None)
)
else:
# Init an empty canvas the same size as the generation target.
empty_canvas = np.zeros(shape=(canvas_height, canvas_width, 3), dtype=np.uint8)
return gr.update(visible=True), gr.update(value=empty_canvas), gr.update(visible=True), \
gr.update(visible=True), gr.update()
return (
# Single mask upload.
gr.update(visible=True),
gr.update(value=empty_canvas),
# Batch mask upload dir.
gr.update(visible=True),
# Multi mask upload gallery.
gr.update(visible=True),
gr.update(),
)
self.mask_upload.change(
fn=on_checkbox_click,
inputs=[self.mask_upload, self.height_slider, self.width_slider],
outputs=[self.mask_image_group, self.mask_image, self.batch_mask_dir,
self.batch_mask_gallery_group, self.batch_mask_gallery],
outputs=[
self.mask_image_group,
self.mask_image,
self.batch_mask_dir,
self.batch_mask_gallery.group,
self.batch_mask_gallery.input_gallery,
],
show_progress=False,
)
@ -1079,51 +1097,14 @@ class ControlNetUiGroup(object):
def register_multi_images_upload(self):
"""Register callbacks on merge tab multiple images upload."""
self.merge_clear_button.click(
fn=lambda: [],
inputs=[],
outputs=[self.batch_input_gallery],
).then(
fn=lambda x: gr.update(value=x + 1),
trigger_dict = dict(
fn=lambda n: gr.update(value=n + 1),
inputs=[self.dummy_gradio_update_trigger],
outputs=[self.dummy_gradio_update_trigger],
)
self.mask_merge_clear_button.click(
fn=lambda: [],
inputs=[],
outputs=[self.batch_mask_gallery],
).then(
fn=lambda x: gr.update(value=x + 1),
inputs=[self.dummy_gradio_update_trigger],
outputs=[self.dummy_gradio_update_trigger],
)
def upload_file(files, current_files):
return {file_d["name"] for file_d in current_files} | {
file.name for file in files
}
self.merge_upload_button.upload(
upload_file,
inputs=[self.merge_upload_button, self.batch_input_gallery],
outputs=[self.batch_input_gallery],
queue=False,
).then(
fn=lambda x: gr.update(value=x + 1),
inputs=[self.dummy_gradio_update_trigger],
outputs=[self.dummy_gradio_update_trigger],
)
self.mask_merge_upload_button.upload(
upload_file,
inputs=[self.mask_merge_upload_button, self.batch_mask_gallery],
outputs=[self.batch_mask_gallery],
queue=False,
).then(
fn=lambda x: gr.update(value=x + 1),
inputs=[self.dummy_gradio_update_trigger],
outputs=[self.dummy_gradio_update_trigger],
)
return
self.batch_input_gallery.register_callbacks(change_trigger=trigger_dict)
self.batch_mask_gallery.register_callbacks(change_trigger=trigger_dict)
self.multi_inputs_gallery.register_callbacks(change_trigger=trigger_dict)
def register_core_callbacks(self):
"""Register core callbacks that only involves gradio components defined
@ -1213,7 +1194,8 @@ class ControlNetUiGroup(object):
for input_tab, fn in (
(ui_group.upload_tab, simple_fn),
(ui_group.batch_tab, batch_fn),
(ui_group.merge_tab, merge_fn),
(ui_group.batch_upload_tab, batch_fn),
(ui_group.multi_inputs_upload_tab, merge_fn),
):
# Sync input_mode.
input_tab.select(

View File

@ -0,0 +1,65 @@
import gradio as gr
from typing import Optional
class MultiInputsGallery:
"""A gallery object that accepts multiple input images."""
def __init__(self, row: int = 2, column: int = 4, **group_kwargs) -> None:
self.gallery_row_num = row
self.gallery_column_num = column
self.group_kwargs = group_kwargs
self.group = None
self.input_gallery = None
self.upload_button = None
self.clear_button = None
self.render()
def render(self):
with gr.Group(**self.group_kwargs) as self.group:
with gr.Column():
self.input_gallery = gr.Gallery(
columns=[self.gallery_column_num],
rows=[self.gallery_row_num],
object_fit="contain",
height="auto",
label="Images",
)
with gr.Row():
self.upload_button = gr.UploadButton(
"Upload Images",
file_types=["image"],
file_count="multiple",
)
self.clear_button = gr.Button("Clear Images")
def register_callbacks(self, change_trigger: Optional[dict] = None):
"""Register callbacks on multiple images upload.
Argument:
- change_trigger: An optional gradio callback param dict to be called
after gallery content change. This is necessary as gallery has no
event subscriber. If the state change of gallery needs to be observed,
the caller needs to pass a change trigger to observe the change.
"""
handle1 = self.clear_button.click(
fn=lambda: [],
inputs=[],
outputs=[self.input_gallery],
)
def upload_file(files, current_files):
return {file_d["name"] for file_d in current_files} | {
file.name for file in files
}
handle2 = self.upload_button.upload(
upload_file,
inputs=[self.upload_button, self.input_gallery],
outputs=[self.input_gallery],
queue=False,
)
if change_trigger:
for handle in (handle1, handle2):
handle.then(**change_trigger)

View File

@ -183,6 +183,8 @@ class ControlNetUnit:
batch_input_gallery: Optional[List[str]] = None
# Optional list of gallery masks for batch processing; defaults to None.
batch_mask_gallery: Optional[List[str]] = None
# Optional list of gallery images for multi-inputs; defaults to None.
multi_inputs_gallery: Optional[List[str]] = None
# Holds the preview image as a NumPy array; defaults to None.
generated_image: Optional[np.ndarray] = None
# ====== End of UI only fields ======
@ -192,7 +194,7 @@ class ControlNetUnit:
mask_image: Optional[GradioImageMaskPair] = None
# Specifies how this unit should be applied in each pass of high-resolution fix.
# Ignored if high-resolution fix is not enabled.
hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH
hr_option: HiResFixOption = HiResFixOption.BOTH
# Indicates whether the unit is enabled; defaults to True.
enabled: bool = True
# Name of the module being used; defaults to "None".
@ -204,7 +206,7 @@ class ControlNetUnit:
# Optional image for input; defaults to None.
image: Optional[GradioImageMaskPair] = None
# Specifies the mode of image resizing; defaults to inner fit.
resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT
resize_mode: ResizeMode = ResizeMode.INNER_FIT
# Resolution for processing by the unit; defaults to -1 (unspecified).
processor_res: int = -1
# Threshold A for processing; defaults to -1 (unspecified).
@ -218,7 +220,22 @@ class ControlNetUnit:
# Enables pixel-perfect processing; defaults to False.
pixel_perfect: bool = False
# Control mode for the unit; defaults to balanced.
control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED
control_mode: ControlMode = ControlMode.BALANCED
# Weight for each layer of ControlNet params.
# For ControlNet:
# - SD1.5: 13 weights (4 encoder block * 3 + 1 middle block)
# - SDXL: 10 weights (3 encoder block * 3 + 1 middle block)
# For T2IAdapter
# - SD1.5: 5 weights (4 encoder block + 1 middle block)
# - SDXL: 4 weights (3 encoder block + 1 middle block)
# For IPAdapter
# - SD15: 16 (6 input blocks + 9 output blocks + 1 middle block)
# - SDXL: 11 weights (4 input blocks + 6 output blocks + 1 middle block)
# Note1: Setting advanced weighting will disable `soft_injection`, i.e.
# It is recommended to set ControlMode = BALANCED when using `advanced_weighting`.
# Note2: The field `weight` is still used in some places, e.g. reference_only,
# even advanced_weighting is set.
advanced_weighting: Optional[List[float]] = None
# Following fields should only be used in the API.
# ====== Start of API only fields ======
@ -275,6 +292,11 @@ class ControlNetUnit:
"image": mask,
"mask": np.zeros_like(mask),
}
# Convert strings to enums.
unit.input_mode = InputMode(unit.input_mode)
unit.hr_option = HiResFixOption(unit.hr_option)
unit.resize_mode = ResizeMode(unit.resize_mode)
unit.control_mode = ControlMode(unit.control_mode)
return unit

View File

@ -6,6 +6,7 @@ from modules import shared, sd_models
from lib_controlnet.enums import StableDiffusionVersion
from modules_forge.shared import controlnet_dir, supported_preprocessors
from typing import Dict, Tuple, List
CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin", ".patch"]
@ -56,6 +57,10 @@ controlnet_names = ['None']
def get_preprocessor(name):
return supported_preprocessors.get(name, None)
def get_default_preprocessor(tag):
ps = get_filtered_preprocessor_names(tag)
assert len(ps) > 0
return ps[0] if len(ps) == 1 else ps[1]
def get_sorted_preprocessors():
preprocessors = [p for k, p in supported_preprocessors.items() if k != 'None']
@ -144,3 +149,44 @@ def get_sd_version() -> StableDiffusionVersion:
return StableDiffusionVersion.SD1x
else:
return StableDiffusionVersion.UNKNOWN
def select_control_type(
control_type: str,
sd_version: StableDiffusionVersion = StableDiffusionVersion.UNKNOWN,
) -> Tuple[List[str], List[str], str, str]:
global controlnet_names
pattern = control_type.lower()
all_models = list(controlnet_names)
if pattern == "all":
preprocessors = get_sorted_preprocessors().values()
return [
[p.name for p in preprocessors],
all_models,
'none', # default option
"None" # default model
]
filtered_model_list = get_filtered_controlnet_names(control_type)
if pattern == "none":
filtered_model_list.append("None")
assert len(filtered_model_list) > 0, "'None' model should always be available."
if len(filtered_model_list) == 1:
default_model = "None"
else:
default_model = filtered_model_list[1]
for x in filtered_model_list:
if "11" in x.split("[")[0]:
default_model = x
break
return (
get_filtered_preprocessor_names(control_type),
filtered_model_list,
get_default_preprocessor(control_type),
default_model
)

View File

@ -1,7 +1,10 @@
from typing import Optional
from copy import copy
from modules import processing
from modules.api import api
from lib_controlnet import external_code
from lib_controlnet.external_code import InputMode, ControlNetUnit
from modules_forge.forge_util import HWC3
@ -361,3 +364,22 @@ def crop_and_resize_image(detected_map, resize_mode, h, w, fill_border_with_255=
def judge_image_type(img):
return isinstance(img, np.ndarray) and img.ndim == 3 and int(img.shape[2]) in [3, 4]
def try_unfold_unit(unit: ControlNetUnit) -> List[ControlNetUnit]:
"""Unfolds an multi-inputs unit into multiple units with one input."""
if unit.input_mode != InputMode.MERGE:
return [unit]
def extract_unit(gallery_item: dict) -> ControlNetUnit:
r_unit = copy(unit)
img = np.array(api.decode_base64_to_image(read_image(gallery_item["name"]))).astype('uint8')
r_unit.image = {
"image": img,
"mask": np.zeros_like(img),
}
r_unit.input_mode = InputMode.SIMPLE
r_unit.weight = unit.weight / len(unit.multi_inputs_gallery)
return r_unit
return [extract_unit(item) for item in unit.multi_inputs_gallery]

View File

@ -11,9 +11,15 @@ from modules.api.api import decode_base64_to_image
import gradio as gr
from lib_controlnet import global_state, external_code
from lib_controlnet.external_code import ControlNetUnit
from lib_controlnet.utils import align_dim_latent, set_numpy_seed, crop_and_resize_image, \
prepare_mask, judge_image_type
from lib_controlnet.external_code import ControlNetUnit, InputMode
from lib_controlnet.utils import (
align_dim_latent,
set_numpy_seed,
crop_and_resize_image,
prepare_mask,
judge_image_type,
try_unfold_unit,
)
from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup
from lib_controlnet.controlnet_ui.photopea import Photopea
from lib_controlnet.logging import logger
@ -53,6 +59,7 @@ class ControlNetCachedParameters:
self.control_cond_for_hr_fix = None
self.control_mask = None
self.control_mask_for_hr_fix = None
self.advanced_weighting = None
class ControlNetForForgeOfficial(scripts.Script):
@ -105,8 +112,13 @@ class ControlNetForForgeOfficial(scripts.Script):
for unit in units
]
assert all(isinstance(unit, ControlNetUnit) for unit in units)
enabled_units = [x for x in units if x.enabled]
return enabled_units
return [
simple_unit
for unit in units
# Unfolds multi-inputs units.
for simple_unit in try_unfold_unit(unit)
if simple_unit.enabled
]
@staticmethod
def try_crop_image_with_a1111_mask(
@ -153,9 +165,9 @@ class ControlNetForForgeOfficial(scripts.Script):
def get_input_data(self, p, unit, preprocessor, h, w):
logger.info(f'ControlNet Input Mode: {unit.input_mode}')
image_list = []
resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
resize_mode = unit.resize_mode
if unit.input_mode == external_code.InputMode.MERGE:
if unit.input_mode == InputMode.MERGE:
for idx, item in enumerate(unit.batch_input_gallery):
img_path = item['name']
logger.info(f'Try to read image: {img_path}')
@ -169,7 +181,7 @@ class ControlNetForForgeOfficial(scripts.Script):
mask = np.ascontiguousarray(cv2.imread(mask_path)[:, :, ::-1]).copy()
if img is not None:
image_list.append([img, mask])
elif unit.input_mode == external_code.InputMode.BATCH:
elif unit.input_mode == InputMode.BATCH:
image_list = []
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
batch_image_files = shared.listfiles(unit.batch_image_dir)
@ -339,7 +351,7 @@ class ControlNetForForgeOfficial(scripts.Script):
break
if has_high_res_fix:
hr_option = HiResFixOption.from_value(unit.hr_option)
hr_option = unit.hr_option
else:
hr_option = HiResFixOption.BOTH
@ -430,7 +442,7 @@ class ControlNetForForgeOfficial(scripts.Script):
)
if has_high_res_fix:
hr_option = HiResFixOption.from_value(unit.hr_option)
hr_option = unit.hr_option
else:
hr_option = HiResFixOption.BOTH
@ -494,6 +506,11 @@ class ControlNetForForgeOfficial(scripts.Script):
params.model.positive_advanced_weighting = soft_weighting.copy()
params.model.negative_advanced_weighting = soft_weighting.copy()
if unit.advanced_weighting is not None:
if params.model.positive_advanced_weighting is None:
logger.warn("advanced_weighting overwrite control_mode")
params.model.positive_advanced_weighting = unit.advanced_weighting
cond, mask = params.preprocessor.process_before_every_sampling(p, cond, mask, *args, **kwargs)
params.model.advanced_mask_weighting = mask

View File

@ -4,6 +4,7 @@ from .template import (
APITestTemplate,
girl_img,
mask_img,
portrait_imgs,
disable_in_cq,
get_model,
)
@ -169,3 +170,60 @@ def test_lama_outpaint():
"resize_mode": "Resize and Fill", # OUTER_FIT
},
).exec()
@disable_in_cq
def test_instant_id_sdxl():
assert len(portrait_imgs) > 0
assert APITestTemplate(
"instant_id_sdxl",
"txt2img",
payload_overrides={
"width": 1000,
"height": 1000,
"prompt": "1girl, red background",
},
unit_overrides=[
dict(
image=portrait_imgs[0],
model=get_model("ip-adapter_instant_id_sdxl"),
module="InsightFace (InstantID)",
),
dict(
image=portrait_imgs[1],
model=get_model("control_instant_id_sdxl"),
module="instant_id_face_keypoints",
),
],
).exec()
@disable_in_cq
def test_instant_id_sdxl_multiple_units():
assert len(portrait_imgs) > 0
assert APITestTemplate(
"instant_id_sdxl_multiple_units",
"txt2img",
payload_overrides={
"width": 1000,
"height": 1000,
"prompt": "1girl, red background",
},
unit_overrides=[
dict(
image=portrait_imgs[0],
model=get_model("ip-adapter_instant_id_sdxl"),
module="InsightFace (InstantID)",
),
dict(
image=portrait_imgs[1],
model=get_model("control_instant_id_sdxl"),
module="instant_id_face_keypoints",
),
dict(
image=portrait_imgs[1],
model=get_model("diffusers_xl_canny"),
module="canny",
),
],
).exec()

View File

@ -0,0 +1,43 @@
from .template import (
APITestTemplate,
realistic_girl_face_img,
disable_in_cq,
get_model,
)
@disable_in_cq
def test_ipadapter_advanced_weighting():
weights = [0.0] * 16 # 16 weights for SD15 / 11 weights for SDXL
# SD15 composition
weights[4] = 0.25
weights[5] = 1.0
APITestTemplate(
"test_ipadapter_advanced_weighting",
"txt2img",
payload_overrides={
"width": 512,
"height": 512,
},
unit_overrides={
"image": realistic_girl_face_img,
"module": "CLIP-ViT-H (IPAdapter)",
"model": get_model("ip-adapter_sd15"),
"advanced_weighting": weights,
},
).exec()
APITestTemplate(
"test_ipadapter_advanced_weighting_ref",
"txt2img",
payload_overrides={
"width": 512,
"height": 512,
},
unit_overrides={
"image": realistic_girl_face_img,
"module": "CLIP-ViT-H (IPAdapter)",
"model": get_model("ip-adapter_sd15"),
},
).exec()

View File

@ -248,13 +248,13 @@ def get_model(model_name: str) -> str:
default_unit = {
"control_mode": 0,
"control_mode": "Balanced",
"enabled": True,
"guidance_end": 1,
"guidance_start": 0,
"pixel_perfect": True,
"processor_res": 512,
"resize_mode": 1,
"resize_mode": "Crop and Resize",
"threshold_a": 64,
"threshold_b": 64,
"weight": 1,

View File

@ -1,6 +1,12 @@
import gradio as gr
from modules import scripts
import sys
import traceback
from typing import Any
from functools import partial
from modules import script_callbacks, scripts
from ldm_patched.contrib.external_freelunch import FreeU_V2
@ -68,6 +74,18 @@ class FreeUForForge(scripts.Script):
freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2 = script_args
xyz = getattr(p, "_freeu_xyz", {})
if "freeu_enabled" in xyz:
freeu_enabled = xyz["freeu_enabled"] == "True"
if "freeu_b1" in xyz:
freeu_b1 = xyz["freeu_b1"]
if "freeu_b2" in xyz:
freeu_b2 = xyz["freeu_b2"]
if "freeu_s1" in xyz:
freeu_s1 = xyz["freeu_s1"]
if "freeu_s2" in xyz:
freeu_s2 = xyz["freeu_s2"]
if not freeu_enabled:
return
@ -89,3 +107,62 @@ class FreeUForForge(scripts.Script):
))
return
def set_value(p, x: Any, xs: Any, *, field: str):
if not hasattr(p, "_freeu_xyz"):
p._freeu_xyz = {}
p._freeu_xyz[field] = x
def make_axis_on_xyz_grid():
xyz_grid = None
for script in scripts.scripts_data:
if script.script_class.__module__ == "xyz_grid.py":
xyz_grid = script.module
break
if xyz_grid is None:
return
axis = [
xyz_grid.AxisOption(
"FreeU Enabled",
str,
partial(set_value, field="freeu_enabled"),
choices=lambda: ["True", "False"]
),
xyz_grid.AxisOption(
"FreeU B1",
float,
partial(set_value, field="freeu_b1"),
),
xyz_grid.AxisOption(
"FreeU B2",
float,
partial(set_value, field="freeu_b2"),
),
xyz_grid.AxisOption(
"FreeU S1",
float,
partial(set_value, field="freeu_s1"),
),
xyz_grid.AxisOption(
"FreeU S2",
float,
partial(set_value, field="freeu_s2"),
),
]
if not any(x.label.startswith("FreeU") for x in xyz_grid.axis_options):
xyz_grid.axis_options.extend(axis)
def on_before_ui():
try:
make_axis_on_xyz_grid()
except Exception:
error = traceback.format_exc()
print(
f"[-] FreeU Integrated: xyz_grid error:\n{error}",
file=sys.stderr,
)
script_callbacks.on_before_ui(on_before_ui)

View File

@ -7,6 +7,7 @@ import math
import ldm_patched.modules.utils
import ldm_patched.modules.model_management
from ldm_patched.modules.controlnet import ControlNet
from ldm_patched.modules.clip_vision import clip_preprocess
from ldm_patched.ldm.modules.attention import optimized_attention
from ldm_patched.utils import path_utils as folder_paths
@ -465,8 +466,18 @@ class CrossAttentionPatch:
ip_v = ip_v_offset + ip_v_mean * W
out_ip = optimized_attention(q, ip_k.to(org_dtype), ip_v.to(org_dtype), extra_options["n_heads"])
if weight_type.startswith("original"):
out_ip = out_ip * weight
if weight_type == "original":
assert isinstance(weight, (float, int))
weight = weight
elif weight_type == "advanced":
assert isinstance(weight, list)
transformer_index: int = extra_options["transformer_index"]
assert transformer_index < len(weight)
weight = weight[transformer_index]
else:
weight = 1.0
out_ip = out_ip * weight
if mask is not None:
# TODO: needs checking
@ -749,13 +760,27 @@ class IPAdapterApply:
work_model = model.clone()
if self.is_instant_id:
def modifier(cnet, x_noisy, t, cond, batched_number):
def instant_id_modifier(cnet: ControlNet, x_noisy, t, cond, batched_number):
"""Overwrites crossattn inputs to InstantID ControlNet with ipadapter image embeds.
TODO: There can be multiple pairs of InstantID (ipadapter/controlnet) to control
rendering of multiple faces on canvas. We need to find a way to pair them. Currently,
the modifier is unconditionally applied to all instant id ControlNet units.
"""
if (
not isinstance(cnet, ControlNet) or
# model_file_name is None for Control LoRA.
cnet.control_model.model_file_name is None or
"instant_id" not in cnet.control_model.model_file_name.lower()
):
return x_noisy, t, cond, batched_number
cond_mark = cond['transformer_options']['cond_mark'][:, None, None].to(cond['c_crossattn']) # cond is 0
c_crossattn = image_prompt_embeds * (1.0 - cond_mark) + uncond_image_prompt_embeds * cond_mark
cond['c_crossattn'] = c_crossattn
return x_noisy, t, cond, batched_number
work_model.add_controlnet_conditioning_modifier(modifier)
work_model.add_controlnet_conditioning_modifier(instant_id_modifier)
if attn_mask is not None:
attn_mask = attn_mask.to(self.device)

View File

@ -143,11 +143,17 @@ class IPAdapterPatcher(ControlModelPatcher):
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
unet = process.sd_model.forge_objects.unet
if self.positive_advanced_weighting is None:
weight = self.strength
cond["weight_type"] = "original"
else:
weight = self.positive_advanced_weighting
cond["weight_type"] = "advanced"
unet = opIPAdapterApply(
ipadapter=self.ip_adapter,
model=unet,
weight=self.strength,
weight=weight,
start_at=self.start_percent,
end_at=self.end_percent,
faceid_v2=self.faceid_v2,

View File

@ -1,7 +1,7 @@
<div>
<a href="{api_docs}">API</a>
 • 
<a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui">Github</a>
<a href="https://github.com/lllyasviel/stable-diffusion-webui-forge">Github</a>
 • 
<a href="https://gradio.app">Gradio</a>
 • 

View File

@ -4,6 +4,7 @@
import torch
import torch as th
import torch.nn as nn
from typing import Optional
from ldm_patched.ldm.modules.diffusionmodules.util import (
zero_module,
@ -54,9 +55,12 @@ class ControlNet(nn.Module):
transformer_depth_output=None,
device=None,
operations=ldm_patched.modules.ops.disable_weight_init,
model_file_name: Optional[str] = None, # Name of model file.
**kwargs,
):
super().__init__()
self.model_file_name = model_file_name
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'

View File

@ -6,6 +6,7 @@ import math
from scipy import integrate
import torch
import numpy as np
from torch import nn
import torchsde
from tqdm.auto import trange, tqdm
@ -38,6 +39,35 @@ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
return append_zero(sigmas)
# align your steps
def get_sigmas_ays(n, sigma_min, sigma_max, is_sdxl=False, device='cpu'):
# https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
def loglinear_interp(t_steps, num_steps):
"""
Performs log-linear interpolation of a given array of decreasing numbers.
"""
xs = torch.linspace(0, 1, len(t_steps))
ys = torch.log(torch.tensor(t_steps[::-1]))
new_xs = torch.linspace(0, 1, num_steps)
new_ys = np.interp(new_xs, xs, ys)
interped_ys = torch.exp(torch.tensor(new_ys)).numpy()[::-1].copy()
return interped_ys
if is_sdxl:
sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.029]
else:
# Default to SD 1.5 sigmas.
sigmas = [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029]
if n != len(sigmas):
sigmas = np.append(loglinear_interp(sigmas, n), [0.0])
else:
sigmas.append(0.0)
return torch.FloatTensor(sigmas).to(device)
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
"""Constructs a continuous VP noise schedule."""

View File

@ -487,7 +487,7 @@ def load_controlnet(ckpt_path, model=None):
controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config)
control_model = ldm_patched.controlnet.cldm.ControlNet(model_file_name=ckpt_path, **controlnet_config)
if pth:
if 'difference' in controlnet_data:

View File

@ -107,9 +107,11 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
sigma_min = sampling_settings.get("sigma_min", 0.002)
sigma_max = sampling_settings.get("sigma_max", 120.0)
self.set_sigma_range(sigma_min, sigma_max)
sigma_data = sampling_settings.get("sigma_data", 1.0)
self.set_sigma_range(sigma_min, sigma_max, sigma_data)
def set_sigma_range(self, sigma_min, sigma_max):
def set_sigma_range(self, sigma_min, sigma_max, sigma_data):
self.sigma_data = sigma_data
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp()
self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers

View File

@ -662,14 +662,16 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return model.process_latent_out(samples.to(torch.float32))
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "ays"]
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
def calculate_sigmas_scheduler(model, scheduler_name, steps):
def calculate_sigmas_scheduler(model, scheduler_name, steps, is_sdxl=False):
if scheduler_name == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
elif scheduler_name == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
elif scheduler_name == "ays":
sigmas = k_diffusion_sampling.get_sigmas_ays(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max), is_sdxl=is_sdxl)
elif scheduler_name == "normal":
sigmas = normal_scheduler(model, steps)
elif scheduler_name == "simple":

View File

@ -169,6 +169,11 @@ class SDXL(supported_models_base.BASE):
def model_type(self, state_dict, prefix=""):
if "v_pred" in state_dict:
return model_base.ModelType.V_PREDICTION
elif "edm_vpred.sigma_max" in state_dict:
self.sampling_settings["sigma_max"] = round(float(state_dict["edm_vpred.sigma_max"].item()),3)
if "edm_vpred.sigma_min" in state_dict:
self.sampling_settings["sigma_min"] = round(float(state_dict["edm_vpred.sigma_min"].item()),3)
return model_base.ModelType.V_PREDICTION_EDM
else:
return model_base.ModelType.EPS

View File

@ -516,7 +516,7 @@ def configure_forge_reference_checkout(a1111_home: Path):
ModelRef(arg_name="--vae-dir", relative_path="models/VAE"),
ModelRef(arg_name="--hypernetwork-dir", relative_path="models/hypernetworks"),
ModelRef(arg_name="--embeddings-dir", relative_path="embeddings"),
ModelRef(arg_name="--lora-dir", relative_path="models/lora"),
ModelRef(arg_name="--lora-dir", relative_path="models/Lora"),
# Ref A1111 need to have sd-webui-controlnet installed.
ModelRef(arg_name="--controlnet-dir", relative_path="models/ControlNet"),
ModelRef(arg_name="--controlnet-preprocessor-models-dir", relative_path="extensions/sd-webui-controlnet/annotator/downloads"),

View File

@ -10,6 +10,7 @@ class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler):
self.sampler_name = sampler_name
self.scheduler_name = scheduler_name
self.unet = sd_model.forge_objects.unet
self.model = sd_model
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
super().__init__(sampler_function, sd_model, None)
@ -20,7 +21,7 @@ class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler):
sigmas = self.unet.model.model_sampling.sigma(timesteps)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
else:
sigmas = calculate_sigmas_scheduler(self.unet.model, self.scheduler_name, steps)
sigmas = calculate_sigmas_scheduler(self.unet.model, self.scheduler_name, steps, is_sdxl=getattr(self.model, "is_sdxl", False))
return sigmas.to(self.unet.load_device)
@ -34,9 +35,13 @@ def build_constructor(sampler_name, scheduler_name):
samplers_data_alter = [
sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm', scheduler_name='normal'), ['ddpm'], {}),
sd_samplers_common.SamplerData('DDPM Karras', build_constructor(sampler_name='ddpm', scheduler_name='karras'), ['ddpm_karras'], {}),
sd_samplers_common.SamplerData('Euler AYS', build_constructor(sampler_name='euler', scheduler_name='ays'), ['euler_ays'], {}),
sd_samplers_common.SamplerData('Euler A Turbo', build_constructor(sampler_name='euler_ancestral', scheduler_name='turbo'), ['euler_ancestral_turbo'], {}),
sd_samplers_common.SamplerData('Euler A AYS', build_constructor(sampler_name='euler_ancestral', scheduler_name='ays'), ['euler_ancestral_ays'], {}),
sd_samplers_common.SamplerData('DPM++ 2M Turbo', build_constructor(sampler_name='dpmpp_2m', scheduler_name='turbo'), ['dpmpp_2m_turbo'], {}),
sd_samplers_common.SamplerData('DPM++ 2M AYS', build_constructor(sampler_name='dpmpp_2m', scheduler_name='ays'), ['dpmpp_2m_ays'], {}),
sd_samplers_common.SamplerData('DPM++ 2M SDE Turbo', build_constructor(sampler_name='dpmpp_2m_sde', scheduler_name='turbo'), ['dpmpp_2m_sde_turbo'], {}),
sd_samplers_common.SamplerData('DPM++ 2M SDE AYS', build_constructor(sampler_name='dpmpp_2m_sde', scheduler_name='ays'), ['dpmpp_2m_sde_ays'], {}),
sd_samplers_common.SamplerData('LCM Karras', build_constructor(sampler_name='lcm', scheduler_name='karras'), ['lcm_karras'], {}),
sd_samplers_common.SamplerData('Euler SGMUniform', build_constructor(sampler_name='euler', scheduler_name='sgm_uniform'), ['euler_sgm_uniform'], {}),
sd_samplers_common.SamplerData('Euler A SGMUniform', build_constructor(sampler_name='euler_ancestral', scheduler_name='sgm_uniform'), ['euler_ancestral_sgm_uniform'], {}),

View File

@ -113,7 +113,7 @@ class ControlNetPatcher(ControlModelPatcher):
controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config)
control_model = ldm_patched.controlnet.cldm.ControlNet(model_file_name=ckpt_path, **controlnet_config)
if pth:
if 'difference' in controlnet_data: