Restore '/controlnet/control_types' API endpoint (#713)
Restores the '/controlnet/control_types' API endpoint, which is immensely useful for anyone using ControlNet via the API
This commit is contained in:
parent
eb1e12b0dc
commit
7eb5cbad01
@ -11,6 +11,8 @@ from .global_state import (
|
|||||||
get_all_preprocessor_names,
|
get_all_preprocessor_names,
|
||||||
get_all_controlnet_names,
|
get_all_controlnet_names,
|
||||||
get_preprocessor,
|
get_preprocessor,
|
||||||
|
get_all_preprocessor_tags,
|
||||||
|
select_control_type,
|
||||||
)
|
)
|
||||||
from .utils import judge_image_type
|
from .utils import judge_image_type
|
||||||
from .logging import logger
|
from .logging import logger
|
||||||
@ -53,6 +55,30 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
|
|||||||
# "module_detail": external_code.get_modules_detail(alias_names),
|
# "module_detail": external_code.get_modules_detail(alias_names),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@app.get("/controlnet/control_types")
|
||||||
|
async def control_types():
|
||||||
|
def format_control_type(
|
||||||
|
filtered_preprocessor_list,
|
||||||
|
filtered_model_list,
|
||||||
|
default_option,
|
||||||
|
default_model,
|
||||||
|
):
|
||||||
|
control_dict = {
|
||||||
|
"module_list": filtered_preprocessor_list,
|
||||||
|
"model_list": filtered_model_list,
|
||||||
|
"default_option": default_option,
|
||||||
|
"default_model": default_model,
|
||||||
|
}
|
||||||
|
|
||||||
|
return control_dict
|
||||||
|
|
||||||
|
return {
|
||||||
|
"control_types": {
|
||||||
|
control_type: format_control_type(*select_control_type(control_type))
|
||||||
|
for control_type in get_all_preprocessor_tags()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@app.post("/controlnet/detect")
|
@app.post("/controlnet/detect")
|
||||||
async def detect(
|
async def detect(
|
||||||
controlnet_module: str = Body("none", title="Controlnet Module"),
|
controlnet_module: str = Body("none", title="Controlnet Module"),
|
||||||
|
@ -6,6 +6,7 @@ from modules import shared, sd_models
|
|||||||
from lib_controlnet.enums import StableDiffusionVersion
|
from lib_controlnet.enums import StableDiffusionVersion
|
||||||
from modules_forge.shared import controlnet_dir, supported_preprocessors
|
from modules_forge.shared import controlnet_dir, supported_preprocessors
|
||||||
|
|
||||||
|
from typing import Dict, Tuple, List
|
||||||
|
|
||||||
CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin", ".patch"]
|
CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin", ".patch"]
|
||||||
|
|
||||||
@ -56,6 +57,10 @@ controlnet_names = ['None']
|
|||||||
def get_preprocessor(name):
|
def get_preprocessor(name):
|
||||||
return supported_preprocessors.get(name, None)
|
return supported_preprocessors.get(name, None)
|
||||||
|
|
||||||
|
def get_default_preprocessor(tag):
|
||||||
|
ps = get_filtered_preprocessor_names(tag)
|
||||||
|
assert len(ps) > 0
|
||||||
|
return ps[0] if len(ps) == 1 else ps[1]
|
||||||
|
|
||||||
def get_sorted_preprocessors():
|
def get_sorted_preprocessors():
|
||||||
preprocessors = [p for k, p in supported_preprocessors.items() if k != 'None']
|
preprocessors = [p for k, p in supported_preprocessors.items() if k != 'None']
|
||||||
@ -144,3 +149,44 @@ def get_sd_version() -> StableDiffusionVersion:
|
|||||||
return StableDiffusionVersion.SD1x
|
return StableDiffusionVersion.SD1x
|
||||||
else:
|
else:
|
||||||
return StableDiffusionVersion.UNKNOWN
|
return StableDiffusionVersion.UNKNOWN
|
||||||
|
|
||||||
|
|
||||||
|
def select_control_type(
|
||||||
|
control_type: str,
|
||||||
|
sd_version: StableDiffusionVersion = StableDiffusionVersion.UNKNOWN,
|
||||||
|
) -> Tuple[List[str], List[str], str, str]:
|
||||||
|
global controlnet_names
|
||||||
|
|
||||||
|
pattern = control_type.lower()
|
||||||
|
all_models = list(controlnet_names)
|
||||||
|
|
||||||
|
if pattern == "all":
|
||||||
|
preprocessors = get_sorted_preprocessors().values()
|
||||||
|
return [
|
||||||
|
[p.name for p in preprocessors],
|
||||||
|
all_models,
|
||||||
|
'none', # default option
|
||||||
|
"None" # default model
|
||||||
|
]
|
||||||
|
|
||||||
|
filtered_model_list = get_filtered_controlnet_names(control_type)
|
||||||
|
|
||||||
|
if pattern == "none":
|
||||||
|
filtered_model_list.append("None")
|
||||||
|
|
||||||
|
assert len(filtered_model_list) > 0, "'None' model should always be available."
|
||||||
|
if len(filtered_model_list) == 1:
|
||||||
|
default_model = "None"
|
||||||
|
else:
|
||||||
|
default_model = filtered_model_list[1]
|
||||||
|
for x in filtered_model_list:
|
||||||
|
if "11" in x.split("[")[0]:
|
||||||
|
default_model = x
|
||||||
|
break
|
||||||
|
|
||||||
|
return (
|
||||||
|
get_filtered_preprocessor_names(control_type),
|
||||||
|
filtered_model_list,
|
||||||
|
get_default_preprocessor(control_type),
|
||||||
|
default_model
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user