Add back ControlNet multi-inputs upload tab (#264)

* Refactor gallery

* Add back multi-inputs

* nit
This commit is contained in:
Chenlei Hu 2024-05-02 18:25:42 -04:00 committed by GitHub
parent 29be1da7cf
commit 61a2a9d342
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 157 additions and 83 deletions

View File

@ -16,6 +16,7 @@ from lib_controlnet.controlnet_ui.openpose_editor import OpenposeEditor
from lib_controlnet.controlnet_ui.preset import ControlNetPresetUI
from lib_controlnet.controlnet_ui.tool_button import ToolButton
from lib_controlnet.controlnet_ui.photopea import Photopea
from lib_controlnet.controlnet_ui.multi_inputs_gallery import MultiInputsGallery
from lib_controlnet.enums import InputMode, HiResFixOption
from modules import shared, script_callbacks
from modules.ui_components import FormRow
@ -185,10 +186,11 @@ class ControlNetUiGroup(object):
self.mask_image = None
self.batch_tab = None
self.batch_image_dir = None
self.merge_tab = None
self.batch_upload_tab = None
self.batch_input_gallery = None
self.merge_upload_button = None
self.merge_clear_button = None
self.batch_mask_gallery = None
self.multi_inputs_upload_tab = None
self.multi_inputs_input_gallery = None
self.create_canvas = None
self.canvas_width = None
self.canvas_height = None
@ -331,31 +333,17 @@ class ControlNetUiGroup(object):
visible=False,
)
with gr.Tab(label="Batch Upload") as self.merge_tab:
with gr.Tab(label="Batch Upload") as self.batch_upload_tab:
with gr.Row():
with gr.Column():
self.batch_input_gallery = gr.Gallery(
columns=[4], rows=[2], object_fit="contain", height="auto", label="Images"
)
with gr.Row():
self.merge_upload_button = gr.UploadButton(
"Upload Images",
file_types=["image"],
file_count="multiple",
)
self.merge_clear_button = gr.Button("Clear Images")
with gr.Group(visible=False, elem_classes=["cnet-mask-gallery-group"]) as self.batch_mask_gallery_group:
with gr.Column():
self.batch_mask_gallery = gr.Gallery(
columns=[4], rows=[2], object_fit="contain", height="auto", label="Masks"
)
with gr.Row():
self.mask_merge_upload_button = gr.UploadButton(
"Upload Masks",
file_types=["image"],
file_count="multiple",
)
self.mask_merge_clear_button = gr.Button("Clear Masks")
self.batch_input_gallery = MultiInputsGallery()
self.batch_mask_gallery = MultiInputsGallery(
visible=False,
elem_classes=["cnet-mask-gallery-group"]
)
with gr.Tab(label="Multi-Inputs") as self.multi_inputs_upload_tab:
with gr.Row():
self.multi_inputs_gallery = MultiInputsGallery()
if self.photopea:
self.photopea.attach_photopea_output(self.generated_image)
@ -600,8 +588,9 @@ class ControlNetUiGroup(object):
self.use_preview_as_input,
self.batch_image_dir,
self.batch_mask_dir,
self.batch_input_gallery,
self.batch_mask_gallery,
self.batch_input_gallery.input_gallery,
self.batch_mask_gallery.input_gallery,
self.multi_inputs_gallery.input_gallery,
self.generated_image,
self.mask_image,
self.hr_option,
@ -981,20 +970,41 @@ class ControlNetUiGroup(object):
"""Controls whether the upload mask input should be visible."""
def on_checkbox_click(checked: bool, canvas_height: int, canvas_width: int):
if not checked:
# Clear mask_image if unchecked.
return gr.update(visible=False), gr.update(value=None), gr.update(value=None, visible=False), \
gr.update(visible=False), gr.update(value=None)
# Clear mask inputs if unchecked.
return (
# Single mask upload.
gr.update(visible=False),
gr.update(value=None),
# Batch mask upload dir.
gr.update(value=None, visible=False),
# Multi mask upload gallery.
gr.update(visible=False),
gr.update(value=None)
)
else:
# Init an empty canvas the same size as the generation target.
empty_canvas = np.zeros(shape=(canvas_height, canvas_width, 3), dtype=np.uint8)
return gr.update(visible=True), gr.update(value=empty_canvas), gr.update(visible=True), \
gr.update(visible=True), gr.update()
return (
# Single mask upload.
gr.update(visible=True),
gr.update(value=empty_canvas),
# Batch mask upload dir.
gr.update(visible=True),
# Multi mask upload gallery.
gr.update(visible=True),
gr.update(),
)
self.mask_upload.change(
fn=on_checkbox_click,
inputs=[self.mask_upload, self.height_slider, self.width_slider],
outputs=[self.mask_image_group, self.mask_image, self.batch_mask_dir,
self.batch_mask_gallery_group, self.batch_mask_gallery],
outputs=[
self.mask_image_group,
self.mask_image,
self.batch_mask_dir,
self.batch_mask_gallery.group,
self.batch_mask_gallery.input_gallery,
],
show_progress=False,
)
@ -1079,51 +1089,14 @@ class ControlNetUiGroup(object):
def register_multi_images_upload(self):
"""Register callbacks on merge tab multiple images upload."""
self.merge_clear_button.click(
fn=lambda: [],
inputs=[],
outputs=[self.batch_input_gallery],
).then(
fn=lambda x: gr.update(value=x + 1),
trigger_dict = dict(
fn=lambda n: gr.update(value=n + 1),
inputs=[self.dummy_gradio_update_trigger],
outputs=[self.dummy_gradio_update_trigger],
)
self.mask_merge_clear_button.click(
fn=lambda: [],
inputs=[],
outputs=[self.batch_mask_gallery],
).then(
fn=lambda x: gr.update(value=x + 1),
inputs=[self.dummy_gradio_update_trigger],
outputs=[self.dummy_gradio_update_trigger],
)
def upload_file(files, current_files):
return {file_d["name"] for file_d in current_files} | {
file.name for file in files
}
self.merge_upload_button.upload(
upload_file,
inputs=[self.merge_upload_button, self.batch_input_gallery],
outputs=[self.batch_input_gallery],
queue=False,
).then(
fn=lambda x: gr.update(value=x + 1),
inputs=[self.dummy_gradio_update_trigger],
outputs=[self.dummy_gradio_update_trigger],
)
self.mask_merge_upload_button.upload(
upload_file,
inputs=[self.mask_merge_upload_button, self.batch_mask_gallery],
outputs=[self.batch_mask_gallery],
queue=False,
).then(
fn=lambda x: gr.update(value=x + 1),
inputs=[self.dummy_gradio_update_trigger],
outputs=[self.dummy_gradio_update_trigger],
)
return
self.batch_input_gallery.register_callbacks(change_trigger=trigger_dict)
self.batch_mask_gallery.register_callbacks(change_trigger=trigger_dict)
self.multi_inputs_gallery.register_callbacks(change_trigger=trigger_dict)
def register_core_callbacks(self):
"""Register core callbacks that only involves gradio components defined
@ -1213,7 +1186,8 @@ class ControlNetUiGroup(object):
for input_tab, fn in (
(ui_group.upload_tab, simple_fn),
(ui_group.batch_tab, batch_fn),
(ui_group.merge_tab, merge_fn),
(ui_group.batch_upload_tab, batch_fn),
(ui_group.multi_inputs_upload_tab, merge_fn),
):
# Sync input_mode.
input_tab.select(

View File

@ -0,0 +1,65 @@
import gradio as gr
from typing import Optional
class MultiInputsGallery:
"""A gallery object that accepts multiple input images."""
def __init__(self, row: int = 2, column: int = 4, **group_kwargs) -> None:
self.gallery_row_num = row
self.gallery_column_num = column
self.group_kwargs = group_kwargs
self.group = None
self.input_gallery = None
self.upload_button = None
self.clear_button = None
self.render()
def render(self):
with gr.Group(**self.group_kwargs) as self.group:
with gr.Column():
self.input_gallery = gr.Gallery(
columns=[self.gallery_column_num],
rows=[self.gallery_row_num],
object_fit="contain",
height="auto",
label="Images",
)
with gr.Row():
self.upload_button = gr.UploadButton(
"Upload Images",
file_types=["image"],
file_count="multiple",
)
self.clear_button = gr.Button("Clear Images")
def register_callbacks(self, change_trigger: Optional[dict] = None):
"""Register callbacks on multiple images upload.
Argument:
- change_trigger: An optional gradio callback param dict to be called
after gallery content change. This is necessary as gallery has no
event subscriber. If the state change of gallery needs to be observed,
the caller needs to pass a change trigger to observe the change.
"""
handle1 = self.clear_button.click(
fn=lambda: [],
inputs=[],
outputs=[self.input_gallery],
)
def upload_file(files, current_files):
return {file_d["name"] for file_d in current_files} | {
file.name for file in files
}
handle2 = self.upload_button.upload(
upload_file,
inputs=[self.upload_button, self.input_gallery],
outputs=[self.input_gallery],
queue=False,
)
if change_trigger:
for handle in (handle1, handle2):
handle.then(**change_trigger)

View File

@ -183,6 +183,8 @@ class ControlNetUnit:
batch_input_gallery: Optional[List[str]] = None
# Optional list of gallery masks for batch processing; defaults to None.
batch_mask_gallery: Optional[List[str]] = None
# Optional list of gallery images for multi-inputs; defaults to None.
multi_inputs_gallery: Optional[List[str]] = None
# Holds the preview image as a NumPy array; defaults to None.
generated_image: Optional[np.ndarray] = None
# ====== End of UI only fields ======

View File

@ -1,7 +1,10 @@
from typing import Optional
from copy import copy
from modules import processing
from modules.api import api
from lib_controlnet import external_code
from lib_controlnet.external_code import InputMode, ControlNetUnit
from modules_forge.forge_util import HWC3
@ -361,3 +364,22 @@ def crop_and_resize_image(detected_map, resize_mode, h, w, fill_border_with_255=
def judge_image_type(img):
return isinstance(img, np.ndarray) and img.ndim == 3 and int(img.shape[2]) in [3, 4]
def try_unfold_unit(unit: ControlNetUnit) -> List[ControlNetUnit]:
"""Unfolds an multi-inputs unit into multiple units with one input."""
if unit.input_mode != InputMode.MERGE:
return [unit]
def extract_unit(gallery_item: dict) -> ControlNetUnit:
r_unit = copy(unit)
img = np.array(api.decode_base64_to_image(read_image(gallery_item["name"]))).astype('uint8')
r_unit.image = {
"image": img,
"mask": np.zeros_like(img),
}
r_unit.input_mode = InputMode.SIMPLE
r_unit.weight = unit.weight / len(unit.multi_inputs_gallery)
return r_unit
return [extract_unit(item) for item in unit.multi_inputs_gallery]

View File

@ -12,8 +12,14 @@ import gradio as gr
from lib_controlnet import global_state, external_code
from lib_controlnet.external_code import ControlNetUnit
from lib_controlnet.utils import align_dim_latent, set_numpy_seed, crop_and_resize_image, \
prepare_mask, judge_image_type
from lib_controlnet.utils import (
align_dim_latent,
set_numpy_seed,
crop_and_resize_image,
prepare_mask,
judge_image_type,
try_unfold_unit,
)
from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup
from lib_controlnet.controlnet_ui.photopea import Photopea
from lib_controlnet.logging import logger
@ -105,8 +111,13 @@ class ControlNetForForgeOfficial(scripts.Script):
for unit in units
]
assert all(isinstance(unit, ControlNetUnit) for unit in units)
enabled_units = [x for x in units if x.enabled]
return enabled_units
return [
simple_unit
for unit in units
# Unfolds multi-inputs units.
for simple_unit in try_unfold_unit(unit)
if simple_unit.enabled
]
@staticmethod
def try_crop_image_with_a1111_mask(