ControlNet API (#162)
* ControlNet API * update cache key * nits * disable controlnet tests
This commit is contained in:
parent
bd0878754c
commit
5a7e755528
22
.github/workflows/run_tests.yaml
vendored
22
.github/workflows/run_tests.yaml
vendored
@ -52,6 +52,18 @@ jobs:
|
||||
curl -Lo "$filename" "$url"
|
||||
fi
|
||||
done
|
||||
# - name: Download ControlNet models
|
||||
# run: |
|
||||
# declare -a urls=(
|
||||
# "https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_canny.pth"
|
||||
# )
|
||||
|
||||
# for url in "${urls[@]}"; do
|
||||
# filename="models/ControlNet/${url##*/}" # Extracts the last part of the URL
|
||||
# if [ ! -f "$filename" ]; then
|
||||
# curl -Lo "$filename" "$url"
|
||||
# fi
|
||||
# done
|
||||
- name: Start test server
|
||||
run: >
|
||||
python -m coverage run
|
||||
@ -71,6 +83,16 @@ jobs:
|
||||
run: |
|
||||
wait-for-it --service 127.0.0.1:7860 -t 20
|
||||
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
|
||||
# TODO(huchenlei): Enable ControlNet tests. Currently it is too slow to run these tests on CPU with
|
||||
# real SD model. We need to find a way to load empty SD model.
|
||||
# - name: Run ControlNet tests
|
||||
# run: >
|
||||
# python -m pytest
|
||||
# --junitxml=test/results.xml
|
||||
# --cov ./extensions-builtin/sd_forge_controlnet
|
||||
# --cov-report=xml
|
||||
# --verify-base-url
|
||||
# ./extensions-builtin/sd_forge_controlnet/tests
|
||||
- name: Kill test server
|
||||
if: always()
|
||||
run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10
|
||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -40,4 +40,5 @@ notification.mp3
|
||||
/.coverage*
|
||||
/test/test_outputs
|
||||
/test/results.xml
|
||||
coverage.xml
|
||||
coverage.xml
|
||||
**/tests/**/expectations
|
108
extensions-builtin/sd_forge_controlnet/lib_controlnet/api.py
Normal file
108
extensions-builtin/sd_forge_controlnet/lib_controlnet/api.py
Normal file
@ -0,0 +1,108 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from fastapi import FastAPI, Body
|
||||
from fastapi.exceptions import HTTPException
|
||||
from PIL import Image
|
||||
import gradio as gr
|
||||
|
||||
from modules.api import api
|
||||
from .global_state import (
|
||||
get_all_preprocessor_names,
|
||||
get_all_controlnet_names,
|
||||
get_preprocessor,
|
||||
)
|
||||
from .logging import logger
|
||||
|
||||
|
||||
def encode_to_base64(image):
|
||||
if isinstance(image, str):
|
||||
return image
|
||||
elif isinstance(image, Image.Image):
|
||||
return api.encode_pil_to_base64(image)
|
||||
elif isinstance(image, np.ndarray):
|
||||
return encode_np_to_base64(image)
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
def encode_np_to_base64(image):
|
||||
pil = Image.fromarray(image)
|
||||
return api.encode_pil_to_base64(pil)
|
||||
|
||||
|
||||
def controlnet_api(_: gr.Blocks, app: FastAPI):
|
||||
@app.get("/controlnet/model_list")
|
||||
async def model_list():
|
||||
up_to_date_model_list = get_all_controlnet_names()
|
||||
logger.debug(up_to_date_model_list)
|
||||
return {"model_list": up_to_date_model_list}
|
||||
|
||||
@app.get("/controlnet/module_list")
|
||||
async def module_list():
|
||||
module_list = get_all_preprocessor_names()
|
||||
logger.debug(module_list)
|
||||
|
||||
return {
|
||||
"module_list": module_list,
|
||||
# TODO: Add back module detail.
|
||||
# "module_detail": external_code.get_modules_detail(alias_names),
|
||||
}
|
||||
|
||||
@app.post("/controlnet/detect")
|
||||
async def detect(
|
||||
controlnet_module: str = Body("none", title="Controlnet Module"),
|
||||
controlnet_input_images: List[str] = Body([], title="Controlnet Input Images"),
|
||||
controlnet_processor_res: int = Body(
|
||||
512, title="Controlnet Processor Resolution"
|
||||
),
|
||||
controlnet_threshold_a: float = Body(64, title="Controlnet Threshold a"),
|
||||
controlnet_threshold_b: float = Body(64, title="Controlnet Threshold b"),
|
||||
):
|
||||
processor_module = get_preprocessor(controlnet_module)
|
||||
if processor_module is None:
|
||||
raise HTTPException(status_code=422, detail="Module not available")
|
||||
|
||||
if len(controlnet_input_images) == 0:
|
||||
raise HTTPException(status_code=422, detail="No image selected")
|
||||
|
||||
logger.debug(
|
||||
f"Detecting {str(len(controlnet_input_images))} images with the {controlnet_module} module."
|
||||
)
|
||||
|
||||
results = []
|
||||
poses = []
|
||||
|
||||
for input_image in controlnet_input_images:
|
||||
img = np.array(api.decode_base64_to_image(input_image)).astype('uint8')
|
||||
|
||||
class JsonAcceptor:
|
||||
def __init__(self) -> None:
|
||||
self.value = None
|
||||
|
||||
def accept(self, json_dict: dict) -> None:
|
||||
self.value = json_dict
|
||||
|
||||
json_acceptor = JsonAcceptor()
|
||||
|
||||
results.append(
|
||||
processor_module(
|
||||
img,
|
||||
res=controlnet_processor_res,
|
||||
thr_a=controlnet_threshold_a,
|
||||
thr_b=controlnet_threshold_b,
|
||||
json_pose_callback=json_acceptor.accept,
|
||||
)[0]
|
||||
)
|
||||
|
||||
if "openpose" in controlnet_module:
|
||||
assert json_acceptor.value is not None
|
||||
poses.append(json_acceptor.value)
|
||||
|
||||
results64 = list(map(encode_to_base64, results))
|
||||
res = {"images": results64, "info": "Success"}
|
||||
if poses:
|
||||
res["poses"] = poses
|
||||
|
||||
return res
|
||||
|
@ -850,7 +850,6 @@ class ControlNetUiGroup(object):
|
||||
slider_1=pthr_a,
|
||||
slider_2=pthr_b,
|
||||
input_mask=mask,
|
||||
low_vram=shared.opts.data.get("controlnet_clip_detector_on_cpu", False),
|
||||
json_pose_callback=json_acceptor.accept
|
||||
if is_openpose(module)
|
||||
else None,
|
||||
|
@ -148,13 +148,15 @@ InputImage = Union[Dict[str, InputImage], Tuple[InputImage, InputImage], InputIm
|
||||
@dataclass
|
||||
class UiControlNetUnit:
|
||||
input_mode: InputMode = InputMode.SIMPLE
|
||||
use_preview_as_input: bool = False,
|
||||
batch_image_dir: str = '',
|
||||
batch_mask_dir: str = '',
|
||||
batch_input_gallery: list = [],
|
||||
batch_mask_gallery: list = [],
|
||||
generated_image: Optional[np.ndarray] = None,
|
||||
mask_image: Optional[np.ndarray] = None,
|
||||
use_preview_as_input: bool = False
|
||||
batch_image_dir: str = ''
|
||||
batch_mask_dir: str = ''
|
||||
batch_input_gallery: Optional[List[str]] = None
|
||||
batch_mask_gallery: Optional[List[str]] = None
|
||||
generated_image: Optional[np.ndarray] = None
|
||||
mask_image: Optional[np.ndarray] = None
|
||||
# If hires fix is enabled in A1111, how should this ControlNet unit be applied.
|
||||
# The value is ignored if the generation is not using hires fix.
|
||||
hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH
|
||||
enabled: bool = True
|
||||
module: str = "None"
|
||||
@ -169,6 +171,13 @@ class UiControlNetUnit:
|
||||
guidance_end: float = 1.0
|
||||
pixel_perfect: bool = False
|
||||
control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED
|
||||
# ====== Start of API only fields ======
|
||||
# Whether save the detected map of this unit. Setting this option to False
|
||||
# prevents saving the detected map or sending detected map along with
|
||||
# generated images via API. Currently the option is only accessible in API
|
||||
# calls.
|
||||
save_detected_map: bool = True
|
||||
# ====== End of API only fields ======
|
||||
|
||||
@staticmethod
|
||||
def infotext_fields():
|
||||
@ -192,6 +201,23 @@ class UiControlNetUnit:
|
||||
"hr_option",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d: Dict) -> "UiControlNetUnit":
|
||||
"""Create UiControlNetUnit from dict. This is primarily used to convert
|
||||
API json dict to UiControlNetUnit."""
|
||||
unit = UiControlNetUnit(
|
||||
**{k: v for k, v in d.items() if k in vars(UiControlNetUnit)}
|
||||
)
|
||||
if isinstance(unit.image, str):
|
||||
img = np.array(api.decode_base64_to_image(unit.image)).astype('uint8')
|
||||
unit.image = {
|
||||
"image": img,
|
||||
"mask": np.zeros_like(img),
|
||||
}
|
||||
if isinstance(unit.mask_image, str):
|
||||
unit.mask_image = np.array(api.decode_base64_to_image(unit.mask_image)).astype('uint8')
|
||||
return unit
|
||||
|
||||
|
||||
# Backward Compatible
|
||||
ControlNetUnit = UiControlNetUnit
|
||||
|
@ -11,9 +11,10 @@ from modules.api.api import decode_base64_to_image
|
||||
import gradio as gr
|
||||
|
||||
from lib_controlnet import global_state, external_code
|
||||
from lib_controlnet.external_code import ControlNetUnit
|
||||
from lib_controlnet.utils import align_dim_latent, image_dict_from_any, set_numpy_seed, crop_and_resize_image, \
|
||||
prepare_mask, judge_image_type
|
||||
from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit
|
||||
from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup
|
||||
from lib_controlnet.controlnet_ui.photopea import Photopea
|
||||
from lib_controlnet.logging import logger
|
||||
from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, \
|
||||
@ -21,6 +22,7 @@ from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusion
|
||||
from lib_controlnet.infotext import Infotext
|
||||
from modules_forge.forge_util import HWC3, numpy_to_pytorch
|
||||
from lib_controlnet.enums import HiResFixOption
|
||||
from lib_controlnet.api import controlnet_api
|
||||
|
||||
import numpy as np
|
||||
import functools
|
||||
@ -67,7 +69,7 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
max_models = shared.opts.data.get("control_net_unit_count", 3)
|
||||
gen_type = "img2img" if is_img2img else "txt2img"
|
||||
elem_id_tabname = gen_type + "_controlnet"
|
||||
default_unit = UiControlNetUnit(enabled=False, module="None", model="None")
|
||||
default_unit = ControlNetUnit(enabled=False, module="None", model="None")
|
||||
with gr.Group(elem_id=elem_id_tabname):
|
||||
with gr.Accordion(f"ControlNet Integrated", open=False, elem_id="controlnet",
|
||||
elem_classes=["controlnet"]):
|
||||
@ -95,13 +97,19 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
return tuple(controls)
|
||||
|
||||
def get_enabled_units(self, units):
|
||||
# Parse dict from API calls.
|
||||
units = [
|
||||
ControlNetUnit.from_dict(unit) if isinstance(unit, dict) else unit
|
||||
for unit in units
|
||||
]
|
||||
assert all(isinstance(unit, ControlNetUnit) for unit in units)
|
||||
enabled_units = [x for x in units if x.enabled]
|
||||
return enabled_units
|
||||
|
||||
@staticmethod
|
||||
def try_crop_image_with_a1111_mask(
|
||||
p: StableDiffusionProcessing,
|
||||
unit: external_code.ControlNetUnit,
|
||||
unit: ControlNetUnit,
|
||||
input_image: np.ndarray,
|
||||
resize_mode: external_code.ResizeMode,
|
||||
preprocessor
|
||||
@ -252,7 +260,7 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
@torch.no_grad()
|
||||
def process_unit_after_click_generate(self,
|
||||
p: StableDiffusionProcessing,
|
||||
unit: external_code.ControlNetUnit,
|
||||
unit: ControlNetUnit,
|
||||
params: ControlNetCachedParameters,
|
||||
*args, **kwargs):
|
||||
|
||||
@ -279,8 +287,6 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
return tqdm(iterable) if use_tqdm else iterable
|
||||
|
||||
for input_image, input_mask in optional_tqdm(input_list, len(input_list) > 1):
|
||||
# p.extra_result_images.append(input_image)
|
||||
|
||||
if unit.pixel_perfect:
|
||||
unit.processor_res = external_code.pixel_perfect_resolution(
|
||||
input_image,
|
||||
@ -319,14 +325,20 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
hr_option = HiResFixOption.BOTH
|
||||
|
||||
alignment_indices = [i % len(preprocessor_outputs) for i in range(p.batch_size)]
|
||||
def attach_extra_result_image(img: np.ndarray, is_high_res: bool = False):
|
||||
if (
|
||||
(is_high_res and hr_option.high_res_enabled) or
|
||||
(not is_high_res and hr_option.low_res_enabled)
|
||||
) and unit.save_detected_map:
|
||||
p.extra_result_images.append(img)
|
||||
|
||||
if preprocessor_output_is_image:
|
||||
params.control_cond = []
|
||||
params.control_cond_for_hr_fix = []
|
||||
|
||||
for preprocessor_output in preprocessor_outputs:
|
||||
control_cond = crop_and_resize_image(preprocessor_output, resize_mode, h, w)
|
||||
if hr_option.low_res_enabled:
|
||||
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond))
|
||||
attach_extra_result_image(external_code.visualize_inpaint_mask(control_cond))
|
||||
params.control_cond.append(numpy_to_pytorch(control_cond).movedim(-1, 1))
|
||||
|
||||
params.control_cond = torch.cat(params.control_cond, dim=0)[alignment_indices].contiguous()
|
||||
@ -334,8 +346,7 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
if has_high_res_fix:
|
||||
for preprocessor_output in preprocessor_outputs:
|
||||
control_cond_for_hr_fix = crop_and_resize_image(preprocessor_output, resize_mode, hr_y, hr_x)
|
||||
if hr_option.high_res_enabled:
|
||||
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond_for_hr_fix))
|
||||
attach_extra_result_image(external_code.visualize_inpaint_mask(control_cond_for_hr_fix), is_high_res=True)
|
||||
params.control_cond_for_hr_fix.append(numpy_to_pytorch(control_cond_for_hr_fix).movedim(-1, 1))
|
||||
params.control_cond_for_hr_fix = torch.cat(params.control_cond_for_hr_fix, dim=0)[alignment_indices].contiguous()
|
||||
else:
|
||||
@ -343,7 +354,7 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
else:
|
||||
params.control_cond = preprocessor_output
|
||||
params.control_cond_for_hr_fix = preprocessor_output
|
||||
p.extra_result_images.append(input_image)
|
||||
attach_extra_result_image(input_image)
|
||||
|
||||
if len(control_masks) > 0:
|
||||
params.control_mask = []
|
||||
@ -352,15 +363,13 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
for input_mask in control_masks:
|
||||
fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill
|
||||
control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border)
|
||||
if hr_option.low_res_enabled:
|
||||
p.extra_result_images.append(control_mask)
|
||||
attach_extra_result_image(control_mask)
|
||||
control_mask = numpy_to_pytorch(control_mask).movedim(-1, 1)[:, :1]
|
||||
params.control_mask.append(control_mask)
|
||||
|
||||
if has_high_res_fix:
|
||||
control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border)
|
||||
if hr_option.high_res_enabled:
|
||||
p.extra_result_images.append(control_mask_for_hr_fix)
|
||||
attach_extra_result_image(control_mask_for_hr_fix, is_high_res=True)
|
||||
control_mask_for_hr_fix = numpy_to_pytorch(control_mask_for_hr_fix).movedim(-1, 1)[:, :1]
|
||||
params.control_mask_for_hr_fix.append(control_mask_for_hr_fix)
|
||||
|
||||
@ -390,7 +399,7 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
@torch.no_grad()
|
||||
def process_unit_before_every_sampling(self,
|
||||
p: StableDiffusionProcessing,
|
||||
unit: external_code.ControlNetUnit,
|
||||
unit: ControlNetUnit,
|
||||
params: ControlNetCachedParameters,
|
||||
*args, **kwargs):
|
||||
|
||||
@ -473,14 +482,14 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def bound_check_params(unit: external_code.ControlNetUnit) -> None:
|
||||
def bound_check_params(unit: ControlNetUnit) -> None:
|
||||
"""
|
||||
Checks and corrects negative parameters in ControlNetUnit 'unit'.
|
||||
Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to
|
||||
their default values if negative.
|
||||
|
||||
Args:
|
||||
unit (external_code.ControlNetUnit): The ControlNetUnit instance to check.
|
||||
unit (ControlNetUnit): The ControlNetUnit instance to check.
|
||||
"""
|
||||
preprocessor = global_state.get_preprocessor(unit.module)
|
||||
|
||||
@ -498,7 +507,7 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
@torch.no_grad()
|
||||
def process_unit_after_every_sampling(self,
|
||||
p: StableDiffusionProcessing,
|
||||
unit: external_code.ControlNetUnit,
|
||||
unit: ControlNetUnit,
|
||||
params: ControlNetCachedParameters,
|
||||
*args, **kwargs):
|
||||
|
||||
@ -577,3 +586,4 @@ script_callbacks.on_ui_settings(on_ui_settings)
|
||||
script_callbacks.on_infotext_pasted(Infotext.on_infotext_pasted)
|
||||
script_callbacks.on_after_component(ControlNetUiGroup.on_after_component)
|
||||
script_callbacks.on_before_reload(ControlNetUiGroup.reset)
|
||||
script_callbacks.on_app_started(controlnet_api)
|
||||
|
8
extensions-builtin/sd_forge_controlnet/tests/conftest.py
Normal file
8
extensions-builtin/sd_forge_controlnet/tests/conftest.py
Normal file
@ -0,0 +1,8 @@
|
||||
import os
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
# We don't want to fail on Py.test command line arguments being
|
||||
# parsed by webui:
|
||||
os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1")
|
||||
os.environ.setdefault("FORGE_CQ_TEST", "1")
|
BIN
extensions-builtin/sd_forge_controlnet/tests/images/1girl.png
Normal file
BIN
extensions-builtin/sd_forge_controlnet/tests/images/1girl.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 482 KiB |
BIN
extensions-builtin/sd_forge_controlnet/tests/images/mask.png
Normal file
BIN
extensions-builtin/sd_forge_controlnet/tests/images/mask.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 244 B |
Binary file not shown.
After Width: | Height: | Size: 226 B |
@ -0,0 +1,85 @@
|
||||
import pytest
|
||||
|
||||
from .template import (
|
||||
APITestTemplate,
|
||||
girl_img,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"])
|
||||
def test_no_unit(gen_type):
|
||||
assert APITestTemplate(
|
||||
f"test_no_unit{gen_type}",
|
||||
gen_type,
|
||||
payload_overrides={},
|
||||
unit_overrides=[],
|
||||
input_image=girl_img,
|
||||
).exec()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"])
|
||||
def test_multiple_iter(gen_type):
|
||||
assert APITestTemplate(
|
||||
f"test_multiple_iter{gen_type}",
|
||||
gen_type,
|
||||
payload_overrides={"n_iter": 2},
|
||||
unit_overrides={},
|
||||
input_image=girl_img,
|
||||
).exec()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"])
|
||||
def test_batch_size(gen_type):
|
||||
assert APITestTemplate(
|
||||
f"test_batch_size{gen_type}",
|
||||
gen_type,
|
||||
payload_overrides={"batch_size": 2},
|
||||
unit_overrides={},
|
||||
input_image=girl_img,
|
||||
).exec()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"])
|
||||
def test_2_units(gen_type):
|
||||
assert APITestTemplate(
|
||||
f"test_2_units{gen_type}",
|
||||
gen_type,
|
||||
payload_overrides={},
|
||||
unit_overrides=[{}, {}],
|
||||
input_image=girl_img,
|
||||
).exec()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"])
|
||||
def test_preprocessor(gen_type):
|
||||
assert APITestTemplate(
|
||||
f"test_preprocessor{gen_type}",
|
||||
gen_type,
|
||||
payload_overrides={},
|
||||
unit_overrides={"module": "canny"},
|
||||
input_image=girl_img,
|
||||
).exec()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("param_name", ("processor_res", "threshold_a", "threshold_b"))
|
||||
@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"])
|
||||
def test_invalid_param(gen_type, param_name):
|
||||
assert APITestTemplate(
|
||||
f"test_invalid_param{(gen_type, param_name)}",
|
||||
gen_type,
|
||||
payload_overrides={},
|
||||
unit_overrides={param_name: -1},
|
||||
input_image=girl_img,
|
||||
).exec()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("save_map", [True, False])
|
||||
@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"])
|
||||
def test_save_map(gen_type, save_map):
|
||||
assert APITestTemplate(
|
||||
f"test_save_map{(gen_type, save_map)}",
|
||||
gen_type,
|
||||
payload_overrides={},
|
||||
unit_overrides={"save_detected_map": save_map},
|
||||
input_image=girl_img,
|
||||
).exec(expected_output_num=2 if save_map else 1)
|
329
extensions-builtin/sd_forge_controlnet/tests/web_api/template.py
Normal file
329
extensions-builtin/sd_forge_controlnet/tests/web_api/template.py
Normal file
@ -0,0 +1,329 @@
|
||||
import io
|
||||
import os
|
||||
import cv2
|
||||
import base64
|
||||
from typing import Dict, Any, List, Union, Literal, Optional
|
||||
from pathlib import Path
|
||||
import datetime
|
||||
from enum import Enum
|
||||
import numpy as np
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
|
||||
PayloadOverrideType = Dict[str, Any]
|
||||
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
test_result_dir = Path(__file__).parent / "results" / f"test_result_{timestamp}"
|
||||
test_expectation_dir = Path(__file__).parent / "expectations"
|
||||
os.makedirs(test_expectation_dir, exist_ok=True)
|
||||
resource_dir = Path(__file__).parents[1] / "images"
|
||||
|
||||
|
||||
def read_image(img_path: Path) -> str:
|
||||
img = cv2.imread(str(img_path))
|
||||
_, bytes = cv2.imencode(".png", img)
|
||||
encoded_image = base64.b64encode(bytes).decode("utf-8")
|
||||
return encoded_image
|
||||
|
||||
|
||||
def read_image_dir(img_dir: Path, suffixes=('.png', '.jpg', '.jpeg', '.webp')) -> List[str]:
|
||||
"""Try read all images in given img_dir."""
|
||||
img_dir = str(img_dir)
|
||||
images = []
|
||||
for filename in os.listdir(img_dir):
|
||||
if filename.endswith(suffixes):
|
||||
img_path = os.path.join(img_dir, filename)
|
||||
try:
|
||||
images.append(read_image(img_path))
|
||||
except IOError:
|
||||
print(f"Error opening {img_path}")
|
||||
return images
|
||||
|
||||
|
||||
girl_img = read_image(resource_dir / "1girl.png")
|
||||
mask_img = read_image(resource_dir / "mask.png")
|
||||
mask_small_img = read_image(resource_dir / "mask_small.png")
|
||||
|
||||
|
||||
general_negative_prompt = """
|
||||
(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality,
|
||||
((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot,
|
||||
backlight,(ugly:1.331), (duplicate:1.331), (morbid:1.21), (mutilated:1.21),
|
||||
(tranny:1.331), mutated hands, (poorly drawn hands:1.331), blurry, (bad anatomy:1.21),
|
||||
(bad proportions:1.331), extra limbs, (missing arms:1.331), (extra legs:1.331),
|
||||
(fused fingers:1.61051), (too many fingers:1.61051), (unclear eyes:1.331), bad hands,
|
||||
missing fingers, extra digit, bad body, easynegative, nsfw"""
|
||||
|
||||
class StableDiffusionVersion(Enum):
|
||||
"""The version family of stable diffusion model."""
|
||||
|
||||
UNKNOWN = 0
|
||||
SD1x = 1
|
||||
SD2x = 2
|
||||
SDXL = 3
|
||||
|
||||
|
||||
sd_version = StableDiffusionVersion(
|
||||
int(os.environ.get("CONTROLNET_TEST_SD_VERSION", StableDiffusionVersion.SD1x.value))
|
||||
)
|
||||
|
||||
is_full_coverage = os.environ.get("CONTROLNET_TEST_FULL_COVERAGE", None) is not None
|
||||
|
||||
|
||||
class APITestTemplate:
|
||||
is_set_expectation_run = os.environ.get("CONTROLNET_SET_EXP", "True") == "True"
|
||||
is_cq_run = bool(os.environ.get("FORGE_CQ_TEST", ""))
|
||||
BASE_URL = "http://localhost:7860/"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
gen_type: Union[Literal["img2img"], Literal["txt2img"]],
|
||||
payload_overrides: PayloadOverrideType,
|
||||
unit_overrides: Union[PayloadOverrideType, List[PayloadOverrideType]],
|
||||
input_image: Optional[str] = None,
|
||||
):
|
||||
self.name = name
|
||||
self.url = APITestTemplate.BASE_URL + "sdapi/v1/" + gen_type
|
||||
self.payload = {
|
||||
**(txt2img_payload if gen_type == "txt2img" else img2img_payload),
|
||||
**payload_overrides,
|
||||
}
|
||||
if gen_type == "img2img" and input_image is not None:
|
||||
self.payload["init_images"] = [input_image]
|
||||
|
||||
# CQ runs on CPU. Reduce steps to increase test speed.
|
||||
if "steps" not in payload_overrides and APITestTemplate.is_cq_run:
|
||||
self.payload["steps"] = 3
|
||||
|
||||
unit_overrides = (
|
||||
unit_overrides
|
||||
if isinstance(unit_overrides, (list, tuple))
|
||||
else [unit_overrides]
|
||||
)
|
||||
self.payload["alwayson_scripts"]["ControlNet"]["args"] = [
|
||||
{
|
||||
**default_unit,
|
||||
**unit_override,
|
||||
**({"image": input_image} if gen_type == "txt2img" and input_image is not None else {}),
|
||||
}
|
||||
for unit_override in unit_overrides
|
||||
]
|
||||
self.active_unit_count = len(unit_overrides)
|
||||
|
||||
def exec(self, *args, **kwargs) -> bool:
|
||||
if APITestTemplate.is_cq_run:
|
||||
return self.exec_cq(*args, **kwargs)
|
||||
else:
|
||||
return self.exec_local(*args, **kwargs)
|
||||
|
||||
def exec_cq(self, expected_output_num: Optional[int] = None, *args, **kwargs) -> bool:
|
||||
"""Execute test in CQ environment."""
|
||||
res = requests.post(url=self.url, json=self.payload)
|
||||
if res.status_code != 200:
|
||||
print(f"Unexpected status code {res.status_code}")
|
||||
return False
|
||||
|
||||
response = res.json()
|
||||
if "images" not in response:
|
||||
print(response.keys())
|
||||
return False
|
||||
|
||||
if expected_output_num is None:
|
||||
expected_output_num = self.payload["n_iter"] * self.payload["batch_size"] + self.active_unit_count
|
||||
|
||||
if len(response["images"]) != expected_output_num:
|
||||
print(f"{len(response['images'])} != {expected_output_num}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def exec_local(self, result_only: bool = True, *args, **kwargs) -> bool:
|
||||
"""Execute test in local environment."""
|
||||
if not APITestTemplate.is_set_expectation_run:
|
||||
os.makedirs(test_result_dir, exist_ok=True)
|
||||
|
||||
failed = False
|
||||
|
||||
response = requests.post(url=self.url, json=self.payload).json()
|
||||
if "images" not in response:
|
||||
print(response.keys())
|
||||
return False
|
||||
|
||||
dest_dir = (
|
||||
test_expectation_dir
|
||||
if APITestTemplate.is_set_expectation_run
|
||||
else test_result_dir
|
||||
)
|
||||
results = response["images"][:1] if result_only else response["images"]
|
||||
for i, base64image in enumerate(results):
|
||||
img_file_name = f"{self.name}_{i}.png"
|
||||
Image.open(io.BytesIO(base64.b64decode(base64image.split(",", 1)[0]))).save(
|
||||
dest_dir / img_file_name
|
||||
)
|
||||
|
||||
if not APITestTemplate.is_set_expectation_run:
|
||||
try:
|
||||
img1 = cv2.imread(os.path.join(test_expectation_dir, img_file_name))
|
||||
img2 = cv2.imread(os.path.join(test_result_dir, img_file_name))
|
||||
except Exception as e:
|
||||
print(f"Get exception reading imgs: {e}")
|
||||
failed = True
|
||||
continue
|
||||
|
||||
if img1 is None:
|
||||
print(f"Warn: No expectation file found {img_file_name}.")
|
||||
continue
|
||||
|
||||
if not expect_same_image(
|
||||
img1,
|
||||
img2,
|
||||
diff_img_path=str(test_result_dir
|
||||
/ img_file_name.replace(".png", "_diff.png")),
|
||||
):
|
||||
failed = True
|
||||
return not failed
|
||||
|
||||
|
||||
def expect_same_image(img1, img2, diff_img_path: str) -> bool:
|
||||
# Calculate the difference between the two images
|
||||
diff = cv2.absdiff(img1, img2)
|
||||
|
||||
# Set a threshold to highlight the different pixels
|
||||
threshold = 30
|
||||
diff_highlighted = np.where(diff > threshold, 255, 0).astype(np.uint8)
|
||||
|
||||
# Assert that the two images are similar within a tolerance
|
||||
similar = np.allclose(img1, img2, rtol=0.5, atol=1)
|
||||
if not similar:
|
||||
# Save the diff_highlighted image to inspect the differences
|
||||
cv2.imwrite(diff_img_path, diff_highlighted)
|
||||
|
||||
matching_pixels = np.isclose(img1, img2, rtol=0.5, atol=1)
|
||||
similar_in_general = (matching_pixels.sum() / matching_pixels.size) >= 0.95
|
||||
return similar_in_general
|
||||
|
||||
|
||||
def get_model(model_name: str) -> str:
|
||||
""" Find an available model with specified model name."""
|
||||
if model_name.lower() == "none":
|
||||
return "None"
|
||||
|
||||
r = requests.get(APITestTemplate.BASE_URL + "controlnet/model_list")
|
||||
result = r.json()
|
||||
if "model_list" not in result:
|
||||
raise ValueError("No model available")
|
||||
|
||||
candidates = [
|
||||
model
|
||||
for model in result["model_list"]
|
||||
if model_name.lower() in model.lower()
|
||||
]
|
||||
|
||||
if not candidates:
|
||||
raise ValueError("No suitable model available")
|
||||
|
||||
return candidates[0]
|
||||
|
||||
|
||||
default_unit = {
|
||||
"control_mode": 0,
|
||||
"enabled": True,
|
||||
"guidance_end": 1,
|
||||
"guidance_start": 0,
|
||||
"low_vram": False,
|
||||
"pixel_perfect": True,
|
||||
"processor_res": 512,
|
||||
"resize_mode": 1,
|
||||
"threshold_a": 64,
|
||||
"threshold_b": 64,
|
||||
"weight": 1,
|
||||
"module": "None",
|
||||
"model": get_model("sd15_canny"),
|
||||
}
|
||||
|
||||
img2img_payload = {
|
||||
"batch_size": 1,
|
||||
"cfg_scale": 7,
|
||||
"height": 768,
|
||||
"width": 512,
|
||||
"n_iter": 1,
|
||||
"steps": 10,
|
||||
"sampler_name": "Euler a",
|
||||
"prompt": "(masterpiece: 1.3), (highres: 1.3), best quality,",
|
||||
"negative_prompt": "",
|
||||
"seed": 42,
|
||||
"seed_enable_extras": False,
|
||||
"seed_resize_from_h": 0,
|
||||
"seed_resize_from_w": 0,
|
||||
"subseed": -1,
|
||||
"subseed_strength": 0,
|
||||
"override_settings": {},
|
||||
"override_settings_restore_afterwards": False,
|
||||
"do_not_save_grid": False,
|
||||
"do_not_save_samples": False,
|
||||
"s_churn": 0,
|
||||
"s_min_uncond": 0,
|
||||
"s_noise": 1,
|
||||
"s_tmax": None,
|
||||
"s_tmin": 0,
|
||||
"script_args": [],
|
||||
"script_name": None,
|
||||
"styles": [],
|
||||
"alwayson_scripts": {"ControlNet": {"args": [default_unit]}},
|
||||
"denoising_strength": 0.75,
|
||||
"initial_noise_multiplier": 1,
|
||||
"inpaint_full_res": 0,
|
||||
"inpaint_full_res_padding": 32,
|
||||
"inpainting_fill": 1,
|
||||
"inpainting_mask_invert": 0,
|
||||
"mask_blur_x": 4,
|
||||
"mask_blur_y": 4,
|
||||
"mask_blur": 4,
|
||||
"resize_mode": 0,
|
||||
}
|
||||
|
||||
txt2img_payload = {
|
||||
"alwayson_scripts": {"ControlNet": {"args": [default_unit]}},
|
||||
"batch_size": 1,
|
||||
"cfg_scale": 7,
|
||||
"comments": {},
|
||||
"disable_extra_networks": False,
|
||||
"do_not_save_grid": False,
|
||||
"do_not_save_samples": False,
|
||||
"enable_hr": False,
|
||||
"height": 768,
|
||||
"hr_negative_prompt": "",
|
||||
"hr_prompt": "",
|
||||
"hr_resize_x": 0,
|
||||
"hr_resize_y": 0,
|
||||
"hr_scale": 2,
|
||||
"hr_second_pass_steps": 0,
|
||||
"hr_upscaler": "Latent",
|
||||
"n_iter": 1,
|
||||
"negative_prompt": "",
|
||||
"override_settings": {},
|
||||
"override_settings_restore_afterwards": True,
|
||||
"prompt": "(masterpiece: 1.3), (highres: 1.3), best quality,",
|
||||
"restore_faces": False,
|
||||
"s_churn": 0.0,
|
||||
"s_min_uncond": 0,
|
||||
"s_noise": 1.0,
|
||||
"s_tmax": None,
|
||||
"s_tmin": 0.0,
|
||||
"sampler_name": "Euler a",
|
||||
"script_args": [],
|
||||
"script_name": None,
|
||||
"seed": 42,
|
||||
"seed_enable_extras": True,
|
||||
"seed_resize_from_h": -1,
|
||||
"seed_resize_from_w": -1,
|
||||
"steps": 10,
|
||||
"styles": [],
|
||||
"subseed": -1,
|
||||
"subseed_strength": 0,
|
||||
"tiling": False,
|
||||
"width": 512,
|
||||
}
|
@ -2,11 +2,13 @@ import base64
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
import itertools
|
||||
import datetime
|
||||
import uvicorn
|
||||
import ipaddress
|
||||
import requests
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
from threading import Lock
|
||||
from io import BytesIO
|
||||
from fastapi import APIRouter, Depends, FastAPI, Request, Response
|
||||
@ -103,6 +105,8 @@ def encode_pil_to_base64(image):
|
||||
with io.BytesIO() as output_bytes:
|
||||
if isinstance(image, str):
|
||||
return image
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
if opts.samples_format.lower() == 'png':
|
||||
use_metadata = False
|
||||
metadata = PngImagePlugin.PngInfo()
|
||||
@ -480,7 +484,11 @@ class Api:
|
||||
shared.state.end()
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||
b64images = [
|
||||
encode_pil_to_base64(image)
|
||||
for image in itertools.chain(processed.images, processed.extra_images)
|
||||
if send_images
|
||||
]
|
||||
|
||||
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||
|
||||
@ -547,7 +555,11 @@ class Api:
|
||||
shared.state.end()
|
||||
shared.total_tqdm.clear()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||
b64images = [
|
||||
encode_pil_to_base64(image)
|
||||
for image in itertools.chain(processed.images, processed.extra_images)
|
||||
if send_images
|
||||
]
|
||||
|
||||
if not img2imgreq.include_init_images:
|
||||
img2imgreq.init_images = None
|
||||
|
Loading…
Reference in New Issue
Block a user