* Fix enum conversion from string

* More fixes
This commit is contained in:
Chenlei Hu 2024-05-02 19:09:53 -04:00 committed by GitHub
parent b996316b20
commit 66845160de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 24 additions and 13 deletions

View File

@ -11,6 +11,7 @@ from lib_controlnet import (
global_state,
external_code,
)
from lib_controlnet.external_code import ControlNetUnit
from lib_controlnet.logging import logger
from lib_controlnet.controlnet_ui.openpose_editor import OpenposeEditor
from lib_controlnet.controlnet_ui.preset import ControlNetPresetUI
@ -21,7 +22,6 @@ from lib_controlnet.enums import InputMode, HiResFixOption
from modules import shared, script_callbacks
from modules.ui_components import FormRow
from modules_forge.forge_util import HWC3
from lib_controlnet.external_code import UiControlNetUnit
@dataclass
@ -172,10 +172,10 @@ class ControlNetUiGroup(object):
self.webcam_mirrored = False
# Note: All gradio elements declared in `render` will be defined as member variable.
# Update counter to trigger a force update of UiControlNetUnit.
# Update counter to trigger a force update of ControlNetUnit.
# dummy_gradio_update_trigger is useful when a field with no event subscriber available changes.
# e.g. gr.Gallery, gr.State, etc. After an update to gr.State / gr.Gallery, please increment
# this counter to trigger a sync update of UiControlNetUnit.
# this counter to trigger a sync update of ControlNetUnit.
self.dummy_gradio_update_trigger = None
self.enabled = None
self.upload_tab = None
@ -610,6 +610,12 @@ class ControlNetUiGroup(object):
)
unit = gr.State(self.default_unit)
def create_unit(*args):
return ControlNetUnit.from_dict({
k: v
for k, v in zip(vars(ControlNetUnit()).keys(), args)
})
for comp in unit_args + (self.dummy_gradio_update_trigger,):
event_subscribers = []
if hasattr(comp, "edit"):
@ -626,7 +632,7 @@ class ControlNetUiGroup(object):
for event_subscriber in event_subscribers:
event_subscriber(
fn=UiControlNetUnit, inputs=list(unit_args), outputs=unit
fn=create_unit, inputs=list(unit_args), outputs=unit
)
(
@ -634,7 +640,7 @@ class ControlNetUiGroup(object):
if self.is_img2img
else ControlNetUiGroup.a1111_context.txt2img_submit_button
).click(
fn=UiControlNetUnit,
fn=create_unit,
inputs=list(unit_args),
outputs=unit,
queue=False,

View File

@ -194,7 +194,7 @@ class ControlNetUnit:
mask_image: Optional[GradioImageMaskPair] = None
# Specifies how this unit should be applied in each pass of high-resolution fix.
# Ignored if high-resolution fix is not enabled.
hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH
hr_option: HiResFixOption = HiResFixOption.BOTH
# Indicates whether the unit is enabled; defaults to True.
enabled: bool = True
# Name of the module being used; defaults to "None".
@ -206,7 +206,7 @@ class ControlNetUnit:
# Optional image for input; defaults to None.
image: Optional[GradioImageMaskPair] = None
# Specifies the mode of image resizing; defaults to inner fit.
resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT
resize_mode: ResizeMode = ResizeMode.INNER_FIT
# Resolution for processing by the unit; defaults to -1 (unspecified).
processor_res: int = -1
# Threshold A for processing; defaults to -1 (unspecified).
@ -220,7 +220,7 @@ class ControlNetUnit:
# Enables pixel-perfect processing; defaults to False.
pixel_perfect: bool = False
# Control mode for the unit; defaults to balanced.
control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED
control_mode: ControlMode = ControlMode.BALANCED
# Following fields should only be used in the API.
# ====== Start of API only fields ======
@ -277,6 +277,11 @@ class ControlNetUnit:
"image": mask,
"mask": np.zeros_like(mask),
}
# Convert strings to enums.
unit.input_mode = InputMode(unit.input_mode)
unit.hr_option = HiResFixOption(unit.hr_option)
unit.resize_mode = ResizeMode(unit.resize_mode)
unit.control_mode = ControlMode(unit.control_mode)
return unit

View File

@ -164,9 +164,9 @@ class ControlNetForForgeOfficial(scripts.Script):
def get_input_data(self, p, unit, preprocessor, h, w):
logger.info(f'ControlNet Input Mode: {unit.input_mode}')
image_list = []
resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
resize_mode = unit.resize_mode
if unit.input_mode == external_code.InputMode.MERGE:
if unit.input_mode == InputMode.MERGE:
for idx, item in enumerate(unit.batch_input_gallery):
img_path = item['name']
logger.info(f'Try to read image: {img_path}')
@ -180,7 +180,7 @@ class ControlNetForForgeOfficial(scripts.Script):
mask = np.ascontiguousarray(cv2.imread(mask_path)[:, :, ::-1]).copy()
if img is not None:
image_list.append([img, mask])
elif unit.input_mode == external_code.InputMode.BATCH:
elif unit.input_mode == InputMode.BATCH:
image_list = []
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
batch_image_files = shared.listfiles(unit.batch_image_dir)
@ -350,7 +350,7 @@ class ControlNetForForgeOfficial(scripts.Script):
break
if has_high_res_fix:
hr_option = HiResFixOption.from_value(unit.hr_option)
hr_option = unit.hr_option
else:
hr_option = HiResFixOption.BOTH
@ -441,7 +441,7 @@ class ControlNetForForgeOfficial(scripts.Script):
)
if has_high_res_fix:
hr_option = HiResFixOption.from_value(unit.hr_option)
hr_option = unit.hr_option
else:
hr_option = HiResFixOption.BOTH