diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/api.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/api.py index 29e7d662..8bd43ae8 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/api.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/api.py @@ -11,6 +11,8 @@ from .global_state import ( get_all_preprocessor_names, get_all_controlnet_names, get_preprocessor, + get_all_preprocessor_tags, + select_control_type, ) from .utils import judge_image_type from .logging import logger @@ -53,6 +55,30 @@ def controlnet_api(_: gr.Blocks, app: FastAPI): # "module_detail": external_code.get_modules_detail(alias_names), } + @app.get("/controlnet/control_types") + async def control_types(): + def format_control_type( + filtered_preprocessor_list, + filtered_model_list, + default_option, + default_model, + ): + control_dict = { + "module_list": filtered_preprocessor_list, + "model_list": filtered_model_list, + "default_option": default_option, + "default_model": default_model, + } + + return control_dict + + return { + "control_types": { + control_type: format_control_type(*select_control_type(control_type)) + for control_type in get_all_preprocessor_tags() + } + } + @app.post("/controlnet/detect") async def detect( controlnet_module: str = Body("none", title="Controlnet Module"), diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py index b3e16df6..0f58b24a 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py @@ -11,16 +11,17 @@ from lib_controlnet import ( global_state, external_code, ) +from lib_controlnet.external_code import ControlNetUnit from lib_controlnet.logging import logger from lib_controlnet.controlnet_ui.openpose_editor import OpenposeEditor from lib_controlnet.controlnet_ui.preset import ControlNetPresetUI from lib_controlnet.controlnet_ui.tool_button import ToolButton from lib_controlnet.controlnet_ui.photopea import Photopea +from lib_controlnet.controlnet_ui.multi_inputs_gallery import MultiInputsGallery from lib_controlnet.enums import InputMode, HiResFixOption from modules import shared, script_callbacks from modules.ui_components import FormRow from modules_forge.forge_util import HWC3 -from lib_controlnet.external_code import UiControlNetUnit @dataclass @@ -171,10 +172,10 @@ class ControlNetUiGroup(object): self.webcam_mirrored = False # Note: All gradio elements declared in `render` will be defined as member variable. - # Update counter to trigger a force update of UiControlNetUnit. + # Update counter to trigger a force update of ControlNetUnit. # dummy_gradio_update_trigger is useful when a field with no event subscriber available changes. # e.g. gr.Gallery, gr.State, etc. After an update to gr.State / gr.Gallery, please increment - # this counter to trigger a sync update of UiControlNetUnit. + # this counter to trigger a sync update of ControlNetUnit. self.dummy_gradio_update_trigger = None self.enabled = None self.upload_tab = None @@ -185,10 +186,11 @@ class ControlNetUiGroup(object): self.mask_image = None self.batch_tab = None self.batch_image_dir = None - self.merge_tab = None + self.batch_upload_tab = None self.batch_input_gallery = None - self.merge_upload_button = None - self.merge_clear_button = None + self.batch_mask_gallery = None + self.multi_inputs_upload_tab = None + self.multi_inputs_input_gallery = None self.create_canvas = None self.canvas_width = None self.canvas_height = None @@ -225,6 +227,7 @@ class ControlNetUiGroup(object): self.hr_option = None self.batch_image_dir_state = None self.output_dir_state = None + self.advanced_weighting = gr.State(None) # Internal states for UI state pasting. self.prevent_next_n_module_update = 0 @@ -331,31 +334,17 @@ class ControlNetUiGroup(object): visible=False, ) - with gr.Tab(label="Batch Upload") as self.merge_tab: + with gr.Tab(label="Batch Upload") as self.batch_upload_tab: with gr.Row(): - with gr.Column(): - self.batch_input_gallery = gr.Gallery( - columns=[4], rows=[2], object_fit="contain", height="auto", label="Images" - ) - with gr.Row(): - self.merge_upload_button = gr.UploadButton( - "Upload Images", - file_types=["image"], - file_count="multiple", - ) - self.merge_clear_button = gr.Button("Clear Images") - with gr.Group(visible=False, elem_classes=["cnet-mask-gallery-group"]) as self.batch_mask_gallery_group: - with gr.Column(): - self.batch_mask_gallery = gr.Gallery( - columns=[4], rows=[2], object_fit="contain", height="auto", label="Masks" - ) - with gr.Row(): - self.mask_merge_upload_button = gr.UploadButton( - "Upload Masks", - file_types=["image"], - file_count="multiple", - ) - self.mask_merge_clear_button = gr.Button("Clear Masks") + self.batch_input_gallery = MultiInputsGallery() + self.batch_mask_gallery = MultiInputsGallery( + visible=False, + elem_classes=["cnet-mask-gallery-group"] + ) + + with gr.Tab(label="Multi-Inputs") as self.multi_inputs_upload_tab: + with gr.Row(): + self.multi_inputs_gallery = MultiInputsGallery() if self.photopea: self.photopea.attach_photopea_output(self.generated_image) @@ -600,8 +589,9 @@ class ControlNetUiGroup(object): self.use_preview_as_input, self.batch_image_dir, self.batch_mask_dir, - self.batch_input_gallery, - self.batch_mask_gallery, + self.batch_input_gallery.input_gallery, + self.batch_mask_gallery.input_gallery, + self.multi_inputs_gallery.input_gallery, self.generated_image, self.mask_image, self.hr_option, @@ -618,9 +608,16 @@ class ControlNetUiGroup(object): self.guidance_end, self.pixel_perfect, self.control_mode, + self.advanced_weighting, ) unit = gr.State(self.default_unit) + def create_unit(*args): + return ControlNetUnit.from_dict({ + k: v + for k, v in zip(vars(ControlNetUnit()).keys(), args) + }) + for comp in unit_args + (self.dummy_gradio_update_trigger,): event_subscribers = [] if hasattr(comp, "edit"): @@ -637,7 +634,7 @@ class ControlNetUiGroup(object): for event_subscriber in event_subscribers: event_subscriber( - fn=UiControlNetUnit, inputs=list(unit_args), outputs=unit + fn=create_unit, inputs=list(unit_args), outputs=unit ) ( @@ -645,7 +642,7 @@ class ControlNetUiGroup(object): if self.is_img2img else ControlNetUiGroup.a1111_context.txt2img_submit_button ).click( - fn=UiControlNetUnit, + fn=create_unit, inputs=list(unit_args), outputs=unit, queue=False, @@ -981,20 +978,41 @@ class ControlNetUiGroup(object): """Controls whether the upload mask input should be visible.""" def on_checkbox_click(checked: bool, canvas_height: int, canvas_width: int): if not checked: - # Clear mask_image if unchecked. - return gr.update(visible=False), gr.update(value=None), gr.update(value=None, visible=False), \ - gr.update(visible=False), gr.update(value=None) + # Clear mask inputs if unchecked. + return ( + # Single mask upload. + gr.update(visible=False), + gr.update(value=None), + # Batch mask upload dir. + gr.update(value=None, visible=False), + # Multi mask upload gallery. + gr.update(visible=False), + gr.update(value=None) + ) else: # Init an empty canvas the same size as the generation target. empty_canvas = np.zeros(shape=(canvas_height, canvas_width, 3), dtype=np.uint8) - return gr.update(visible=True), gr.update(value=empty_canvas), gr.update(visible=True), \ - gr.update(visible=True), gr.update() + return ( + # Single mask upload. + gr.update(visible=True), + gr.update(value=empty_canvas), + # Batch mask upload dir. + gr.update(visible=True), + # Multi mask upload gallery. + gr.update(visible=True), + gr.update(), + ) self.mask_upload.change( fn=on_checkbox_click, inputs=[self.mask_upload, self.height_slider, self.width_slider], - outputs=[self.mask_image_group, self.mask_image, self.batch_mask_dir, - self.batch_mask_gallery_group, self.batch_mask_gallery], + outputs=[ + self.mask_image_group, + self.mask_image, + self.batch_mask_dir, + self.batch_mask_gallery.group, + self.batch_mask_gallery.input_gallery, + ], show_progress=False, ) @@ -1079,51 +1097,14 @@ class ControlNetUiGroup(object): def register_multi_images_upload(self): """Register callbacks on merge tab multiple images upload.""" - self.merge_clear_button.click( - fn=lambda: [], - inputs=[], - outputs=[self.batch_input_gallery], - ).then( - fn=lambda x: gr.update(value=x + 1), + trigger_dict = dict( + fn=lambda n: gr.update(value=n + 1), inputs=[self.dummy_gradio_update_trigger], outputs=[self.dummy_gradio_update_trigger], ) - self.mask_merge_clear_button.click( - fn=lambda: [], - inputs=[], - outputs=[self.batch_mask_gallery], - ).then( - fn=lambda x: gr.update(value=x + 1), - inputs=[self.dummy_gradio_update_trigger], - outputs=[self.dummy_gradio_update_trigger], - ) - - def upload_file(files, current_files): - return {file_d["name"] for file_d in current_files} | { - file.name for file in files - } - - self.merge_upload_button.upload( - upload_file, - inputs=[self.merge_upload_button, self.batch_input_gallery], - outputs=[self.batch_input_gallery], - queue=False, - ).then( - fn=lambda x: gr.update(value=x + 1), - inputs=[self.dummy_gradio_update_trigger], - outputs=[self.dummy_gradio_update_trigger], - ) - self.mask_merge_upload_button.upload( - upload_file, - inputs=[self.mask_merge_upload_button, self.batch_mask_gallery], - outputs=[self.batch_mask_gallery], - queue=False, - ).then( - fn=lambda x: gr.update(value=x + 1), - inputs=[self.dummy_gradio_update_trigger], - outputs=[self.dummy_gradio_update_trigger], - ) - return + self.batch_input_gallery.register_callbacks(change_trigger=trigger_dict) + self.batch_mask_gallery.register_callbacks(change_trigger=trigger_dict) + self.multi_inputs_gallery.register_callbacks(change_trigger=trigger_dict) def register_core_callbacks(self): """Register core callbacks that only involves gradio components defined @@ -1213,7 +1194,8 @@ class ControlNetUiGroup(object): for input_tab, fn in ( (ui_group.upload_tab, simple_fn), (ui_group.batch_tab, batch_fn), - (ui_group.merge_tab, merge_fn), + (ui_group.batch_upload_tab, batch_fn), + (ui_group.multi_inputs_upload_tab, merge_fn), ): # Sync input_mode. input_tab.select( diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/multi_inputs_gallery.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/multi_inputs_gallery.py new file mode 100644 index 00000000..15cba1df --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/multi_inputs_gallery.py @@ -0,0 +1,65 @@ +import gradio as gr +from typing import Optional + + +class MultiInputsGallery: + """A gallery object that accepts multiple input images.""" + + def __init__(self, row: int = 2, column: int = 4, **group_kwargs) -> None: + self.gallery_row_num = row + self.gallery_column_num = column + self.group_kwargs = group_kwargs + + self.group = None + self.input_gallery = None + self.upload_button = None + self.clear_button = None + self.render() + + def render(self): + with gr.Group(**self.group_kwargs) as self.group: + with gr.Column(): + self.input_gallery = gr.Gallery( + columns=[self.gallery_column_num], + rows=[self.gallery_row_num], + object_fit="contain", + height="auto", + label="Images", + ) + with gr.Row(): + self.upload_button = gr.UploadButton( + "Upload Images", + file_types=["image"], + file_count="multiple", + ) + self.clear_button = gr.Button("Clear Images") + + def register_callbacks(self, change_trigger: Optional[dict] = None): + """Register callbacks on multiple images upload. + Argument: + - change_trigger: An optional gradio callback param dict to be called + after gallery content change. This is necessary as gallery has no + event subscriber. If the state change of gallery needs to be observed, + the caller needs to pass a change trigger to observe the change. + """ + handle1 = self.clear_button.click( + fn=lambda: [], + inputs=[], + outputs=[self.input_gallery], + ) + + def upload_file(files, current_files): + return {file_d["name"] for file_d in current_files} | { + file.name for file in files + } + + handle2 = self.upload_button.upload( + upload_file, + inputs=[self.upload_button, self.input_gallery], + outputs=[self.input_gallery], + queue=False, + ) + + if change_trigger: + for handle in (handle1, handle2): + handle.then(**change_trigger) diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py index 4954478a..857387e0 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py @@ -183,6 +183,8 @@ class ControlNetUnit: batch_input_gallery: Optional[List[str]] = None # Optional list of gallery masks for batch processing; defaults to None. batch_mask_gallery: Optional[List[str]] = None + # Optional list of gallery images for multi-inputs; defaults to None. + multi_inputs_gallery: Optional[List[str]] = None # Holds the preview image as a NumPy array; defaults to None. generated_image: Optional[np.ndarray] = None # ====== End of UI only fields ====== @@ -192,7 +194,7 @@ class ControlNetUnit: mask_image: Optional[GradioImageMaskPair] = None # Specifies how this unit should be applied in each pass of high-resolution fix. # Ignored if high-resolution fix is not enabled. - hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH + hr_option: HiResFixOption = HiResFixOption.BOTH # Indicates whether the unit is enabled; defaults to True. enabled: bool = True # Name of the module being used; defaults to "None". @@ -204,7 +206,7 @@ class ControlNetUnit: # Optional image for input; defaults to None. image: Optional[GradioImageMaskPair] = None # Specifies the mode of image resizing; defaults to inner fit. - resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT + resize_mode: ResizeMode = ResizeMode.INNER_FIT # Resolution for processing by the unit; defaults to -1 (unspecified). processor_res: int = -1 # Threshold A for processing; defaults to -1 (unspecified). @@ -218,7 +220,22 @@ class ControlNetUnit: # Enables pixel-perfect processing; defaults to False. pixel_perfect: bool = False # Control mode for the unit; defaults to balanced. - control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED + control_mode: ControlMode = ControlMode.BALANCED + # Weight for each layer of ControlNet params. + # For ControlNet: + # - SD1.5: 13 weights (4 encoder block * 3 + 1 middle block) + # - SDXL: 10 weights (3 encoder block * 3 + 1 middle block) + # For T2IAdapter + # - SD1.5: 5 weights (4 encoder block + 1 middle block) + # - SDXL: 4 weights (3 encoder block + 1 middle block) + # For IPAdapter + # - SD15: 16 (6 input blocks + 9 output blocks + 1 middle block) + # - SDXL: 11 weights (4 input blocks + 6 output blocks + 1 middle block) + # Note1: Setting advanced weighting will disable `soft_injection`, i.e. + # It is recommended to set ControlMode = BALANCED when using `advanced_weighting`. + # Note2: The field `weight` is still used in some places, e.g. reference_only, + # even advanced_weighting is set. + advanced_weighting: Optional[List[float]] = None # Following fields should only be used in the API. # ====== Start of API only fields ====== @@ -275,6 +292,11 @@ class ControlNetUnit: "image": mask, "mask": np.zeros_like(mask), } + # Convert strings to enums. + unit.input_mode = InputMode(unit.input_mode) + unit.hr_option = HiResFixOption(unit.hr_option) + unit.resize_mode = ResizeMode(unit.resize_mode) + unit.control_mode = ControlMode(unit.control_mode) return unit diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/global_state.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/global_state.py index dc87dda8..15bef936 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/global_state.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/global_state.py @@ -6,6 +6,7 @@ from modules import shared, sd_models from lib_controlnet.enums import StableDiffusionVersion from modules_forge.shared import controlnet_dir, supported_preprocessors +from typing import Dict, Tuple, List CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin", ".patch"] @@ -56,6 +57,10 @@ controlnet_names = ['None'] def get_preprocessor(name): return supported_preprocessors.get(name, None) +def get_default_preprocessor(tag): + ps = get_filtered_preprocessor_names(tag) + assert len(ps) > 0 + return ps[0] if len(ps) == 1 else ps[1] def get_sorted_preprocessors(): preprocessors = [p for k, p in supported_preprocessors.items() if k != 'None'] @@ -144,3 +149,44 @@ def get_sd_version() -> StableDiffusionVersion: return StableDiffusionVersion.SD1x else: return StableDiffusionVersion.UNKNOWN + + +def select_control_type( + control_type: str, + sd_version: StableDiffusionVersion = StableDiffusionVersion.UNKNOWN, +) -> Tuple[List[str], List[str], str, str]: + global controlnet_names + + pattern = control_type.lower() + all_models = list(controlnet_names) + + if pattern == "all": + preprocessors = get_sorted_preprocessors().values() + return [ + [p.name for p in preprocessors], + all_models, + 'none', # default option + "None" # default model + ] + + filtered_model_list = get_filtered_controlnet_names(control_type) + + if pattern == "none": + filtered_model_list.append("None") + + assert len(filtered_model_list) > 0, "'None' model should always be available." + if len(filtered_model_list) == 1: + default_model = "None" + else: + default_model = filtered_model_list[1] + for x in filtered_model_list: + if "11" in x.split("[")[0]: + default_model = x + break + + return ( + get_filtered_preprocessor_names(control_type), + filtered_model_list, + get_default_preprocessor(control_type), + default_model + ) diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/utils.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/utils.py index c911d1c4..087b3298 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/utils.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/utils.py @@ -1,7 +1,10 @@ from typing import Optional +from copy import copy from modules import processing +from modules.api import api from lib_controlnet import external_code +from lib_controlnet.external_code import InputMode, ControlNetUnit from modules_forge.forge_util import HWC3 @@ -361,3 +364,22 @@ def crop_and_resize_image(detected_map, resize_mode, h, w, fill_border_with_255= def judge_image_type(img): return isinstance(img, np.ndarray) and img.ndim == 3 and int(img.shape[2]) in [3, 4] + + +def try_unfold_unit(unit: ControlNetUnit) -> List[ControlNetUnit]: + """Unfolds an multi-inputs unit into multiple units with one input.""" + if unit.input_mode != InputMode.MERGE: + return [unit] + + def extract_unit(gallery_item: dict) -> ControlNetUnit: + r_unit = copy(unit) + img = np.array(api.decode_base64_to_image(read_image(gallery_item["name"]))).astype('uint8') + r_unit.image = { + "image": img, + "mask": np.zeros_like(img), + } + r_unit.input_mode = InputMode.SIMPLE + r_unit.weight = unit.weight / len(unit.multi_inputs_gallery) + return r_unit + + return [extract_unit(item) for item in unit.multi_inputs_gallery] diff --git a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py index 08aa5ff2..436487a0 100644 --- a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py +++ b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py @@ -11,9 +11,15 @@ from modules.api.api import decode_base64_to_image import gradio as gr from lib_controlnet import global_state, external_code -from lib_controlnet.external_code import ControlNetUnit -from lib_controlnet.utils import align_dim_latent, set_numpy_seed, crop_and_resize_image, \ - prepare_mask, judge_image_type +from lib_controlnet.external_code import ControlNetUnit, InputMode +from lib_controlnet.utils import ( + align_dim_latent, + set_numpy_seed, + crop_and_resize_image, + prepare_mask, + judge_image_type, + try_unfold_unit, +) from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup from lib_controlnet.controlnet_ui.photopea import Photopea from lib_controlnet.logging import logger @@ -53,6 +59,7 @@ class ControlNetCachedParameters: self.control_cond_for_hr_fix = None self.control_mask = None self.control_mask_for_hr_fix = None + self.advanced_weighting = None class ControlNetForForgeOfficial(scripts.Script): @@ -105,8 +112,13 @@ class ControlNetForForgeOfficial(scripts.Script): for unit in units ] assert all(isinstance(unit, ControlNetUnit) for unit in units) - enabled_units = [x for x in units if x.enabled] - return enabled_units + return [ + simple_unit + for unit in units + # Unfolds multi-inputs units. + for simple_unit in try_unfold_unit(unit) + if simple_unit.enabled + ] @staticmethod def try_crop_image_with_a1111_mask( @@ -153,9 +165,9 @@ class ControlNetForForgeOfficial(scripts.Script): def get_input_data(self, p, unit, preprocessor, h, w): logger.info(f'ControlNet Input Mode: {unit.input_mode}') image_list = [] - resize_mode = external_code.resize_mode_from_value(unit.resize_mode) + resize_mode = unit.resize_mode - if unit.input_mode == external_code.InputMode.MERGE: + if unit.input_mode == InputMode.MERGE: for idx, item in enumerate(unit.batch_input_gallery): img_path = item['name'] logger.info(f'Try to read image: {img_path}') @@ -169,7 +181,7 @@ class ControlNetForForgeOfficial(scripts.Script): mask = np.ascontiguousarray(cv2.imread(mask_path)[:, :, ::-1]).copy() if img is not None: image_list.append([img, mask]) - elif unit.input_mode == external_code.InputMode.BATCH: + elif unit.input_mode == InputMode.BATCH: image_list = [] image_extensions = ['.jpg', '.jpeg', '.png', '.bmp'] batch_image_files = shared.listfiles(unit.batch_image_dir) @@ -339,7 +351,7 @@ class ControlNetForForgeOfficial(scripts.Script): break if has_high_res_fix: - hr_option = HiResFixOption.from_value(unit.hr_option) + hr_option = unit.hr_option else: hr_option = HiResFixOption.BOTH @@ -430,7 +442,7 @@ class ControlNetForForgeOfficial(scripts.Script): ) if has_high_res_fix: - hr_option = HiResFixOption.from_value(unit.hr_option) + hr_option = unit.hr_option else: hr_option = HiResFixOption.BOTH @@ -494,6 +506,11 @@ class ControlNetForForgeOfficial(scripts.Script): params.model.positive_advanced_weighting = soft_weighting.copy() params.model.negative_advanced_weighting = soft_weighting.copy() + if unit.advanced_weighting is not None: + if params.model.positive_advanced_weighting is None: + logger.warn("advanced_weighting overwrite control_mode") + params.model.positive_advanced_weighting = unit.advanced_weighting + cond, mask = params.preprocessor.process_before_every_sampling(p, cond, mask, *args, **kwargs) params.model.advanced_mask_weighting = mask diff --git a/extensions-builtin/sd_forge_controlnet/tests/web_api/generation_test.py b/extensions-builtin/sd_forge_controlnet/tests/web_api/generation_test.py index 433819d1..8b8a68c9 100644 --- a/extensions-builtin/sd_forge_controlnet/tests/web_api/generation_test.py +++ b/extensions-builtin/sd_forge_controlnet/tests/web_api/generation_test.py @@ -4,6 +4,7 @@ from .template import ( APITestTemplate, girl_img, mask_img, + portrait_imgs, disable_in_cq, get_model, ) @@ -169,3 +170,60 @@ def test_lama_outpaint(): "resize_mode": "Resize and Fill", # OUTER_FIT }, ).exec() + + +@disable_in_cq +def test_instant_id_sdxl(): + assert len(portrait_imgs) > 0 + assert APITestTemplate( + "instant_id_sdxl", + "txt2img", + payload_overrides={ + "width": 1000, + "height": 1000, + "prompt": "1girl, red background", + }, + unit_overrides=[ + dict( + image=portrait_imgs[0], + model=get_model("ip-adapter_instant_id_sdxl"), + module="InsightFace (InstantID)", + ), + dict( + image=portrait_imgs[1], + model=get_model("control_instant_id_sdxl"), + module="instant_id_face_keypoints", + ), + ], + ).exec() + + +@disable_in_cq +def test_instant_id_sdxl_multiple_units(): + assert len(portrait_imgs) > 0 + assert APITestTemplate( + "instant_id_sdxl_multiple_units", + "txt2img", + payload_overrides={ + "width": 1000, + "height": 1000, + "prompt": "1girl, red background", + }, + unit_overrides=[ + dict( + image=portrait_imgs[0], + model=get_model("ip-adapter_instant_id_sdxl"), + module="InsightFace (InstantID)", + ), + dict( + image=portrait_imgs[1], + model=get_model("control_instant_id_sdxl"), + module="instant_id_face_keypoints", + ), + dict( + image=portrait_imgs[1], + model=get_model("diffusers_xl_canny"), + module="canny", + ), + ], + ).exec() diff --git a/extensions-builtin/sd_forge_controlnet/tests/web_api/ipadapter_advanced_weighting.py b/extensions-builtin/sd_forge_controlnet/tests/web_api/ipadapter_advanced_weighting.py new file mode 100644 index 00000000..c61ef568 --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/tests/web_api/ipadapter_advanced_weighting.py @@ -0,0 +1,43 @@ +from .template import ( + APITestTemplate, + realistic_girl_face_img, + disable_in_cq, + get_model, +) + + +@disable_in_cq +def test_ipadapter_advanced_weighting(): + weights = [0.0] * 16 # 16 weights for SD15 / 11 weights for SDXL + # SD15 composition + weights[4] = 0.25 + weights[5] = 1.0 + + APITestTemplate( + "test_ipadapter_advanced_weighting", + "txt2img", + payload_overrides={ + "width": 512, + "height": 512, + }, + unit_overrides={ + "image": realistic_girl_face_img, + "module": "CLIP-ViT-H (IPAdapter)", + "model": get_model("ip-adapter_sd15"), + "advanced_weighting": weights, + }, + ).exec() + + APITestTemplate( + "test_ipadapter_advanced_weighting_ref", + "txt2img", + payload_overrides={ + "width": 512, + "height": 512, + }, + unit_overrides={ + "image": realistic_girl_face_img, + "module": "CLIP-ViT-H (IPAdapter)", + "model": get_model("ip-adapter_sd15"), + }, + ).exec() \ No newline at end of file diff --git a/extensions-builtin/sd_forge_controlnet/tests/web_api/template.py b/extensions-builtin/sd_forge_controlnet/tests/web_api/template.py index 5129e541..5a4eb49f 100644 --- a/extensions-builtin/sd_forge_controlnet/tests/web_api/template.py +++ b/extensions-builtin/sd_forge_controlnet/tests/web_api/template.py @@ -248,13 +248,13 @@ def get_model(model_name: str) -> str: default_unit = { - "control_mode": 0, + "control_mode": "Balanced", "enabled": True, "guidance_end": 1, "guidance_start": 0, "pixel_perfect": True, "processor_res": 512, - "resize_mode": 1, + "resize_mode": "Crop and Resize", "threshold_a": 64, "threshold_b": 64, "weight": 1, diff --git a/extensions-builtin/sd_forge_freeu/scripts/forge_freeu.py b/extensions-builtin/sd_forge_freeu/scripts/forge_freeu.py index 07864802..377ae260 100644 --- a/extensions-builtin/sd_forge_freeu/scripts/forge_freeu.py +++ b/extensions-builtin/sd_forge_freeu/scripts/forge_freeu.py @@ -1,6 +1,12 @@ import gradio as gr -from modules import scripts +import sys +import traceback + +from typing import Any +from functools import partial + +from modules import script_callbacks, scripts from ldm_patched.contrib.external_freelunch import FreeU_V2 @@ -68,6 +74,18 @@ class FreeUForForge(scripts.Script): freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2 = script_args + xyz = getattr(p, "_freeu_xyz", {}) + if "freeu_enabled" in xyz: + freeu_enabled = xyz["freeu_enabled"] == "True" + if "freeu_b1" in xyz: + freeu_b1 = xyz["freeu_b1"] + if "freeu_b2" in xyz: + freeu_b2 = xyz["freeu_b2"] + if "freeu_s1" in xyz: + freeu_s1 = xyz["freeu_s1"] + if "freeu_s2" in xyz: + freeu_s2 = xyz["freeu_s2"] + if not freeu_enabled: return @@ -89,3 +107,62 @@ class FreeUForForge(scripts.Script): )) return + +def set_value(p, x: Any, xs: Any, *, field: str): + if not hasattr(p, "_freeu_xyz"): + p._freeu_xyz = {} + p._freeu_xyz[field] = x + +def make_axis_on_xyz_grid(): + xyz_grid = None + for script in scripts.scripts_data: + if script.script_class.__module__ == "xyz_grid.py": + xyz_grid = script.module + break + + if xyz_grid is None: + return + + axis = [ + xyz_grid.AxisOption( + "FreeU Enabled", + str, + partial(set_value, field="freeu_enabled"), + choices=lambda: ["True", "False"] + ), + xyz_grid.AxisOption( + "FreeU B1", + float, + partial(set_value, field="freeu_b1"), + ), + xyz_grid.AxisOption( + "FreeU B2", + float, + partial(set_value, field="freeu_b2"), + ), + xyz_grid.AxisOption( + "FreeU S1", + float, + partial(set_value, field="freeu_s1"), + ), + xyz_grid.AxisOption( + "FreeU S2", + float, + partial(set_value, field="freeu_s2"), + ), + ] + + if not any(x.label.startswith("FreeU") for x in xyz_grid.axis_options): + xyz_grid.axis_options.extend(axis) + +def on_before_ui(): + try: + make_axis_on_xyz_grid() + except Exception: + error = traceback.format_exc() + print( + f"[-] FreeU Integrated: xyz_grid error:\n{error}", + file=sys.stderr, + ) + +script_callbacks.on_before_ui(on_before_ui) diff --git a/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py b/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py index e35b8d60..1350c1f6 100644 --- a/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py +++ b/extensions-builtin/sd_forge_ipadapter/lib_ipadapter/IPAdapterPlus.py @@ -7,6 +7,7 @@ import math import ldm_patched.modules.utils import ldm_patched.modules.model_management +from ldm_patched.modules.controlnet import ControlNet from ldm_patched.modules.clip_vision import clip_preprocess from ldm_patched.ldm.modules.attention import optimized_attention from ldm_patched.utils import path_utils as folder_paths @@ -402,7 +403,7 @@ class CrossAttentionPatch: batch_prompt = b // len(cond_or_uncond) out = optimized_attention(q, k, v, extra_options["n_heads"]) _, _, lh, lw = extra_options["original_shape"] - + for weight, cond, uncond, ipadapter, mask, weight_type, sigma_start, sigma_end, unfold_batch in zip(self.weights, self.conds, self.unconds, self.ipadapters, self.masks, self.weight_type, self.sigma_start, self.sigma_end, self.unfold_batch): if sigma > sigma_start or sigma < sigma_end: continue @@ -465,8 +466,18 @@ class CrossAttentionPatch: ip_v = ip_v_offset + ip_v_mean * W out_ip = optimized_attention(q, ip_k.to(org_dtype), ip_v.to(org_dtype), extra_options["n_heads"]) - if weight_type.startswith("original"): - out_ip = out_ip * weight + + if weight_type == "original": + assert isinstance(weight, (float, int)) + weight = weight + elif weight_type == "advanced": + assert isinstance(weight, list) + transformer_index: int = extra_options["transformer_index"] + assert transformer_index < len(weight) + weight = weight[transformer_index] + else: + weight = 1.0 + out_ip = out_ip * weight if mask is not None: # TODO: needs checking @@ -732,7 +743,7 @@ class IPAdapterApply: is_faceid=self.is_faceid, is_instant_id=self.is_instant_id ) - + self.ipadapter.to(self.device, dtype=self.dtype) if self.is_instant_id: @@ -749,13 +760,27 @@ class IPAdapterApply: work_model = model.clone() if self.is_instant_id: - def modifier(cnet, x_noisy, t, cond, batched_number): + def instant_id_modifier(cnet: ControlNet, x_noisy, t, cond, batched_number): + """Overwrites crossattn inputs to InstantID ControlNet with ipadapter image embeds. + + TODO: There can be multiple pairs of InstantID (ipadapter/controlnet) to control + rendering of multiple faces on canvas. We need to find a way to pair them. Currently, + the modifier is unconditionally applied to all instant id ControlNet units. + """ + if ( + not isinstance(cnet, ControlNet) or + # model_file_name is None for Control LoRA. + cnet.control_model.model_file_name is None or + "instant_id" not in cnet.control_model.model_file_name.lower() + ): + return x_noisy, t, cond, batched_number + cond_mark = cond['transformer_options']['cond_mark'][:, None, None].to(cond['c_crossattn']) # cond is 0 c_crossattn = image_prompt_embeds * (1.0 - cond_mark) + uncond_image_prompt_embeds * cond_mark cond['c_crossattn'] = c_crossattn return x_noisy, t, cond, batched_number - work_model.add_controlnet_conditioning_modifier(modifier) + work_model.add_controlnet_conditioning_modifier(instant_id_modifier) if attn_mask is not None: attn_mask = attn_mask.to(self.device) diff --git a/extensions-builtin/sd_forge_ipadapter/scripts/forge_ipadapter.py b/extensions-builtin/sd_forge_ipadapter/scripts/forge_ipadapter.py index 47b72044..c3d8c1a9 100644 --- a/extensions-builtin/sd_forge_ipadapter/scripts/forge_ipadapter.py +++ b/extensions-builtin/sd_forge_ipadapter/scripts/forge_ipadapter.py @@ -143,11 +143,17 @@ class IPAdapterPatcher(ControlModelPatcher): def process_before_every_sampling(self, process, cond, mask, *args, **kwargs): unet = process.sd_model.forge_objects.unet + if self.positive_advanced_weighting is None: + weight = self.strength + cond["weight_type"] = "original" + else: + weight = self.positive_advanced_weighting + cond["weight_type"] = "advanced" unet = opIPAdapterApply( ipadapter=self.ip_adapter, model=unet, - weight=self.strength, + weight=weight, start_at=self.start_percent, end_at=self.end_percent, faceid_v2=self.faceid_v2, diff --git a/html/footer.html b/html/footer.html index 69b2372c..8fe2bf8d 100644 --- a/html/footer.html +++ b/html/footer.html @@ -1,7 +1,7 @@
API  •  - Github + Github  •  Gradio  •  diff --git a/ldm_patched/controlnet/cldm.py b/ldm_patched/controlnet/cldm.py index 82265ef9..aa8de9bd 100644 --- a/ldm_patched/controlnet/cldm.py +++ b/ldm_patched/controlnet/cldm.py @@ -4,6 +4,7 @@ import torch import torch as th import torch.nn as nn +from typing import Optional from ldm_patched.ldm.modules.diffusionmodules.util import ( zero_module, @@ -54,9 +55,12 @@ class ControlNet(nn.Module): transformer_depth_output=None, device=None, operations=ldm_patched.modules.ops.disable_weight_init, + model_file_name: Optional[str] = None, # Name of model file. **kwargs, ): super().__init__() + self.model_file_name = model_file_name + assert use_spatial_transformer == True, "use_spatial_transformer has to be true" if use_spatial_transformer: assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' diff --git a/ldm_patched/k_diffusion/sampling.py b/ldm_patched/k_diffusion/sampling.py index 6f2fbea7..498d1a6f 100644 --- a/ldm_patched/k_diffusion/sampling.py +++ b/ldm_patched/k_diffusion/sampling.py @@ -6,6 +6,7 @@ import math from scipy import integrate import torch +import numpy as np from torch import nn import torchsde from tqdm.auto import trange, tqdm @@ -38,6 +39,35 @@ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'): sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min)) return append_zero(sigmas) +# align your steps +def get_sigmas_ays(n, sigma_min, sigma_max, is_sdxl=False, device='cpu'): + # https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html + def loglinear_interp(t_steps, num_steps): + """ + Performs log-linear interpolation of a given array of decreasing numbers. + """ + xs = torch.linspace(0, 1, len(t_steps)) + ys = torch.log(torch.tensor(t_steps[::-1])) + + new_xs = torch.linspace(0, 1, num_steps) + new_ys = np.interp(new_xs, xs, ys) + + interped_ys = torch.exp(torch.tensor(new_ys)).numpy()[::-1].copy() + return interped_ys + + if is_sdxl: + sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.029] + else: + # Default to SD 1.5 sigmas. + sigmas = [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029] + + if n != len(sigmas): + sigmas = np.append(loglinear_interp(sigmas, n), [0.0]) + else: + sigmas.append(0.0) + + return torch.FloatTensor(sigmas).to(device) + def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): """Constructs a continuous VP noise schedule.""" diff --git a/ldm_patched/modules/controlnet.py b/ldm_patched/modules/controlnet.py index 6192d7ae..449716c4 100644 --- a/ldm_patched/modules/controlnet.py +++ b/ldm_patched/modules/controlnet.py @@ -487,7 +487,7 @@ def load_controlnet(ckpt_path, model=None): controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] - control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config) + control_model = ldm_patched.controlnet.cldm.ControlNet(model_file_name=ckpt_path, **controlnet_config) if pth: if 'difference' in controlnet_data: diff --git a/ldm_patched/modules/model_sampling.py b/ldm_patched/modules/model_sampling.py index da5cc3a6..6bb08fd2 100644 --- a/ldm_patched/modules/model_sampling.py +++ b/ldm_patched/modules/model_sampling.py @@ -107,9 +107,11 @@ class ModelSamplingContinuousEDM(torch.nn.Module): sigma_min = sampling_settings.get("sigma_min", 0.002) sigma_max = sampling_settings.get("sigma_max", 120.0) - self.set_sigma_range(sigma_min, sigma_max) + sigma_data = sampling_settings.get("sigma_data", 1.0) + self.set_sigma_range(sigma_min, sigma_max, sigma_data) - def set_sigma_range(self, sigma_min, sigma_max): + def set_sigma_range(self, sigma_min, sigma_max, sigma_data): + self.sigma_data = sigma_data sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp() self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers diff --git a/ldm_patched/modules/samplers.py b/ldm_patched/modules/samplers.py index e8f53e13..7f49907a 100644 --- a/ldm_patched/modules/samplers.py +++ b/ldm_patched/modules/samplers.py @@ -662,14 +662,16 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) return model.process_latent_out(samples.to(torch.float32)) -SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] +SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "ays"] SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] -def calculate_sigmas_scheduler(model, scheduler_name, steps): +def calculate_sigmas_scheduler(model, scheduler_name, steps, is_sdxl=False): if scheduler_name == "karras": sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) elif scheduler_name == "exponential": sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) + elif scheduler_name == "ays": + sigmas = k_diffusion_sampling.get_sigmas_ays(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max), is_sdxl=is_sdxl) elif scheduler_name == "normal": sigmas = normal_scheduler(model, steps) elif scheduler_name == "simple": diff --git a/ldm_patched/modules/supported_models.py b/ldm_patched/modules/supported_models.py index de21c10d..3e1889bf 100644 --- a/ldm_patched/modules/supported_models.py +++ b/ldm_patched/modules/supported_models.py @@ -169,6 +169,11 @@ class SDXL(supported_models_base.BASE): def model_type(self, state_dict, prefix=""): if "v_pred" in state_dict: return model_base.ModelType.V_PREDICTION + elif "edm_vpred.sigma_max" in state_dict: + self.sampling_settings["sigma_max"] = round(float(state_dict["edm_vpred.sigma_max"].item()),3) + if "edm_vpred.sigma_min" in state_dict: + self.sampling_settings["sigma_min"] = round(float(state_dict["edm_vpred.sigma_min"].item()),3) + return model_base.ModelType.V_PREDICTION_EDM else: return model_base.ModelType.EPS diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 596d14b6..775bc3ef 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -516,7 +516,7 @@ def configure_forge_reference_checkout(a1111_home: Path): ModelRef(arg_name="--vae-dir", relative_path="models/VAE"), ModelRef(arg_name="--hypernetwork-dir", relative_path="models/hypernetworks"), ModelRef(arg_name="--embeddings-dir", relative_path="embeddings"), - ModelRef(arg_name="--lora-dir", relative_path="models/lora"), + ModelRef(arg_name="--lora-dir", relative_path="models/Lora"), # Ref A1111 need to have sd-webui-controlnet installed. ModelRef(arg_name="--controlnet-dir", relative_path="models/ControlNet"), ModelRef(arg_name="--controlnet-preprocessor-models-dir", relative_path="extensions/sd-webui-controlnet/annotator/downloads"), diff --git a/modules_forge/forge_alter_samplers.py b/modules_forge/forge_alter_samplers.py index 4e482208..bb9f2499 100644 --- a/modules_forge/forge_alter_samplers.py +++ b/modules_forge/forge_alter_samplers.py @@ -10,6 +10,7 @@ class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler): self.sampler_name = sampler_name self.scheduler_name = scheduler_name self.unet = sd_model.forge_objects.unet + self.model = sd_model sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name)) super().__init__(sampler_function, sd_model, None) @@ -20,7 +21,7 @@ class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler): sigmas = self.unet.model.model_sampling.sigma(timesteps) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) else: - sigmas = calculate_sigmas_scheduler(self.unet.model, self.scheduler_name, steps) + sigmas = calculate_sigmas_scheduler(self.unet.model, self.scheduler_name, steps, is_sdxl=getattr(self.model, "is_sdxl", False)) return sigmas.to(self.unet.load_device) @@ -34,9 +35,13 @@ def build_constructor(sampler_name, scheduler_name): samplers_data_alter = [ sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm', scheduler_name='normal'), ['ddpm'], {}), sd_samplers_common.SamplerData('DDPM Karras', build_constructor(sampler_name='ddpm', scheduler_name='karras'), ['ddpm_karras'], {}), + sd_samplers_common.SamplerData('Euler AYS', build_constructor(sampler_name='euler', scheduler_name='ays'), ['euler_ays'], {}), sd_samplers_common.SamplerData('Euler A Turbo', build_constructor(sampler_name='euler_ancestral', scheduler_name='turbo'), ['euler_ancestral_turbo'], {}), + sd_samplers_common.SamplerData('Euler A AYS', build_constructor(sampler_name='euler_ancestral', scheduler_name='ays'), ['euler_ancestral_ays'], {}), sd_samplers_common.SamplerData('DPM++ 2M Turbo', build_constructor(sampler_name='dpmpp_2m', scheduler_name='turbo'), ['dpmpp_2m_turbo'], {}), + sd_samplers_common.SamplerData('DPM++ 2M AYS', build_constructor(sampler_name='dpmpp_2m', scheduler_name='ays'), ['dpmpp_2m_ays'], {}), sd_samplers_common.SamplerData('DPM++ 2M SDE Turbo', build_constructor(sampler_name='dpmpp_2m_sde', scheduler_name='turbo'), ['dpmpp_2m_sde_turbo'], {}), + sd_samplers_common.SamplerData('DPM++ 2M SDE AYS', build_constructor(sampler_name='dpmpp_2m_sde', scheduler_name='ays'), ['dpmpp_2m_sde_ays'], {}), sd_samplers_common.SamplerData('LCM Karras', build_constructor(sampler_name='lcm', scheduler_name='karras'), ['lcm_karras'], {}), sd_samplers_common.SamplerData('Euler SGMUniform', build_constructor(sampler_name='euler', scheduler_name='sgm_uniform'), ['euler_sgm_uniform'], {}), sd_samplers_common.SamplerData('Euler A SGMUniform', build_constructor(sampler_name='euler_ancestral', scheduler_name='sgm_uniform'), ['euler_ancestral_sgm_uniform'], {}), diff --git a/modules_forge/supported_controlnet.py b/modules_forge/supported_controlnet.py index 1490259e..747fe583 100644 --- a/modules_forge/supported_controlnet.py +++ b/modules_forge/supported_controlnet.py @@ -113,7 +113,7 @@ class ControlNetPatcher(ControlModelPatcher): controlnet_config["operations"] = ldm_patched.modules.ops.manual_cast controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] - control_model = ldm_patched.controlnet.cldm.ControlNet(**controlnet_config) + control_model = ldm_patched.controlnet.cldm.ControlNet(model_file_name=ckpt_path, **controlnet_config) if pth: if 'difference' in controlnet_data: