diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py index a61df353..928348dd 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py @@ -11,6 +11,7 @@ from lib_controlnet import ( global_state, external_code, ) +from lib_controlnet.external_code import ControlNetUnit from lib_controlnet.logging import logger from lib_controlnet.controlnet_ui.openpose_editor import OpenposeEditor from lib_controlnet.controlnet_ui.preset import ControlNetPresetUI @@ -21,7 +22,6 @@ from lib_controlnet.enums import InputMode, HiResFixOption from modules import shared, script_callbacks from modules.ui_components import FormRow from modules_forge.forge_util import HWC3 -from lib_controlnet.external_code import UiControlNetUnit @dataclass @@ -172,10 +172,10 @@ class ControlNetUiGroup(object): self.webcam_mirrored = False # Note: All gradio elements declared in `render` will be defined as member variable. - # Update counter to trigger a force update of UiControlNetUnit. + # Update counter to trigger a force update of ControlNetUnit. # dummy_gradio_update_trigger is useful when a field with no event subscriber available changes. # e.g. gr.Gallery, gr.State, etc. After an update to gr.State / gr.Gallery, please increment - # this counter to trigger a sync update of UiControlNetUnit. + # this counter to trigger a sync update of ControlNetUnit. self.dummy_gradio_update_trigger = None self.enabled = None self.upload_tab = None @@ -610,6 +610,12 @@ class ControlNetUiGroup(object): ) unit = gr.State(self.default_unit) + def create_unit(*args): + return ControlNetUnit.from_dict({ + k: v + for k, v in zip(vars(ControlNetUnit()).keys(), args) + }) + for comp in unit_args + (self.dummy_gradio_update_trigger,): event_subscribers = [] if hasattr(comp, "edit"): @@ -626,7 +632,7 @@ class ControlNetUiGroup(object): for event_subscriber in event_subscribers: event_subscriber( - fn=UiControlNetUnit, inputs=list(unit_args), outputs=unit + fn=create_unit, inputs=list(unit_args), outputs=unit ) ( @@ -634,7 +640,7 @@ class ControlNetUiGroup(object): if self.is_img2img else ControlNetUiGroup.a1111_context.txt2img_submit_button ).click( - fn=UiControlNetUnit, + fn=create_unit, inputs=list(unit_args), outputs=unit, queue=False, diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py index 405719a2..4e47bd05 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py @@ -194,7 +194,7 @@ class ControlNetUnit: mask_image: Optional[GradioImageMaskPair] = None # Specifies how this unit should be applied in each pass of high-resolution fix. # Ignored if high-resolution fix is not enabled. - hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH + hr_option: HiResFixOption = HiResFixOption.BOTH # Indicates whether the unit is enabled; defaults to True. enabled: bool = True # Name of the module being used; defaults to "None". @@ -206,7 +206,7 @@ class ControlNetUnit: # Optional image for input; defaults to None. image: Optional[GradioImageMaskPair] = None # Specifies the mode of image resizing; defaults to inner fit. - resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT + resize_mode: ResizeMode = ResizeMode.INNER_FIT # Resolution for processing by the unit; defaults to -1 (unspecified). processor_res: int = -1 # Threshold A for processing; defaults to -1 (unspecified). @@ -220,7 +220,7 @@ class ControlNetUnit: # Enables pixel-perfect processing; defaults to False. pixel_perfect: bool = False # Control mode for the unit; defaults to balanced. - control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED + control_mode: ControlMode = ControlMode.BALANCED # Following fields should only be used in the API. # ====== Start of API only fields ====== @@ -277,6 +277,11 @@ class ControlNetUnit: "image": mask, "mask": np.zeros_like(mask), } + # Convert strings to enums. + unit.input_mode = InputMode(unit.input_mode) + unit.hr_option = HiResFixOption(unit.hr_option) + unit.resize_mode = ResizeMode(unit.resize_mode) + unit.control_mode = ControlMode(unit.control_mode) return unit diff --git a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py index 9b3615b3..844c5f90 100644 --- a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py +++ b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py @@ -164,9 +164,9 @@ class ControlNetForForgeOfficial(scripts.Script): def get_input_data(self, p, unit, preprocessor, h, w): logger.info(f'ControlNet Input Mode: {unit.input_mode}') image_list = [] - resize_mode = external_code.resize_mode_from_value(unit.resize_mode) + resize_mode = unit.resize_mode - if unit.input_mode == external_code.InputMode.MERGE: + if unit.input_mode == InputMode.MERGE: for idx, item in enumerate(unit.batch_input_gallery): img_path = item['name'] logger.info(f'Try to read image: {img_path}') @@ -180,7 +180,7 @@ class ControlNetForForgeOfficial(scripts.Script): mask = np.ascontiguousarray(cv2.imread(mask_path)[:, :, ::-1]).copy() if img is not None: image_list.append([img, mask]) - elif unit.input_mode == external_code.InputMode.BATCH: + elif unit.input_mode == InputMode.BATCH: image_list = [] image_extensions = ['.jpg', '.jpeg', '.png', '.bmp'] batch_image_files = shared.listfiles(unit.batch_image_dir) @@ -350,7 +350,7 @@ class ControlNetForForgeOfficial(scripts.Script): break if has_high_res_fix: - hr_option = HiResFixOption.from_value(unit.hr_option) + hr_option = unit.hr_option else: hr_option = HiResFixOption.BOTH @@ -441,7 +441,7 @@ class ControlNetForForgeOfficial(scripts.Script): ) if has_high_res_fix: - hr_option = HiResFixOption.from_value(unit.hr_option) + hr_option = unit.hr_option else: hr_option = HiResFixOption.BOTH