Add back ControlNet multi-inputs upload tab (#264)
* Refactor gallery * Add back multi-inputs * nit
This commit is contained in:
parent
29be1da7cf
commit
61a2a9d342
@ -16,6 +16,7 @@ from lib_controlnet.controlnet_ui.openpose_editor import OpenposeEditor
|
||||
from lib_controlnet.controlnet_ui.preset import ControlNetPresetUI
|
||||
from lib_controlnet.controlnet_ui.tool_button import ToolButton
|
||||
from lib_controlnet.controlnet_ui.photopea import Photopea
|
||||
from lib_controlnet.controlnet_ui.multi_inputs_gallery import MultiInputsGallery
|
||||
from lib_controlnet.enums import InputMode, HiResFixOption
|
||||
from modules import shared, script_callbacks
|
||||
from modules.ui_components import FormRow
|
||||
@ -185,10 +186,11 @@ class ControlNetUiGroup(object):
|
||||
self.mask_image = None
|
||||
self.batch_tab = None
|
||||
self.batch_image_dir = None
|
||||
self.merge_tab = None
|
||||
self.batch_upload_tab = None
|
||||
self.batch_input_gallery = None
|
||||
self.merge_upload_button = None
|
||||
self.merge_clear_button = None
|
||||
self.batch_mask_gallery = None
|
||||
self.multi_inputs_upload_tab = None
|
||||
self.multi_inputs_input_gallery = None
|
||||
self.create_canvas = None
|
||||
self.canvas_width = None
|
||||
self.canvas_height = None
|
||||
@ -331,31 +333,17 @@ class ControlNetUiGroup(object):
|
||||
visible=False,
|
||||
)
|
||||
|
||||
with gr.Tab(label="Batch Upload") as self.merge_tab:
|
||||
with gr.Tab(label="Batch Upload") as self.batch_upload_tab:
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
self.batch_input_gallery = gr.Gallery(
|
||||
columns=[4], rows=[2], object_fit="contain", height="auto", label="Images"
|
||||
)
|
||||
with gr.Row():
|
||||
self.merge_upload_button = gr.UploadButton(
|
||||
"Upload Images",
|
||||
file_types=["image"],
|
||||
file_count="multiple",
|
||||
)
|
||||
self.merge_clear_button = gr.Button("Clear Images")
|
||||
with gr.Group(visible=False, elem_classes=["cnet-mask-gallery-group"]) as self.batch_mask_gallery_group:
|
||||
with gr.Column():
|
||||
self.batch_mask_gallery = gr.Gallery(
|
||||
columns=[4], rows=[2], object_fit="contain", height="auto", label="Masks"
|
||||
)
|
||||
with gr.Row():
|
||||
self.mask_merge_upload_button = gr.UploadButton(
|
||||
"Upload Masks",
|
||||
file_types=["image"],
|
||||
file_count="multiple",
|
||||
)
|
||||
self.mask_merge_clear_button = gr.Button("Clear Masks")
|
||||
self.batch_input_gallery = MultiInputsGallery()
|
||||
self.batch_mask_gallery = MultiInputsGallery(
|
||||
visible=False,
|
||||
elem_classes=["cnet-mask-gallery-group"]
|
||||
)
|
||||
|
||||
with gr.Tab(label="Multi-Inputs") as self.multi_inputs_upload_tab:
|
||||
with gr.Row():
|
||||
self.multi_inputs_gallery = MultiInputsGallery()
|
||||
|
||||
if self.photopea:
|
||||
self.photopea.attach_photopea_output(self.generated_image)
|
||||
@ -600,8 +588,9 @@ class ControlNetUiGroup(object):
|
||||
self.use_preview_as_input,
|
||||
self.batch_image_dir,
|
||||
self.batch_mask_dir,
|
||||
self.batch_input_gallery,
|
||||
self.batch_mask_gallery,
|
||||
self.batch_input_gallery.input_gallery,
|
||||
self.batch_mask_gallery.input_gallery,
|
||||
self.multi_inputs_gallery.input_gallery,
|
||||
self.generated_image,
|
||||
self.mask_image,
|
||||
self.hr_option,
|
||||
@ -981,20 +970,41 @@ class ControlNetUiGroup(object):
|
||||
"""Controls whether the upload mask input should be visible."""
|
||||
def on_checkbox_click(checked: bool, canvas_height: int, canvas_width: int):
|
||||
if not checked:
|
||||
# Clear mask_image if unchecked.
|
||||
return gr.update(visible=False), gr.update(value=None), gr.update(value=None, visible=False), \
|
||||
gr.update(visible=False), gr.update(value=None)
|
||||
# Clear mask inputs if unchecked.
|
||||
return (
|
||||
# Single mask upload.
|
||||
gr.update(visible=False),
|
||||
gr.update(value=None),
|
||||
# Batch mask upload dir.
|
||||
gr.update(value=None, visible=False),
|
||||
# Multi mask upload gallery.
|
||||
gr.update(visible=False),
|
||||
gr.update(value=None)
|
||||
)
|
||||
else:
|
||||
# Init an empty canvas the same size as the generation target.
|
||||
empty_canvas = np.zeros(shape=(canvas_height, canvas_width, 3), dtype=np.uint8)
|
||||
return gr.update(visible=True), gr.update(value=empty_canvas), gr.update(visible=True), \
|
||||
gr.update(visible=True), gr.update()
|
||||
return (
|
||||
# Single mask upload.
|
||||
gr.update(visible=True),
|
||||
gr.update(value=empty_canvas),
|
||||
# Batch mask upload dir.
|
||||
gr.update(visible=True),
|
||||
# Multi mask upload gallery.
|
||||
gr.update(visible=True),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
self.mask_upload.change(
|
||||
fn=on_checkbox_click,
|
||||
inputs=[self.mask_upload, self.height_slider, self.width_slider],
|
||||
outputs=[self.mask_image_group, self.mask_image, self.batch_mask_dir,
|
||||
self.batch_mask_gallery_group, self.batch_mask_gallery],
|
||||
outputs=[
|
||||
self.mask_image_group,
|
||||
self.mask_image,
|
||||
self.batch_mask_dir,
|
||||
self.batch_mask_gallery.group,
|
||||
self.batch_mask_gallery.input_gallery,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
@ -1079,51 +1089,14 @@ class ControlNetUiGroup(object):
|
||||
|
||||
def register_multi_images_upload(self):
|
||||
"""Register callbacks on merge tab multiple images upload."""
|
||||
self.merge_clear_button.click(
|
||||
fn=lambda: [],
|
||||
inputs=[],
|
||||
outputs=[self.batch_input_gallery],
|
||||
).then(
|
||||
fn=lambda x: gr.update(value=x + 1),
|
||||
trigger_dict = dict(
|
||||
fn=lambda n: gr.update(value=n + 1),
|
||||
inputs=[self.dummy_gradio_update_trigger],
|
||||
outputs=[self.dummy_gradio_update_trigger],
|
||||
)
|
||||
self.mask_merge_clear_button.click(
|
||||
fn=lambda: [],
|
||||
inputs=[],
|
||||
outputs=[self.batch_mask_gallery],
|
||||
).then(
|
||||
fn=lambda x: gr.update(value=x + 1),
|
||||
inputs=[self.dummy_gradio_update_trigger],
|
||||
outputs=[self.dummy_gradio_update_trigger],
|
||||
)
|
||||
|
||||
def upload_file(files, current_files):
|
||||
return {file_d["name"] for file_d in current_files} | {
|
||||
file.name for file in files
|
||||
}
|
||||
|
||||
self.merge_upload_button.upload(
|
||||
upload_file,
|
||||
inputs=[self.merge_upload_button, self.batch_input_gallery],
|
||||
outputs=[self.batch_input_gallery],
|
||||
queue=False,
|
||||
).then(
|
||||
fn=lambda x: gr.update(value=x + 1),
|
||||
inputs=[self.dummy_gradio_update_trigger],
|
||||
outputs=[self.dummy_gradio_update_trigger],
|
||||
)
|
||||
self.mask_merge_upload_button.upload(
|
||||
upload_file,
|
||||
inputs=[self.mask_merge_upload_button, self.batch_mask_gallery],
|
||||
outputs=[self.batch_mask_gallery],
|
||||
queue=False,
|
||||
).then(
|
||||
fn=lambda x: gr.update(value=x + 1),
|
||||
inputs=[self.dummy_gradio_update_trigger],
|
||||
outputs=[self.dummy_gradio_update_trigger],
|
||||
)
|
||||
return
|
||||
self.batch_input_gallery.register_callbacks(change_trigger=trigger_dict)
|
||||
self.batch_mask_gallery.register_callbacks(change_trigger=trigger_dict)
|
||||
self.multi_inputs_gallery.register_callbacks(change_trigger=trigger_dict)
|
||||
|
||||
def register_core_callbacks(self):
|
||||
"""Register core callbacks that only involves gradio components defined
|
||||
@ -1213,7 +1186,8 @@ class ControlNetUiGroup(object):
|
||||
for input_tab, fn in (
|
||||
(ui_group.upload_tab, simple_fn),
|
||||
(ui_group.batch_tab, batch_fn),
|
||||
(ui_group.merge_tab, merge_fn),
|
||||
(ui_group.batch_upload_tab, batch_fn),
|
||||
(ui_group.multi_inputs_upload_tab, merge_fn),
|
||||
):
|
||||
# Sync input_mode.
|
||||
input_tab.select(
|
||||
|
@ -0,0 +1,65 @@
|
||||
import gradio as gr
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class MultiInputsGallery:
|
||||
"""A gallery object that accepts multiple input images."""
|
||||
|
||||
def __init__(self, row: int = 2, column: int = 4, **group_kwargs) -> None:
|
||||
self.gallery_row_num = row
|
||||
self.gallery_column_num = column
|
||||
self.group_kwargs = group_kwargs
|
||||
|
||||
self.group = None
|
||||
self.input_gallery = None
|
||||
self.upload_button = None
|
||||
self.clear_button = None
|
||||
self.render()
|
||||
|
||||
def render(self):
|
||||
with gr.Group(**self.group_kwargs) as self.group:
|
||||
with gr.Column():
|
||||
self.input_gallery = gr.Gallery(
|
||||
columns=[self.gallery_column_num],
|
||||
rows=[self.gallery_row_num],
|
||||
object_fit="contain",
|
||||
height="auto",
|
||||
label="Images",
|
||||
)
|
||||
with gr.Row():
|
||||
self.upload_button = gr.UploadButton(
|
||||
"Upload Images",
|
||||
file_types=["image"],
|
||||
file_count="multiple",
|
||||
)
|
||||
self.clear_button = gr.Button("Clear Images")
|
||||
|
||||
def register_callbacks(self, change_trigger: Optional[dict] = None):
|
||||
"""Register callbacks on multiple images upload.
|
||||
Argument:
|
||||
- change_trigger: An optional gradio callback param dict to be called
|
||||
after gallery content change. This is necessary as gallery has no
|
||||
event subscriber. If the state change of gallery needs to be observed,
|
||||
the caller needs to pass a change trigger to observe the change.
|
||||
"""
|
||||
handle1 = self.clear_button.click(
|
||||
fn=lambda: [],
|
||||
inputs=[],
|
||||
outputs=[self.input_gallery],
|
||||
)
|
||||
|
||||
def upload_file(files, current_files):
|
||||
return {file_d["name"] for file_d in current_files} | {
|
||||
file.name for file in files
|
||||
}
|
||||
|
||||
handle2 = self.upload_button.upload(
|
||||
upload_file,
|
||||
inputs=[self.upload_button, self.input_gallery],
|
||||
outputs=[self.input_gallery],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
if change_trigger:
|
||||
for handle in (handle1, handle2):
|
||||
handle.then(**change_trigger)
|
@ -183,6 +183,8 @@ class ControlNetUnit:
|
||||
batch_input_gallery: Optional[List[str]] = None
|
||||
# Optional list of gallery masks for batch processing; defaults to None.
|
||||
batch_mask_gallery: Optional[List[str]] = None
|
||||
# Optional list of gallery images for multi-inputs; defaults to None.
|
||||
multi_inputs_gallery: Optional[List[str]] = None
|
||||
# Holds the preview image as a NumPy array; defaults to None.
|
||||
generated_image: Optional[np.ndarray] = None
|
||||
# ====== End of UI only fields ======
|
||||
|
@ -1,7 +1,10 @@
|
||||
from typing import Optional
|
||||
from copy import copy
|
||||
from modules import processing
|
||||
from modules.api import api
|
||||
|
||||
from lib_controlnet import external_code
|
||||
from lib_controlnet.external_code import InputMode, ControlNetUnit
|
||||
|
||||
from modules_forge.forge_util import HWC3
|
||||
|
||||
@ -361,3 +364,22 @@ def crop_and_resize_image(detected_map, resize_mode, h, w, fill_border_with_255=
|
||||
|
||||
def judge_image_type(img):
|
||||
return isinstance(img, np.ndarray) and img.ndim == 3 and int(img.shape[2]) in [3, 4]
|
||||
|
||||
|
||||
def try_unfold_unit(unit: ControlNetUnit) -> List[ControlNetUnit]:
|
||||
"""Unfolds an multi-inputs unit into multiple units with one input."""
|
||||
if unit.input_mode != InputMode.MERGE:
|
||||
return [unit]
|
||||
|
||||
def extract_unit(gallery_item: dict) -> ControlNetUnit:
|
||||
r_unit = copy(unit)
|
||||
img = np.array(api.decode_base64_to_image(read_image(gallery_item["name"]))).astype('uint8')
|
||||
r_unit.image = {
|
||||
"image": img,
|
||||
"mask": np.zeros_like(img),
|
||||
}
|
||||
r_unit.input_mode = InputMode.SIMPLE
|
||||
r_unit.weight = unit.weight / len(unit.multi_inputs_gallery)
|
||||
return r_unit
|
||||
|
||||
return [extract_unit(item) for item in unit.multi_inputs_gallery]
|
||||
|
@ -12,8 +12,14 @@ import gradio as gr
|
||||
|
||||
from lib_controlnet import global_state, external_code
|
||||
from lib_controlnet.external_code import ControlNetUnit
|
||||
from lib_controlnet.utils import align_dim_latent, set_numpy_seed, crop_and_resize_image, \
|
||||
prepare_mask, judge_image_type
|
||||
from lib_controlnet.utils import (
|
||||
align_dim_latent,
|
||||
set_numpy_seed,
|
||||
crop_and_resize_image,
|
||||
prepare_mask,
|
||||
judge_image_type,
|
||||
try_unfold_unit,
|
||||
)
|
||||
from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup
|
||||
from lib_controlnet.controlnet_ui.photopea import Photopea
|
||||
from lib_controlnet.logging import logger
|
||||
@ -105,8 +111,13 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
for unit in units
|
||||
]
|
||||
assert all(isinstance(unit, ControlNetUnit) for unit in units)
|
||||
enabled_units = [x for x in units if x.enabled]
|
||||
return enabled_units
|
||||
return [
|
||||
simple_unit
|
||||
for unit in units
|
||||
# Unfolds multi-inputs units.
|
||||
for simple_unit in try_unfold_unit(unit)
|
||||
if simple_unit.enabled
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def try_crop_image_with_a1111_mask(
|
||||
|
Loading…
Reference in New Issue
Block a user