Add back ControlNet model version filter (#131)
* Add back ControlNet model version filter * Update choice after sd model changes
This commit is contained in:
parent
ac4a8820a5
commit
200f2b69ed
@ -17,7 +17,7 @@ from lib_controlnet.controlnet_ui.preset import ControlNetPresetUI
|
||||
from lib_controlnet.controlnet_ui.tool_button import ToolButton
|
||||
from lib_controlnet.controlnet_ui.photopea import Photopea
|
||||
from lib_controlnet.enums import InputMode, HiResFixOption
|
||||
from modules import shared
|
||||
from modules import shared, script_callbacks
|
||||
from modules.ui_components import FormRow
|
||||
from modules_forge.forge_util import HWC3
|
||||
from lib_controlnet.external_code import UiControlNetUnit
|
||||
@ -47,7 +47,6 @@ class A1111Context:
|
||||
|
||||
img2img_inpaint_area: Optional[gr.components.IOComponent] = None
|
||||
txt2img_enable_hr: Optional[gr.components.IOComponent] = None
|
||||
setting_sd_model_checkpoint: Optional[gr.components.IOComponent] = None
|
||||
|
||||
@property
|
||||
def img2img_inpaint_tabs(self) -> Tuple[gr.components.IOComponent]:
|
||||
@ -75,10 +74,6 @@ class A1111Context:
|
||||
"img2img_inpaint_tab": "img2img_inpaint_tab",
|
||||
"img2img_inpaint_sketch_tab": "img2img_inpaint_sketch_tab",
|
||||
"img2img_inpaint_upload_tab": "img2img_inpaint_upload_tab",
|
||||
# SDNext does not have this field. Temporarily disable the callback on
|
||||
# the checkpoint change until we find a way to register an event when
|
||||
# all A1111 UI components are ready.
|
||||
"setting_sd_model_checkpoint": "setting_sd_model_checkpoint",
|
||||
}
|
||||
return all(
|
||||
c
|
||||
@ -104,8 +99,6 @@ class A1111Context:
|
||||
"img2img_inpaint_upload_tab": "img2img_inpaint_upload_tab",
|
||||
"img2img_inpaint_full_res": "img2img_inpaint_area",
|
||||
"txt2img_hr-checkbox": "txt2img_enable_hr",
|
||||
# setting_sd_model_checkpoint is expected to be initialized last.
|
||||
# "setting_sd_model_checkpoint": "setting_sd_model_checkpoint",
|
||||
}
|
||||
elem_id = getattr(component, "elem_id", None)
|
||||
# Do not set component if it has already been set.
|
||||
@ -1141,7 +1134,6 @@ class ControlNetUiGroup(object):
|
||||
self.register_refresh_all_models()
|
||||
self.register_build_sliders()
|
||||
self.register_shift_preview()
|
||||
self.register_shift_upload_mask()
|
||||
self.register_create_canvas()
|
||||
self.register_clear_preview()
|
||||
self.register_multi_images_upload()
|
||||
@ -1162,6 +1154,26 @@ class ControlNetUiGroup(object):
|
||||
if self.is_img2img:
|
||||
self.register_img2img_same_input()
|
||||
|
||||
def register_sd_model_changed(self):
|
||||
def sd_version_changed(type_filter: str, current_model: str, setting_value: str, setting_name: str):
|
||||
"""When SD version changes, update model dropdown choices."""
|
||||
if setting_name != "sd_model_checkpoint":
|
||||
return gr.update()
|
||||
|
||||
filtered_model_list = global_state.get_filtered_controlnet_names(type_filter)
|
||||
assert len(filtered_model_list) > 0
|
||||
default_model = filtered_model_list[1] if len(filtered_model_list) > 1 else filtered_model_list[0]
|
||||
return gr.Dropdown.update(
|
||||
choices=filtered_model_list,
|
||||
value=current_model if current_model in filtered_model_list else default_model
|
||||
)
|
||||
|
||||
script_callbacks.on_setting_updated_subscriber(dict(
|
||||
fn=sd_version_changed,
|
||||
inputs=[self.type_filter, self.model],
|
||||
outputs=[self.model],
|
||||
))
|
||||
|
||||
def register_callbacks(self):
|
||||
"""Register callbacks that involves A1111 context gradio components."""
|
||||
# Prevent infinite recursion.
|
||||
@ -1172,6 +1184,8 @@ class ControlNetUiGroup(object):
|
||||
self.register_send_dimensions()
|
||||
self.register_run_annotator()
|
||||
self.register_sync_batch_dir()
|
||||
self.register_shift_upload_mask()
|
||||
self.register_sd_model_changed()
|
||||
if self.is_img2img:
|
||||
self.register_shift_crop_input_image()
|
||||
else:
|
||||
|
@ -98,7 +98,7 @@ def get_filtered_preprocessor_names(tag):
|
||||
return list(get_filtered_preprocessors(tag).keys())
|
||||
|
||||
|
||||
def get_filtered_controlnet_names(tag, filter_version: bool = True):
|
||||
def get_filtered_controlnet_names(tag):
|
||||
filtered_preprocessors = get_filtered_preprocessors(tag)
|
||||
model_filename_filters = []
|
||||
for p in filtered_preprocessors.values():
|
||||
@ -106,8 +106,8 @@ def get_filtered_controlnet_names(tag, filter_version: bool = True):
|
||||
return [
|
||||
x for x in controlnet_names
|
||||
if x == 'None' or (
|
||||
any(f.lower() in x.lower() for f in model_filename_filters) # and
|
||||
# get_sd_version().is_compatible_with(StableDiffusionVersion.detect_from_model_name(x))
|
||||
any(f.lower() in x.lower() for f in model_filename_filters) and
|
||||
get_sd_version().is_compatible_with(StableDiffusionVersion.detect_from_model_name(x))
|
||||
)
|
||||
]
|
||||
|
||||
@ -134,6 +134,8 @@ def update_controlnet_filenames():
|
||||
|
||||
|
||||
def get_sd_version() -> StableDiffusionVersion:
|
||||
if not shared.sd_model:
|
||||
return StableDiffusionVersion.UNKNOWN
|
||||
if shared.sd_model.is_sdxl:
|
||||
return StableDiffusionVersion.SDXL
|
||||
elif shared.sd_model.is_sd2:
|
||||
|
@ -129,6 +129,9 @@ callback_map = dict(
|
||||
callbacks_list_optimizers=[],
|
||||
callbacks_list_unets=[],
|
||||
)
|
||||
event_subscriber_map = dict(
|
||||
callbacks_setting_updated=[],
|
||||
)
|
||||
|
||||
|
||||
def clear_callbacks():
|
||||
@ -309,6 +312,23 @@ def list_unets_callback():
|
||||
return res
|
||||
|
||||
|
||||
def setting_updated_event_subscriber_chain(handler, component, setting_name: str):
|
||||
"""
|
||||
Arguments:
|
||||
- handler: The returned handler from calling an event subscriber.
|
||||
- component: The component that is updated. The component should provide
|
||||
the value of setting after update.
|
||||
- setting_name: The name of the setting.
|
||||
"""
|
||||
for param in event_subscriber_map['callbacks_setting_updated']:
|
||||
handler = handler.then(
|
||||
fn=lambda *args: param["fn"](*args, setting_name),
|
||||
inputs=param["inputs"] + [component],
|
||||
outputs=param["outputs"],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
|
||||
def add_callback(callbacks, fun):
|
||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||
filename = stack[0].filename if stack else 'unknown file'
|
||||
@ -483,3 +503,13 @@ def on_list_unets(callback):
|
||||
The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
|
||||
|
||||
add_callback(callback_map['callbacks_list_unets'], callback)
|
||||
|
||||
|
||||
def on_setting_updated_subscriber(subscriber_params):
|
||||
"""register a function to be called after settings update. `subscriber_params`
|
||||
should contain necessary fields to register an gradio event handler. Necessary
|
||||
fields are ["fn", "outputs", "inputs"].
|
||||
Setting name and setting value after update will be append to inputs. So be
|
||||
sure to handle these extra params when defining the callback function.
|
||||
"""
|
||||
event_subscriber_map['callbacks_setting_updated'].append(subscriber_params)
|
||||
|
@ -303,20 +303,30 @@ class UiSettings:
|
||||
methods = [component.change]
|
||||
|
||||
for method in methods:
|
||||
method(
|
||||
handler = method(
|
||||
fn=lambda value, k=k: self.run_settings_single(value, key=k),
|
||||
inputs=[component],
|
||||
outputs=[component, self.text_settings],
|
||||
show_progress=False,
|
||||
)
|
||||
script_callbacks.setting_updated_event_subscriber_chain(
|
||||
handler=handler,
|
||||
component=component,
|
||||
setting_name=k,
|
||||
)
|
||||
|
||||
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
||||
button_set_checkpoint.click(
|
||||
handler = button_set_checkpoint.click(
|
||||
fn=lambda value, _: self.run_settings_single(value, key='sd_model_checkpoint'),
|
||||
_js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
|
||||
inputs=[self.component_dict['sd_model_checkpoint'], self.dummy_component],
|
||||
outputs=[self.component_dict['sd_model_checkpoint'], self.text_settings],
|
||||
)
|
||||
script_callbacks.setting_updated_event_subscriber_chain(
|
||||
handler=handler,
|
||||
component=self.component_dict['sd_model_checkpoint'],
|
||||
setting_name="sd_model_checkpoint"
|
||||
)
|
||||
|
||||
component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user