diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py index b3e16df6..a61df353 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py @@ -16,6 +16,7 @@ from lib_controlnet.controlnet_ui.openpose_editor import OpenposeEditor from lib_controlnet.controlnet_ui.preset import ControlNetPresetUI from lib_controlnet.controlnet_ui.tool_button import ToolButton from lib_controlnet.controlnet_ui.photopea import Photopea +from lib_controlnet.controlnet_ui.multi_inputs_gallery import MultiInputsGallery from lib_controlnet.enums import InputMode, HiResFixOption from modules import shared, script_callbacks from modules.ui_components import FormRow @@ -185,10 +186,11 @@ class ControlNetUiGroup(object): self.mask_image = None self.batch_tab = None self.batch_image_dir = None - self.merge_tab = None + self.batch_upload_tab = None self.batch_input_gallery = None - self.merge_upload_button = None - self.merge_clear_button = None + self.batch_mask_gallery = None + self.multi_inputs_upload_tab = None + self.multi_inputs_input_gallery = None self.create_canvas = None self.canvas_width = None self.canvas_height = None @@ -331,31 +333,17 @@ class ControlNetUiGroup(object): visible=False, ) - with gr.Tab(label="Batch Upload") as self.merge_tab: + with gr.Tab(label="Batch Upload") as self.batch_upload_tab: with gr.Row(): - with gr.Column(): - self.batch_input_gallery = gr.Gallery( - columns=[4], rows=[2], object_fit="contain", height="auto", label="Images" - ) - with gr.Row(): - self.merge_upload_button = gr.UploadButton( - "Upload Images", - file_types=["image"], - file_count="multiple", - ) - self.merge_clear_button = gr.Button("Clear Images") - with gr.Group(visible=False, elem_classes=["cnet-mask-gallery-group"]) as self.batch_mask_gallery_group: - with gr.Column(): - self.batch_mask_gallery = gr.Gallery( - columns=[4], rows=[2], object_fit="contain", height="auto", label="Masks" - ) - with gr.Row(): - self.mask_merge_upload_button = gr.UploadButton( - "Upload Masks", - file_types=["image"], - file_count="multiple", - ) - self.mask_merge_clear_button = gr.Button("Clear Masks") + self.batch_input_gallery = MultiInputsGallery() + self.batch_mask_gallery = MultiInputsGallery( + visible=False, + elem_classes=["cnet-mask-gallery-group"] + ) + + with gr.Tab(label="Multi-Inputs") as self.multi_inputs_upload_tab: + with gr.Row(): + self.multi_inputs_gallery = MultiInputsGallery() if self.photopea: self.photopea.attach_photopea_output(self.generated_image) @@ -600,8 +588,9 @@ class ControlNetUiGroup(object): self.use_preview_as_input, self.batch_image_dir, self.batch_mask_dir, - self.batch_input_gallery, - self.batch_mask_gallery, + self.batch_input_gallery.input_gallery, + self.batch_mask_gallery.input_gallery, + self.multi_inputs_gallery.input_gallery, self.generated_image, self.mask_image, self.hr_option, @@ -981,20 +970,41 @@ class ControlNetUiGroup(object): """Controls whether the upload mask input should be visible.""" def on_checkbox_click(checked: bool, canvas_height: int, canvas_width: int): if not checked: - # Clear mask_image if unchecked. - return gr.update(visible=False), gr.update(value=None), gr.update(value=None, visible=False), \ - gr.update(visible=False), gr.update(value=None) + # Clear mask inputs if unchecked. + return ( + # Single mask upload. + gr.update(visible=False), + gr.update(value=None), + # Batch mask upload dir. + gr.update(value=None, visible=False), + # Multi mask upload gallery. + gr.update(visible=False), + gr.update(value=None) + ) else: # Init an empty canvas the same size as the generation target. empty_canvas = np.zeros(shape=(canvas_height, canvas_width, 3), dtype=np.uint8) - return gr.update(visible=True), gr.update(value=empty_canvas), gr.update(visible=True), \ - gr.update(visible=True), gr.update() + return ( + # Single mask upload. + gr.update(visible=True), + gr.update(value=empty_canvas), + # Batch mask upload dir. + gr.update(visible=True), + # Multi mask upload gallery. + gr.update(visible=True), + gr.update(), + ) self.mask_upload.change( fn=on_checkbox_click, inputs=[self.mask_upload, self.height_slider, self.width_slider], - outputs=[self.mask_image_group, self.mask_image, self.batch_mask_dir, - self.batch_mask_gallery_group, self.batch_mask_gallery], + outputs=[ + self.mask_image_group, + self.mask_image, + self.batch_mask_dir, + self.batch_mask_gallery.group, + self.batch_mask_gallery.input_gallery, + ], show_progress=False, ) @@ -1079,51 +1089,14 @@ class ControlNetUiGroup(object): def register_multi_images_upload(self): """Register callbacks on merge tab multiple images upload.""" - self.merge_clear_button.click( - fn=lambda: [], - inputs=[], - outputs=[self.batch_input_gallery], - ).then( - fn=lambda x: gr.update(value=x + 1), + trigger_dict = dict( + fn=lambda n: gr.update(value=n + 1), inputs=[self.dummy_gradio_update_trigger], outputs=[self.dummy_gradio_update_trigger], ) - self.mask_merge_clear_button.click( - fn=lambda: [], - inputs=[], - outputs=[self.batch_mask_gallery], - ).then( - fn=lambda x: gr.update(value=x + 1), - inputs=[self.dummy_gradio_update_trigger], - outputs=[self.dummy_gradio_update_trigger], - ) - - def upload_file(files, current_files): - return {file_d["name"] for file_d in current_files} | { - file.name for file in files - } - - self.merge_upload_button.upload( - upload_file, - inputs=[self.merge_upload_button, self.batch_input_gallery], - outputs=[self.batch_input_gallery], - queue=False, - ).then( - fn=lambda x: gr.update(value=x + 1), - inputs=[self.dummy_gradio_update_trigger], - outputs=[self.dummy_gradio_update_trigger], - ) - self.mask_merge_upload_button.upload( - upload_file, - inputs=[self.mask_merge_upload_button, self.batch_mask_gallery], - outputs=[self.batch_mask_gallery], - queue=False, - ).then( - fn=lambda x: gr.update(value=x + 1), - inputs=[self.dummy_gradio_update_trigger], - outputs=[self.dummy_gradio_update_trigger], - ) - return + self.batch_input_gallery.register_callbacks(change_trigger=trigger_dict) + self.batch_mask_gallery.register_callbacks(change_trigger=trigger_dict) + self.multi_inputs_gallery.register_callbacks(change_trigger=trigger_dict) def register_core_callbacks(self): """Register core callbacks that only involves gradio components defined @@ -1213,7 +1186,8 @@ class ControlNetUiGroup(object): for input_tab, fn in ( (ui_group.upload_tab, simple_fn), (ui_group.batch_tab, batch_fn), - (ui_group.merge_tab, merge_fn), + (ui_group.batch_upload_tab, batch_fn), + (ui_group.multi_inputs_upload_tab, merge_fn), ): # Sync input_mode. input_tab.select( diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/multi_inputs_gallery.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/multi_inputs_gallery.py new file mode 100644 index 00000000..15cba1df --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/multi_inputs_gallery.py @@ -0,0 +1,65 @@ +import gradio as gr +from typing import Optional + + +class MultiInputsGallery: + """A gallery object that accepts multiple input images.""" + + def __init__(self, row: int = 2, column: int = 4, **group_kwargs) -> None: + self.gallery_row_num = row + self.gallery_column_num = column + self.group_kwargs = group_kwargs + + self.group = None + self.input_gallery = None + self.upload_button = None + self.clear_button = None + self.render() + + def render(self): + with gr.Group(**self.group_kwargs) as self.group: + with gr.Column(): + self.input_gallery = gr.Gallery( + columns=[self.gallery_column_num], + rows=[self.gallery_row_num], + object_fit="contain", + height="auto", + label="Images", + ) + with gr.Row(): + self.upload_button = gr.UploadButton( + "Upload Images", + file_types=["image"], + file_count="multiple", + ) + self.clear_button = gr.Button("Clear Images") + + def register_callbacks(self, change_trigger: Optional[dict] = None): + """Register callbacks on multiple images upload. + Argument: + - change_trigger: An optional gradio callback param dict to be called + after gallery content change. This is necessary as gallery has no + event subscriber. If the state change of gallery needs to be observed, + the caller needs to pass a change trigger to observe the change. + """ + handle1 = self.clear_button.click( + fn=lambda: [], + inputs=[], + outputs=[self.input_gallery], + ) + + def upload_file(files, current_files): + return {file_d["name"] for file_d in current_files} | { + file.name for file in files + } + + handle2 = self.upload_button.upload( + upload_file, + inputs=[self.upload_button, self.input_gallery], + outputs=[self.input_gallery], + queue=False, + ) + + if change_trigger: + for handle in (handle1, handle2): + handle.then(**change_trigger) diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py index 4954478a..405719a2 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py @@ -183,6 +183,8 @@ class ControlNetUnit: batch_input_gallery: Optional[List[str]] = None # Optional list of gallery masks for batch processing; defaults to None. batch_mask_gallery: Optional[List[str]] = None + # Optional list of gallery images for multi-inputs; defaults to None. + multi_inputs_gallery: Optional[List[str]] = None # Holds the preview image as a NumPy array; defaults to None. generated_image: Optional[np.ndarray] = None # ====== End of UI only fields ====== diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/utils.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/utils.py index c911d1c4..087b3298 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/utils.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/utils.py @@ -1,7 +1,10 @@ from typing import Optional +from copy import copy from modules import processing +from modules.api import api from lib_controlnet import external_code +from lib_controlnet.external_code import InputMode, ControlNetUnit from modules_forge.forge_util import HWC3 @@ -361,3 +364,22 @@ def crop_and_resize_image(detected_map, resize_mode, h, w, fill_border_with_255= def judge_image_type(img): return isinstance(img, np.ndarray) and img.ndim == 3 and int(img.shape[2]) in [3, 4] + + +def try_unfold_unit(unit: ControlNetUnit) -> List[ControlNetUnit]: + """Unfolds an multi-inputs unit into multiple units with one input.""" + if unit.input_mode != InputMode.MERGE: + return [unit] + + def extract_unit(gallery_item: dict) -> ControlNetUnit: + r_unit = copy(unit) + img = np.array(api.decode_base64_to_image(read_image(gallery_item["name"]))).astype('uint8') + r_unit.image = { + "image": img, + "mask": np.zeros_like(img), + } + r_unit.input_mode = InputMode.SIMPLE + r_unit.weight = unit.weight / len(unit.multi_inputs_gallery) + return r_unit + + return [extract_unit(item) for item in unit.multi_inputs_gallery] diff --git a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py index 08aa5ff2..9b3615b3 100644 --- a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py +++ b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py @@ -12,8 +12,14 @@ import gradio as gr from lib_controlnet import global_state, external_code from lib_controlnet.external_code import ControlNetUnit -from lib_controlnet.utils import align_dim_latent, set_numpy_seed, crop_and_resize_image, \ - prepare_mask, judge_image_type +from lib_controlnet.utils import ( + align_dim_latent, + set_numpy_seed, + crop_and_resize_image, + prepare_mask, + judge_image_type, + try_unfold_unit, +) from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup from lib_controlnet.controlnet_ui.photopea import Photopea from lib_controlnet.logging import logger @@ -105,8 +111,13 @@ class ControlNetForForgeOfficial(scripts.Script): for unit in units ] assert all(isinstance(unit, ControlNetUnit) for unit in units) - enabled_units = [x for x in units if x.enabled] - return enabled_units + return [ + simple_unit + for unit in units + # Unfolds multi-inputs units. + for simple_unit in try_unfold_unit(unit) + if simple_unit.enabled + ] @staticmethod def try_crop_image_with_a1111_mask(