Add back ControlNet model version filter (#131)

* Add back ControlNet model version filter

* Update choice after sd model changes
This commit is contained in:
Chenlei Hu 2024-02-09 21:34:09 +00:00 committed by GitHub
parent ac4a8820a5
commit 200f2b69ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 70 additions and 14 deletions

View File

@ -17,7 +17,7 @@ from lib_controlnet.controlnet_ui.preset import ControlNetPresetUI
from lib_controlnet.controlnet_ui.tool_button import ToolButton
from lib_controlnet.controlnet_ui.photopea import Photopea
from lib_controlnet.enums import InputMode, HiResFixOption
from modules import shared
from modules import shared, script_callbacks
from modules.ui_components import FormRow
from modules_forge.forge_util import HWC3
from lib_controlnet.external_code import UiControlNetUnit
@ -47,7 +47,6 @@ class A1111Context:
img2img_inpaint_area: Optional[gr.components.IOComponent] = None
txt2img_enable_hr: Optional[gr.components.IOComponent] = None
setting_sd_model_checkpoint: Optional[gr.components.IOComponent] = None
@property
def img2img_inpaint_tabs(self) -> Tuple[gr.components.IOComponent]:
@ -75,10 +74,6 @@ class A1111Context:
"img2img_inpaint_tab": "img2img_inpaint_tab",
"img2img_inpaint_sketch_tab": "img2img_inpaint_sketch_tab",
"img2img_inpaint_upload_tab": "img2img_inpaint_upload_tab",
# SDNext does not have this field. Temporarily disable the callback on
# the checkpoint change until we find a way to register an event when
# all A1111 UI components are ready.
"setting_sd_model_checkpoint": "setting_sd_model_checkpoint",
}
return all(
c
@ -104,8 +99,6 @@ class A1111Context:
"img2img_inpaint_upload_tab": "img2img_inpaint_upload_tab",
"img2img_inpaint_full_res": "img2img_inpaint_area",
"txt2img_hr-checkbox": "txt2img_enable_hr",
# setting_sd_model_checkpoint is expected to be initialized last.
# "setting_sd_model_checkpoint": "setting_sd_model_checkpoint",
}
elem_id = getattr(component, "elem_id", None)
# Do not set component if it has already been set.
@ -1141,7 +1134,6 @@ class ControlNetUiGroup(object):
self.register_refresh_all_models()
self.register_build_sliders()
self.register_shift_preview()
self.register_shift_upload_mask()
self.register_create_canvas()
self.register_clear_preview()
self.register_multi_images_upload()
@ -1162,6 +1154,26 @@ class ControlNetUiGroup(object):
if self.is_img2img:
self.register_img2img_same_input()
def register_sd_model_changed(self):
def sd_version_changed(type_filter: str, current_model: str, setting_value: str, setting_name: str):
"""When SD version changes, update model dropdown choices."""
if setting_name != "sd_model_checkpoint":
return gr.update()
filtered_model_list = global_state.get_filtered_controlnet_names(type_filter)
assert len(filtered_model_list) > 0
default_model = filtered_model_list[1] if len(filtered_model_list) > 1 else filtered_model_list[0]
return gr.Dropdown.update(
choices=filtered_model_list,
value=current_model if current_model in filtered_model_list else default_model
)
script_callbacks.on_setting_updated_subscriber(dict(
fn=sd_version_changed,
inputs=[self.type_filter, self.model],
outputs=[self.model],
))
def register_callbacks(self):
"""Register callbacks that involves A1111 context gradio components."""
# Prevent infinite recursion.
@ -1172,6 +1184,8 @@ class ControlNetUiGroup(object):
self.register_send_dimensions()
self.register_run_annotator()
self.register_sync_batch_dir()
self.register_shift_upload_mask()
self.register_sd_model_changed()
if self.is_img2img:
self.register_shift_crop_input_image()
else:

View File

@ -98,7 +98,7 @@ def get_filtered_preprocessor_names(tag):
return list(get_filtered_preprocessors(tag).keys())
def get_filtered_controlnet_names(tag, filter_version: bool = True):
def get_filtered_controlnet_names(tag):
filtered_preprocessors = get_filtered_preprocessors(tag)
model_filename_filters = []
for p in filtered_preprocessors.values():
@ -106,8 +106,8 @@ def get_filtered_controlnet_names(tag, filter_version: bool = True):
return [
x for x in controlnet_names
if x == 'None' or (
any(f.lower() in x.lower() for f in model_filename_filters) # and
# get_sd_version().is_compatible_with(StableDiffusionVersion.detect_from_model_name(x))
any(f.lower() in x.lower() for f in model_filename_filters) and
get_sd_version().is_compatible_with(StableDiffusionVersion.detect_from_model_name(x))
)
]
@ -134,6 +134,8 @@ def update_controlnet_filenames():
def get_sd_version() -> StableDiffusionVersion:
if not shared.sd_model:
return StableDiffusionVersion.UNKNOWN
if shared.sd_model.is_sdxl:
return StableDiffusionVersion.SDXL
elif shared.sd_model.is_sd2:

View File

@ -129,6 +129,9 @@ callback_map = dict(
callbacks_list_optimizers=[],
callbacks_list_unets=[],
)
event_subscriber_map = dict(
callbacks_setting_updated=[],
)
def clear_callbacks():
@ -309,6 +312,23 @@ def list_unets_callback():
return res
def setting_updated_event_subscriber_chain(handler, component, setting_name: str):
"""
Arguments:
- handler: The returned handler from calling an event subscriber.
- component: The component that is updated. The component should provide
the value of setting after update.
- setting_name: The name of the setting.
"""
for param in event_subscriber_map['callbacks_setting_updated']:
handler = handler.then(
fn=lambda *args: param["fn"](*args, setting_name),
inputs=param["inputs"] + [component],
outputs=param["outputs"],
show_progress=False,
)
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if stack else 'unknown file'
@ -483,3 +503,13 @@ def on_list_unets(callback):
The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
add_callback(callback_map['callbacks_list_unets'], callback)
def on_setting_updated_subscriber(subscriber_params):
"""register a function to be called after settings update. `subscriber_params`
should contain necessary fields to register an gradio event handler. Necessary
fields are ["fn", "outputs", "inputs"].
Setting name and setting value after update will be append to inputs. So be
sure to handle these extra params when defining the callback function.
"""
event_subscriber_map['callbacks_setting_updated'].append(subscriber_params)

View File

@ -303,20 +303,30 @@ class UiSettings:
methods = [component.change]
for method in methods:
method(
handler = method(
fn=lambda value, k=k: self.run_settings_single(value, key=k),
inputs=[component],
outputs=[component, self.text_settings],
show_progress=False,
)
script_callbacks.setting_updated_event_subscriber_chain(
handler=handler,
component=component,
setting_name=k,
)
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
button_set_checkpoint.click(
handler = button_set_checkpoint.click(
fn=lambda value, _: self.run_settings_single(value, key='sd_model_checkpoint'),
_js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
inputs=[self.component_dict['sd_model_checkpoint'], self.dummy_component],
outputs=[self.component_dict['sd_model_checkpoint'], self.text_settings],
)
script_callbacks.setting_updated_event_subscriber_chain(
handler=handler,
component=self.component_dict['sd_model_checkpoint'],
setting_name="sd_model_checkpoint"
)
component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict]