Update controlnet.py

This commit is contained in:
lllyasviel 2024-02-01 17:35:56 -08:00
parent d4d794d09e
commit 07977a1fec

View File

@ -250,28 +250,13 @@ class ControlNetForForgeOfficial(scripts.Script):
input_image = np.stack(input_image, axis=2)
return input_image
@staticmethod
def bound_check_params(unit: external_code.ControlNetUnit) -> None:
"""
Checks and corrects negative parameters in ControlNetUnit 'unit'.
Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to
their default values if negative.
Args:
unit (external_code.ControlNetUnit): The ControlNetUnit instance to check.
"""
preprocessor = global_state.get_preprocessor(unit.module)
if unit.processor_res < 0:
unit.processor_res = int(preprocessor.slider_resolution.gradio_update_kwargs.get('value', 512))
if unit.threshold_a < 0:
unit.threshold_a = int(preprocessor.slider_1.gradio_update_kwargs.get('value', 1.0))
if unit.threshold_b < 0:
unit.threshold_b = int(preprocessor.slider_2.gradio_update_kwargs.get('value', 1.0))
return
def get_input_data(self, p, unit, preprocessor):
mask = None
input_image, resize_mode = self.choose_input_image(p, unit)
assert isinstance(input_image, np.ndarray), 'Invalid input image!'
input_image = self.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode, preprocessor)
input_image = np.ascontiguousarray(input_image.copy()).copy() # safe numpy
return input_image, mask, resize_mode
@staticmethod
def get_target_dimensions(p: StableDiffusionProcessing) -> Tuple[int, int, int, int]:
@ -297,14 +282,6 @@ class ControlNetForForgeOfficial(scripts.Script):
return h, w, hr_y, hr_x
def get_input_data(self, p, unit, preprocessor):
mask = None
input_image, resize_mode = self.choose_input_image(p, unit)
assert isinstance(input_image, np.ndarray), 'Invalid input image!'
input_image = self.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode, preprocessor)
input_image = np.ascontiguousarray(input_image.copy()).copy() # safe numpy
return input_image, mask, resize_mode
@torch.no_grad()
def process_unit_after_click_generate(self,
p: StableDiffusionProcessing,
@ -440,6 +417,29 @@ class ControlNetForForgeOfficial(scripts.Script):
logger.info(f"ControlNet Method {params.preprocessor.name} patched.")
return
@staticmethod
def bound_check_params(unit: external_code.ControlNetUnit) -> None:
"""
Checks and corrects negative parameters in ControlNetUnit 'unit'.
Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to
their default values if negative.
Args:
unit (external_code.ControlNetUnit): The ControlNetUnit instance to check.
"""
preprocessor = global_state.get_preprocessor(unit.module)
if unit.processor_res < 0:
unit.processor_res = int(preprocessor.slider_resolution.gradio_update_kwargs.get('value', 512))
if unit.threshold_a < 0:
unit.threshold_a = int(preprocessor.slider_1.gradio_update_kwargs.get('value', 1.0))
if unit.threshold_b < 0:
unit.threshold_b = int(preprocessor.slider_2.gradio_update_kwargs.get('value', 1.0))
return
@torch.no_grad()
def process_unit_after_every_sampling(self,
p: StableDiffusionProcessing,