Restore '/controlnet/control_types' API endpoint (#713)
Restores the '/controlnet/control_types' API endpoint, which is immensely useful for anyone using ControlNet via the API
This commit is contained in:
parent
eb1e12b0dc
commit
7eb5cbad01
@ -11,6 +11,8 @@ from .global_state import (
|
||||
get_all_preprocessor_names,
|
||||
get_all_controlnet_names,
|
||||
get_preprocessor,
|
||||
get_all_preprocessor_tags,
|
||||
select_control_type,
|
||||
)
|
||||
from .utils import judge_image_type
|
||||
from .logging import logger
|
||||
@ -53,6 +55,30 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
|
||||
# "module_detail": external_code.get_modules_detail(alias_names),
|
||||
}
|
||||
|
||||
@app.get("/controlnet/control_types")
|
||||
async def control_types():
|
||||
def format_control_type(
|
||||
filtered_preprocessor_list,
|
||||
filtered_model_list,
|
||||
default_option,
|
||||
default_model,
|
||||
):
|
||||
control_dict = {
|
||||
"module_list": filtered_preprocessor_list,
|
||||
"model_list": filtered_model_list,
|
||||
"default_option": default_option,
|
||||
"default_model": default_model,
|
||||
}
|
||||
|
||||
return control_dict
|
||||
|
||||
return {
|
||||
"control_types": {
|
||||
control_type: format_control_type(*select_control_type(control_type))
|
||||
for control_type in get_all_preprocessor_tags()
|
||||
}
|
||||
}
|
||||
|
||||
@app.post("/controlnet/detect")
|
||||
async def detect(
|
||||
controlnet_module: str = Body("none", title="Controlnet Module"),
|
||||
|
@ -6,6 +6,7 @@ from modules import shared, sd_models
|
||||
from lib_controlnet.enums import StableDiffusionVersion
|
||||
from modules_forge.shared import controlnet_dir, supported_preprocessors
|
||||
|
||||
from typing import Dict, Tuple, List
|
||||
|
||||
CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin", ".patch"]
|
||||
|
||||
@ -56,6 +57,10 @@ controlnet_names = ['None']
|
||||
def get_preprocessor(name):
|
||||
return supported_preprocessors.get(name, None)
|
||||
|
||||
def get_default_preprocessor(tag):
|
||||
ps = get_filtered_preprocessor_names(tag)
|
||||
assert len(ps) > 0
|
||||
return ps[0] if len(ps) == 1 else ps[1]
|
||||
|
||||
def get_sorted_preprocessors():
|
||||
preprocessors = [p for k, p in supported_preprocessors.items() if k != 'None']
|
||||
@ -144,3 +149,44 @@ def get_sd_version() -> StableDiffusionVersion:
|
||||
return StableDiffusionVersion.SD1x
|
||||
else:
|
||||
return StableDiffusionVersion.UNKNOWN
|
||||
|
||||
|
||||
def select_control_type(
|
||||
control_type: str,
|
||||
sd_version: StableDiffusionVersion = StableDiffusionVersion.UNKNOWN,
|
||||
) -> Tuple[List[str], List[str], str, str]:
|
||||
global controlnet_names
|
||||
|
||||
pattern = control_type.lower()
|
||||
all_models = list(controlnet_names)
|
||||
|
||||
if pattern == "all":
|
||||
preprocessors = get_sorted_preprocessors().values()
|
||||
return [
|
||||
[p.name for p in preprocessors],
|
||||
all_models,
|
||||
'none', # default option
|
||||
"None" # default model
|
||||
]
|
||||
|
||||
filtered_model_list = get_filtered_controlnet_names(control_type)
|
||||
|
||||
if pattern == "none":
|
||||
filtered_model_list.append("None")
|
||||
|
||||
assert len(filtered_model_list) > 0, "'None' model should always be available."
|
||||
if len(filtered_model_list) == 1:
|
||||
default_model = "None"
|
||||
else:
|
||||
default_model = filtered_model_list[1]
|
||||
for x in filtered_model_list:
|
||||
if "11" in x.split("[")[0]:
|
||||
default_model = x
|
||||
break
|
||||
|
||||
return (
|
||||
get_filtered_preprocessor_names(control_type),
|
||||
filtered_model_list,
|
||||
get_default_preprocessor(control_type),
|
||||
default_model
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user