Restore '/controlnet/control_types' API endpoint (#713)

Restores the '/controlnet/control_types' API endpoint, which is immensely useful for anyone using ControlNet via the API
This commit is contained in:
altoiddealer 2024-05-23 16:39:26 -04:00 committed by GitHub
parent eb1e12b0dc
commit 7eb5cbad01
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 72 additions and 0 deletions

View File

@ -11,6 +11,8 @@ from .global_state import (
get_all_preprocessor_names, get_all_preprocessor_names,
get_all_controlnet_names, get_all_controlnet_names,
get_preprocessor, get_preprocessor,
get_all_preprocessor_tags,
select_control_type,
) )
from .utils import judge_image_type from .utils import judge_image_type
from .logging import logger from .logging import logger
@ -53,6 +55,30 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
# "module_detail": external_code.get_modules_detail(alias_names), # "module_detail": external_code.get_modules_detail(alias_names),
} }
@app.get("/controlnet/control_types")
async def control_types():
def format_control_type(
filtered_preprocessor_list,
filtered_model_list,
default_option,
default_model,
):
control_dict = {
"module_list": filtered_preprocessor_list,
"model_list": filtered_model_list,
"default_option": default_option,
"default_model": default_model,
}
return control_dict
return {
"control_types": {
control_type: format_control_type(*select_control_type(control_type))
for control_type in get_all_preprocessor_tags()
}
}
@app.post("/controlnet/detect") @app.post("/controlnet/detect")
async def detect( async def detect(
controlnet_module: str = Body("none", title="Controlnet Module"), controlnet_module: str = Body("none", title="Controlnet Module"),

View File

@ -6,6 +6,7 @@ from modules import shared, sd_models
from lib_controlnet.enums import StableDiffusionVersion from lib_controlnet.enums import StableDiffusionVersion
from modules_forge.shared import controlnet_dir, supported_preprocessors from modules_forge.shared import controlnet_dir, supported_preprocessors
from typing import Dict, Tuple, List
CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin", ".patch"] CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin", ".patch"]
@ -56,6 +57,10 @@ controlnet_names = ['None']
def get_preprocessor(name): def get_preprocessor(name):
return supported_preprocessors.get(name, None) return supported_preprocessors.get(name, None)
def get_default_preprocessor(tag):
ps = get_filtered_preprocessor_names(tag)
assert len(ps) > 0
return ps[0] if len(ps) == 1 else ps[1]
def get_sorted_preprocessors(): def get_sorted_preprocessors():
preprocessors = [p for k, p in supported_preprocessors.items() if k != 'None'] preprocessors = [p for k, p in supported_preprocessors.items() if k != 'None']
@ -144,3 +149,44 @@ def get_sd_version() -> StableDiffusionVersion:
return StableDiffusionVersion.SD1x return StableDiffusionVersion.SD1x
else: else:
return StableDiffusionVersion.UNKNOWN return StableDiffusionVersion.UNKNOWN
def select_control_type(
control_type: str,
sd_version: StableDiffusionVersion = StableDiffusionVersion.UNKNOWN,
) -> Tuple[List[str], List[str], str, str]:
global controlnet_names
pattern = control_type.lower()
all_models = list(controlnet_names)
if pattern == "all":
preprocessors = get_sorted_preprocessors().values()
return [
[p.name for p in preprocessors],
all_models,
'none', # default option
"None" # default model
]
filtered_model_list = get_filtered_controlnet_names(control_type)
if pattern == "none":
filtered_model_list.append("None")
assert len(filtered_model_list) > 0, "'None' model should always be available."
if len(filtered_model_list) == 1:
default_model = "None"
else:
default_model = filtered_model_list[1]
for x in filtered_model_list:
if "11" in x.split("[")[0]:
default_model = x
break
return (
get_filtered_preprocessor_names(control_type),
filtered_model_list,
get_default_preprocessor(control_type),
default_model
)