parent
b996316b20
commit
66845160de
@ -11,6 +11,7 @@ from lib_controlnet import (
|
||||
global_state,
|
||||
external_code,
|
||||
)
|
||||
from lib_controlnet.external_code import ControlNetUnit
|
||||
from lib_controlnet.logging import logger
|
||||
from lib_controlnet.controlnet_ui.openpose_editor import OpenposeEditor
|
||||
from lib_controlnet.controlnet_ui.preset import ControlNetPresetUI
|
||||
@ -21,7 +22,6 @@ from lib_controlnet.enums import InputMode, HiResFixOption
|
||||
from modules import shared, script_callbacks
|
||||
from modules.ui_components import FormRow
|
||||
from modules_forge.forge_util import HWC3
|
||||
from lib_controlnet.external_code import UiControlNetUnit
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -172,10 +172,10 @@ class ControlNetUiGroup(object):
|
||||
self.webcam_mirrored = False
|
||||
|
||||
# Note: All gradio elements declared in `render` will be defined as member variable.
|
||||
# Update counter to trigger a force update of UiControlNetUnit.
|
||||
# Update counter to trigger a force update of ControlNetUnit.
|
||||
# dummy_gradio_update_trigger is useful when a field with no event subscriber available changes.
|
||||
# e.g. gr.Gallery, gr.State, etc. After an update to gr.State / gr.Gallery, please increment
|
||||
# this counter to trigger a sync update of UiControlNetUnit.
|
||||
# this counter to trigger a sync update of ControlNetUnit.
|
||||
self.dummy_gradio_update_trigger = None
|
||||
self.enabled = None
|
||||
self.upload_tab = None
|
||||
@ -610,6 +610,12 @@ class ControlNetUiGroup(object):
|
||||
)
|
||||
|
||||
unit = gr.State(self.default_unit)
|
||||
def create_unit(*args):
|
||||
return ControlNetUnit.from_dict({
|
||||
k: v
|
||||
for k, v in zip(vars(ControlNetUnit()).keys(), args)
|
||||
})
|
||||
|
||||
for comp in unit_args + (self.dummy_gradio_update_trigger,):
|
||||
event_subscribers = []
|
||||
if hasattr(comp, "edit"):
|
||||
@ -626,7 +632,7 @@ class ControlNetUiGroup(object):
|
||||
|
||||
for event_subscriber in event_subscribers:
|
||||
event_subscriber(
|
||||
fn=UiControlNetUnit, inputs=list(unit_args), outputs=unit
|
||||
fn=create_unit, inputs=list(unit_args), outputs=unit
|
||||
)
|
||||
|
||||
(
|
||||
@ -634,7 +640,7 @@ class ControlNetUiGroup(object):
|
||||
if self.is_img2img
|
||||
else ControlNetUiGroup.a1111_context.txt2img_submit_button
|
||||
).click(
|
||||
fn=UiControlNetUnit,
|
||||
fn=create_unit,
|
||||
inputs=list(unit_args),
|
||||
outputs=unit,
|
||||
queue=False,
|
||||
|
@ -194,7 +194,7 @@ class ControlNetUnit:
|
||||
mask_image: Optional[GradioImageMaskPair] = None
|
||||
# Specifies how this unit should be applied in each pass of high-resolution fix.
|
||||
# Ignored if high-resolution fix is not enabled.
|
||||
hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH
|
||||
hr_option: HiResFixOption = HiResFixOption.BOTH
|
||||
# Indicates whether the unit is enabled; defaults to True.
|
||||
enabled: bool = True
|
||||
# Name of the module being used; defaults to "None".
|
||||
@ -206,7 +206,7 @@ class ControlNetUnit:
|
||||
# Optional image for input; defaults to None.
|
||||
image: Optional[GradioImageMaskPair] = None
|
||||
# Specifies the mode of image resizing; defaults to inner fit.
|
||||
resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT
|
||||
resize_mode: ResizeMode = ResizeMode.INNER_FIT
|
||||
# Resolution for processing by the unit; defaults to -1 (unspecified).
|
||||
processor_res: int = -1
|
||||
# Threshold A for processing; defaults to -1 (unspecified).
|
||||
@ -220,7 +220,7 @@ class ControlNetUnit:
|
||||
# Enables pixel-perfect processing; defaults to False.
|
||||
pixel_perfect: bool = False
|
||||
# Control mode for the unit; defaults to balanced.
|
||||
control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED
|
||||
control_mode: ControlMode = ControlMode.BALANCED
|
||||
|
||||
# Following fields should only be used in the API.
|
||||
# ====== Start of API only fields ======
|
||||
@ -277,6 +277,11 @@ class ControlNetUnit:
|
||||
"image": mask,
|
||||
"mask": np.zeros_like(mask),
|
||||
}
|
||||
# Convert strings to enums.
|
||||
unit.input_mode = InputMode(unit.input_mode)
|
||||
unit.hr_option = HiResFixOption(unit.hr_option)
|
||||
unit.resize_mode = ResizeMode(unit.resize_mode)
|
||||
unit.control_mode = ControlMode(unit.control_mode)
|
||||
return unit
|
||||
|
||||
|
||||
|
@ -164,9 +164,9 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
def get_input_data(self, p, unit, preprocessor, h, w):
|
||||
logger.info(f'ControlNet Input Mode: {unit.input_mode}')
|
||||
image_list = []
|
||||
resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
|
||||
resize_mode = unit.resize_mode
|
||||
|
||||
if unit.input_mode == external_code.InputMode.MERGE:
|
||||
if unit.input_mode == InputMode.MERGE:
|
||||
for idx, item in enumerate(unit.batch_input_gallery):
|
||||
img_path = item['name']
|
||||
logger.info(f'Try to read image: {img_path}')
|
||||
@ -180,7 +180,7 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
mask = np.ascontiguousarray(cv2.imread(mask_path)[:, :, ::-1]).copy()
|
||||
if img is not None:
|
||||
image_list.append([img, mask])
|
||||
elif unit.input_mode == external_code.InputMode.BATCH:
|
||||
elif unit.input_mode == InputMode.BATCH:
|
||||
image_list = []
|
||||
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
|
||||
batch_image_files = shared.listfiles(unit.batch_image_dir)
|
||||
@ -350,7 +350,7 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
break
|
||||
|
||||
if has_high_res_fix:
|
||||
hr_option = HiResFixOption.from_value(unit.hr_option)
|
||||
hr_option = unit.hr_option
|
||||
else:
|
||||
hr_option = HiResFixOption.BOTH
|
||||
|
||||
@ -441,7 +441,7 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
)
|
||||
|
||||
if has_high_res_fix:
|
||||
hr_option = HiResFixOption.from_value(unit.hr_option)
|
||||
hr_option = unit.hr_option
|
||||
else:
|
||||
hr_option = HiResFixOption.BOTH
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user