diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/api.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/api.py index 29e7d662..8bd43ae8 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/api.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/api.py @@ -11,6 +11,8 @@ from .global_state import ( get_all_preprocessor_names, get_all_controlnet_names, get_preprocessor, + get_all_preprocessor_tags, + select_control_type, ) from .utils import judge_image_type from .logging import logger @@ -53,6 +55,30 @@ def controlnet_api(_: gr.Blocks, app: FastAPI): # "module_detail": external_code.get_modules_detail(alias_names), } + @app.get("/controlnet/control_types") + async def control_types(): + def format_control_type( + filtered_preprocessor_list, + filtered_model_list, + default_option, + default_model, + ): + control_dict = { + "module_list": filtered_preprocessor_list, + "model_list": filtered_model_list, + "default_option": default_option, + "default_model": default_model, + } + + return control_dict + + return { + "control_types": { + control_type: format_control_type(*select_control_type(control_type)) + for control_type in get_all_preprocessor_tags() + } + } + @app.post("/controlnet/detect") async def detect( controlnet_module: str = Body("none", title="Controlnet Module"), diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/global_state.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/global_state.py index dc87dda8..15bef936 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/global_state.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/global_state.py @@ -6,6 +6,7 @@ from modules import shared, sd_models from lib_controlnet.enums import StableDiffusionVersion from modules_forge.shared import controlnet_dir, supported_preprocessors +from typing import Dict, Tuple, List CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin", ".patch"] @@ -56,6 +57,10 @@ controlnet_names = ['None'] def get_preprocessor(name): return supported_preprocessors.get(name, None) +def get_default_preprocessor(tag): + ps = get_filtered_preprocessor_names(tag) + assert len(ps) > 0 + return ps[0] if len(ps) == 1 else ps[1] def get_sorted_preprocessors(): preprocessors = [p for k, p in supported_preprocessors.items() if k != 'None'] @@ -144,3 +149,44 @@ def get_sd_version() -> StableDiffusionVersion: return StableDiffusionVersion.SD1x else: return StableDiffusionVersion.UNKNOWN + + +def select_control_type( + control_type: str, + sd_version: StableDiffusionVersion = StableDiffusionVersion.UNKNOWN, +) -> Tuple[List[str], List[str], str, str]: + global controlnet_names + + pattern = control_type.lower() + all_models = list(controlnet_names) + + if pattern == "all": + preprocessors = get_sorted_preprocessors().values() + return [ + [p.name for p in preprocessors], + all_models, + 'none', # default option + "None" # default model + ] + + filtered_model_list = get_filtered_controlnet_names(control_type) + + if pattern == "none": + filtered_model_list.append("None") + + assert len(filtered_model_list) > 0, "'None' model should always be available." + if len(filtered_model_list) == 1: + default_model = "None" + else: + default_model = filtered_model_list[1] + for x in filtered_model_list: + if "11" in x.split("[")[0]: + default_model = x + break + + return ( + get_filtered_preprocessor_names(control_type), + filtered_model_list, + get_default_preprocessor(control_type), + default_model + )