diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 8f48d85b..27dacbff 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -52,6 +52,18 @@ jobs: curl -Lo "$filename" "$url" fi done + # - name: Download ControlNet models + # run: | + # declare -a urls=( + # "https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_canny.pth" + # ) + + # for url in "${urls[@]}"; do + # filename="models/ControlNet/${url##*/}" # Extracts the last part of the URL + # if [ ! -f "$filename" ]; then + # curl -Lo "$filename" "$url" + # fi + # done - name: Start test server run: > python -m coverage run @@ -71,6 +83,16 @@ jobs: run: | wait-for-it --service 127.0.0.1:7860 -t 20 python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test + # TODO(huchenlei): Enable ControlNet tests. Currently it is too slow to run these tests on CPU with + # real SD model. We need to find a way to load empty SD model. + # - name: Run ControlNet tests + # run: > + # python -m pytest + # --junitxml=test/results.xml + # --cov ./extensions-builtin/sd_forge_controlnet + # --cov-report=xml + # --verify-base-url + # ./extensions-builtin/sd_forge_controlnet/tests - name: Kill test server if: always() run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10 diff --git a/.gitignore b/.gitignore index 4149b673..ca7c47ee 100644 --- a/.gitignore +++ b/.gitignore @@ -40,4 +40,5 @@ notification.mp3 /.coverage* /test/test_outputs /test/results.xml -coverage.xml \ No newline at end of file +coverage.xml +**/tests/**/expectations \ No newline at end of file diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/api.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/api.py new file mode 100644 index 00000000..ddfc3153 --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/api.py @@ -0,0 +1,108 @@ +from typing import List + +import numpy as np +from fastapi import FastAPI, Body +from fastapi.exceptions import HTTPException +from PIL import Image +import gradio as gr + +from modules.api import api +from .global_state import ( + get_all_preprocessor_names, + get_all_controlnet_names, + get_preprocessor, +) +from .logging import logger + + +def encode_to_base64(image): + if isinstance(image, str): + return image + elif isinstance(image, Image.Image): + return api.encode_pil_to_base64(image) + elif isinstance(image, np.ndarray): + return encode_np_to_base64(image) + else: + return "" + + +def encode_np_to_base64(image): + pil = Image.fromarray(image) + return api.encode_pil_to_base64(pil) + + +def controlnet_api(_: gr.Blocks, app: FastAPI): + @app.get("/controlnet/model_list") + async def model_list(): + up_to_date_model_list = get_all_controlnet_names() + logger.debug(up_to_date_model_list) + return {"model_list": up_to_date_model_list} + + @app.get("/controlnet/module_list") + async def module_list(): + module_list = get_all_preprocessor_names() + logger.debug(module_list) + + return { + "module_list": module_list, + # TODO: Add back module detail. + # "module_detail": external_code.get_modules_detail(alias_names), + } + + @app.post("/controlnet/detect") + async def detect( + controlnet_module: str = Body("none", title="Controlnet Module"), + controlnet_input_images: List[str] = Body([], title="Controlnet Input Images"), + controlnet_processor_res: int = Body( + 512, title="Controlnet Processor Resolution" + ), + controlnet_threshold_a: float = Body(64, title="Controlnet Threshold a"), + controlnet_threshold_b: float = Body(64, title="Controlnet Threshold b"), + ): + processor_module = get_preprocessor(controlnet_module) + if processor_module is None: + raise HTTPException(status_code=422, detail="Module not available") + + if len(controlnet_input_images) == 0: + raise HTTPException(status_code=422, detail="No image selected") + + logger.debug( + f"Detecting {str(len(controlnet_input_images))} images with the {controlnet_module} module." + ) + + results = [] + poses = [] + + for input_image in controlnet_input_images: + img = np.array(api.decode_base64_to_image(input_image)).astype('uint8') + + class JsonAcceptor: + def __init__(self) -> None: + self.value = None + + def accept(self, json_dict: dict) -> None: + self.value = json_dict + + json_acceptor = JsonAcceptor() + + results.append( + processor_module( + img, + res=controlnet_processor_res, + thr_a=controlnet_threshold_a, + thr_b=controlnet_threshold_b, + json_pose_callback=json_acceptor.accept, + )[0] + ) + + if "openpose" in controlnet_module: + assert json_acceptor.value is not None + poses.append(json_acceptor.value) + + results64 = list(map(encode_to_base64, results)) + res = {"images": results64, "info": "Success"} + if poses: + res["poses"] = poses + + return res + diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py index 389788bf..b3e16df6 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/controlnet_ui/controlnet_ui_group.py @@ -850,7 +850,6 @@ class ControlNetUiGroup(object): slider_1=pthr_a, slider_2=pthr_b, input_mask=mask, - low_vram=shared.opts.data.get("controlnet_clip_detector_on_cpu", False), json_pose_callback=json_acceptor.accept if is_openpose(module) else None, diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py index a5d2bb4f..95ff32cb 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/external_code.py @@ -148,13 +148,15 @@ InputImage = Union[Dict[str, InputImage], Tuple[InputImage, InputImage], InputIm @dataclass class UiControlNetUnit: input_mode: InputMode = InputMode.SIMPLE - use_preview_as_input: bool = False, - batch_image_dir: str = '', - batch_mask_dir: str = '', - batch_input_gallery: list = [], - batch_mask_gallery: list = [], - generated_image: Optional[np.ndarray] = None, - mask_image: Optional[np.ndarray] = None, + use_preview_as_input: bool = False + batch_image_dir: str = '' + batch_mask_dir: str = '' + batch_input_gallery: Optional[List[str]] = None + batch_mask_gallery: Optional[List[str]] = None + generated_image: Optional[np.ndarray] = None + mask_image: Optional[np.ndarray] = None + # If hires fix is enabled in A1111, how should this ControlNet unit be applied. + # The value is ignored if the generation is not using hires fix. hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH enabled: bool = True module: str = "None" @@ -169,6 +171,13 @@ class UiControlNetUnit: guidance_end: float = 1.0 pixel_perfect: bool = False control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED + # ====== Start of API only fields ====== + # Whether save the detected map of this unit. Setting this option to False + # prevents saving the detected map or sending detected map along with + # generated images via API. Currently the option is only accessible in API + # calls. + save_detected_map: bool = True + # ====== End of API only fields ====== @staticmethod def infotext_fields(): @@ -192,6 +201,23 @@ class UiControlNetUnit: "hr_option", ) + @staticmethod + def from_dict(d: Dict) -> "UiControlNetUnit": + """Create UiControlNetUnit from dict. This is primarily used to convert + API json dict to UiControlNetUnit.""" + unit = UiControlNetUnit( + **{k: v for k, v in d.items() if k in vars(UiControlNetUnit)} + ) + if isinstance(unit.image, str): + img = np.array(api.decode_base64_to_image(unit.image)).astype('uint8') + unit.image = { + "image": img, + "mask": np.zeros_like(img), + } + if isinstance(unit.mask_image, str): + unit.mask_image = np.array(api.decode_base64_to_image(unit.mask_image)).astype('uint8') + return unit + # Backward Compatible ControlNetUnit = UiControlNetUnit diff --git a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py index 5cc50697..dd5e3fa1 100644 --- a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py +++ b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py @@ -11,9 +11,10 @@ from modules.api.api import decode_base64_to_image import gradio as gr from lib_controlnet import global_state, external_code +from lib_controlnet.external_code import ControlNetUnit from lib_controlnet.utils import align_dim_latent, image_dict_from_any, set_numpy_seed, crop_and_resize_image, \ prepare_mask, judge_image_type -from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit +from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup from lib_controlnet.controlnet_ui.photopea import Photopea from lib_controlnet.logging import logger from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, \ @@ -21,6 +22,7 @@ from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusion from lib_controlnet.infotext import Infotext from modules_forge.forge_util import HWC3, numpy_to_pytorch from lib_controlnet.enums import HiResFixOption +from lib_controlnet.api import controlnet_api import numpy as np import functools @@ -67,7 +69,7 @@ class ControlNetForForgeOfficial(scripts.Script): max_models = shared.opts.data.get("control_net_unit_count", 3) gen_type = "img2img" if is_img2img else "txt2img" elem_id_tabname = gen_type + "_controlnet" - default_unit = UiControlNetUnit(enabled=False, module="None", model="None") + default_unit = ControlNetUnit(enabled=False, module="None", model="None") with gr.Group(elem_id=elem_id_tabname): with gr.Accordion(f"ControlNet Integrated", open=False, elem_id="controlnet", elem_classes=["controlnet"]): @@ -95,13 +97,19 @@ class ControlNetForForgeOfficial(scripts.Script): return tuple(controls) def get_enabled_units(self, units): + # Parse dict from API calls. + units = [ + ControlNetUnit.from_dict(unit) if isinstance(unit, dict) else unit + for unit in units + ] + assert all(isinstance(unit, ControlNetUnit) for unit in units) enabled_units = [x for x in units if x.enabled] return enabled_units @staticmethod def try_crop_image_with_a1111_mask( p: StableDiffusionProcessing, - unit: external_code.ControlNetUnit, + unit: ControlNetUnit, input_image: np.ndarray, resize_mode: external_code.ResizeMode, preprocessor @@ -252,7 +260,7 @@ class ControlNetForForgeOfficial(scripts.Script): @torch.no_grad() def process_unit_after_click_generate(self, p: StableDiffusionProcessing, - unit: external_code.ControlNetUnit, + unit: ControlNetUnit, params: ControlNetCachedParameters, *args, **kwargs): @@ -279,8 +287,6 @@ class ControlNetForForgeOfficial(scripts.Script): return tqdm(iterable) if use_tqdm else iterable for input_image, input_mask in optional_tqdm(input_list, len(input_list) > 1): - # p.extra_result_images.append(input_image) - if unit.pixel_perfect: unit.processor_res = external_code.pixel_perfect_resolution( input_image, @@ -319,14 +325,20 @@ class ControlNetForForgeOfficial(scripts.Script): hr_option = HiResFixOption.BOTH alignment_indices = [i % len(preprocessor_outputs) for i in range(p.batch_size)] + def attach_extra_result_image(img: np.ndarray, is_high_res: bool = False): + if ( + (is_high_res and hr_option.high_res_enabled) or + (not is_high_res and hr_option.low_res_enabled) + ) and unit.save_detected_map: + p.extra_result_images.append(img) + if preprocessor_output_is_image: params.control_cond = [] params.control_cond_for_hr_fix = [] for preprocessor_output in preprocessor_outputs: control_cond = crop_and_resize_image(preprocessor_output, resize_mode, h, w) - if hr_option.low_res_enabled: - p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond)) + attach_extra_result_image(external_code.visualize_inpaint_mask(control_cond)) params.control_cond.append(numpy_to_pytorch(control_cond).movedim(-1, 1)) params.control_cond = torch.cat(params.control_cond, dim=0)[alignment_indices].contiguous() @@ -334,8 +346,7 @@ class ControlNetForForgeOfficial(scripts.Script): if has_high_res_fix: for preprocessor_output in preprocessor_outputs: control_cond_for_hr_fix = crop_and_resize_image(preprocessor_output, resize_mode, hr_y, hr_x) - if hr_option.high_res_enabled: - p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond_for_hr_fix)) + attach_extra_result_image(external_code.visualize_inpaint_mask(control_cond_for_hr_fix), is_high_res=True) params.control_cond_for_hr_fix.append(numpy_to_pytorch(control_cond_for_hr_fix).movedim(-1, 1)) params.control_cond_for_hr_fix = torch.cat(params.control_cond_for_hr_fix, dim=0)[alignment_indices].contiguous() else: @@ -343,7 +354,7 @@ class ControlNetForForgeOfficial(scripts.Script): else: params.control_cond = preprocessor_output params.control_cond_for_hr_fix = preprocessor_output - p.extra_result_images.append(input_image) + attach_extra_result_image(input_image) if len(control_masks) > 0: params.control_mask = [] @@ -352,15 +363,13 @@ class ControlNetForForgeOfficial(scripts.Script): for input_mask in control_masks: fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border) - if hr_option.low_res_enabled: - p.extra_result_images.append(control_mask) + attach_extra_result_image(control_mask) control_mask = numpy_to_pytorch(control_mask).movedim(-1, 1)[:, :1] params.control_mask.append(control_mask) if has_high_res_fix: control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border) - if hr_option.high_res_enabled: - p.extra_result_images.append(control_mask_for_hr_fix) + attach_extra_result_image(control_mask_for_hr_fix, is_high_res=True) control_mask_for_hr_fix = numpy_to_pytorch(control_mask_for_hr_fix).movedim(-1, 1)[:, :1] params.control_mask_for_hr_fix.append(control_mask_for_hr_fix) @@ -390,7 +399,7 @@ class ControlNetForForgeOfficial(scripts.Script): @torch.no_grad() def process_unit_before_every_sampling(self, p: StableDiffusionProcessing, - unit: external_code.ControlNetUnit, + unit: ControlNetUnit, params: ControlNetCachedParameters, *args, **kwargs): @@ -473,14 +482,14 @@ class ControlNetForForgeOfficial(scripts.Script): return @staticmethod - def bound_check_params(unit: external_code.ControlNetUnit) -> None: + def bound_check_params(unit: ControlNetUnit) -> None: """ Checks and corrects negative parameters in ControlNetUnit 'unit'. Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to their default values if negative. Args: - unit (external_code.ControlNetUnit): The ControlNetUnit instance to check. + unit (ControlNetUnit): The ControlNetUnit instance to check. """ preprocessor = global_state.get_preprocessor(unit.module) @@ -498,7 +507,7 @@ class ControlNetForForgeOfficial(scripts.Script): @torch.no_grad() def process_unit_after_every_sampling(self, p: StableDiffusionProcessing, - unit: external_code.ControlNetUnit, + unit: ControlNetUnit, params: ControlNetCachedParameters, *args, **kwargs): @@ -577,3 +586,4 @@ script_callbacks.on_ui_settings(on_ui_settings) script_callbacks.on_infotext_pasted(Infotext.on_infotext_pasted) script_callbacks.on_after_component(ControlNetUiGroup.on_after_component) script_callbacks.on_before_reload(ControlNetUiGroup.reset) +script_callbacks.on_app_started(controlnet_api) diff --git a/extensions-builtin/sd_forge_controlnet/tests/conftest.py b/extensions-builtin/sd_forge_controlnet/tests/conftest.py new file mode 100644 index 00000000..cfd85061 --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/tests/conftest.py @@ -0,0 +1,8 @@ +import os + + +def pytest_configure(config): + # We don't want to fail on Py.test command line arguments being + # parsed by webui: + os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1") + os.environ.setdefault("FORGE_CQ_TEST", "1") diff --git a/extensions-builtin/sd_forge_controlnet/tests/images/1girl.png b/extensions-builtin/sd_forge_controlnet/tests/images/1girl.png new file mode 100644 index 00000000..d825e716 Binary files /dev/null and b/extensions-builtin/sd_forge_controlnet/tests/images/1girl.png differ diff --git a/extensions-builtin/sd_forge_controlnet/tests/images/mask.png b/extensions-builtin/sd_forge_controlnet/tests/images/mask.png new file mode 100644 index 00000000..166203af Binary files /dev/null and b/extensions-builtin/sd_forge_controlnet/tests/images/mask.png differ diff --git a/extensions-builtin/sd_forge_controlnet/tests/images/mask_small.png b/extensions-builtin/sd_forge_controlnet/tests/images/mask_small.png new file mode 100644 index 00000000..c48d77e4 Binary files /dev/null and b/extensions-builtin/sd_forge_controlnet/tests/images/mask_small.png differ diff --git a/extensions-builtin/sd_forge_controlnet/tests/web_api/__init__.py b/extensions-builtin/sd_forge_controlnet/tests/web_api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/extensions-builtin/sd_forge_controlnet/tests/web_api/generation_test.py b/extensions-builtin/sd_forge_controlnet/tests/web_api/generation_test.py new file mode 100644 index 00000000..9656b266 --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/tests/web_api/generation_test.py @@ -0,0 +1,85 @@ +import pytest + +from .template import ( + APITestTemplate, + girl_img, +) + + +@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"]) +def test_no_unit(gen_type): + assert APITestTemplate( + f"test_no_unit{gen_type}", + gen_type, + payload_overrides={}, + unit_overrides=[], + input_image=girl_img, + ).exec() + + +@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"]) +def test_multiple_iter(gen_type): + assert APITestTemplate( + f"test_multiple_iter{gen_type}", + gen_type, + payload_overrides={"n_iter": 2}, + unit_overrides={}, + input_image=girl_img, + ).exec() + + +@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"]) +def test_batch_size(gen_type): + assert APITestTemplate( + f"test_batch_size{gen_type}", + gen_type, + payload_overrides={"batch_size": 2}, + unit_overrides={}, + input_image=girl_img, + ).exec() + + +@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"]) +def test_2_units(gen_type): + assert APITestTemplate( + f"test_2_units{gen_type}", + gen_type, + payload_overrides={}, + unit_overrides=[{}, {}], + input_image=girl_img, + ).exec() + + +@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"]) +def test_preprocessor(gen_type): + assert APITestTemplate( + f"test_preprocessor{gen_type}", + gen_type, + payload_overrides={}, + unit_overrides={"module": "canny"}, + input_image=girl_img, + ).exec() + + +@pytest.mark.parametrize("param_name", ("processor_res", "threshold_a", "threshold_b")) +@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"]) +def test_invalid_param(gen_type, param_name): + assert APITestTemplate( + f"test_invalid_param{(gen_type, param_name)}", + gen_type, + payload_overrides={}, + unit_overrides={param_name: -1}, + input_image=girl_img, + ).exec() + + +@pytest.mark.parametrize("save_map", [True, False]) +@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"]) +def test_save_map(gen_type, save_map): + assert APITestTemplate( + f"test_save_map{(gen_type, save_map)}", + gen_type, + payload_overrides={}, + unit_overrides={"save_detected_map": save_map}, + input_image=girl_img, + ).exec(expected_output_num=2 if save_map else 1) diff --git a/extensions-builtin/sd_forge_controlnet/tests/web_api/template.py b/extensions-builtin/sd_forge_controlnet/tests/web_api/template.py new file mode 100644 index 00000000..bc3b73ae --- /dev/null +++ b/extensions-builtin/sd_forge_controlnet/tests/web_api/template.py @@ -0,0 +1,329 @@ +import io +import os +import cv2 +import base64 +from typing import Dict, Any, List, Union, Literal, Optional +from pathlib import Path +import datetime +from enum import Enum +import numpy as np + +import requests +from PIL import Image + + +PayloadOverrideType = Dict[str, Any] + +timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") +test_result_dir = Path(__file__).parent / "results" / f"test_result_{timestamp}" +test_expectation_dir = Path(__file__).parent / "expectations" +os.makedirs(test_expectation_dir, exist_ok=True) +resource_dir = Path(__file__).parents[1] / "images" + + +def read_image(img_path: Path) -> str: + img = cv2.imread(str(img_path)) + _, bytes = cv2.imencode(".png", img) + encoded_image = base64.b64encode(bytes).decode("utf-8") + return encoded_image + + +def read_image_dir(img_dir: Path, suffixes=('.png', '.jpg', '.jpeg', '.webp')) -> List[str]: + """Try read all images in given img_dir.""" + img_dir = str(img_dir) + images = [] + for filename in os.listdir(img_dir): + if filename.endswith(suffixes): + img_path = os.path.join(img_dir, filename) + try: + images.append(read_image(img_path)) + except IOError: + print(f"Error opening {img_path}") + return images + + +girl_img = read_image(resource_dir / "1girl.png") +mask_img = read_image(resource_dir / "mask.png") +mask_small_img = read_image(resource_dir / "mask_small.png") + + +general_negative_prompt = """ +(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, +((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, +backlight,(ugly:1.331), (duplicate:1.331), (morbid:1.21), (mutilated:1.21), +(tranny:1.331), mutated hands, (poorly drawn hands:1.331), blurry, (bad anatomy:1.21), +(bad proportions:1.331), extra limbs, (missing arms:1.331), (extra legs:1.331), +(fused fingers:1.61051), (too many fingers:1.61051), (unclear eyes:1.331), bad hands, +missing fingers, extra digit, bad body, easynegative, nsfw""" + +class StableDiffusionVersion(Enum): + """The version family of stable diffusion model.""" + + UNKNOWN = 0 + SD1x = 1 + SD2x = 2 + SDXL = 3 + + +sd_version = StableDiffusionVersion( + int(os.environ.get("CONTROLNET_TEST_SD_VERSION", StableDiffusionVersion.SD1x.value)) +) + +is_full_coverage = os.environ.get("CONTROLNET_TEST_FULL_COVERAGE", None) is not None + + +class APITestTemplate: + is_set_expectation_run = os.environ.get("CONTROLNET_SET_EXP", "True") == "True" + is_cq_run = bool(os.environ.get("FORGE_CQ_TEST", "")) + BASE_URL = "http://localhost:7860/" + + def __init__( + self, + name: str, + gen_type: Union[Literal["img2img"], Literal["txt2img"]], + payload_overrides: PayloadOverrideType, + unit_overrides: Union[PayloadOverrideType, List[PayloadOverrideType]], + input_image: Optional[str] = None, + ): + self.name = name + self.url = APITestTemplate.BASE_URL + "sdapi/v1/" + gen_type + self.payload = { + **(txt2img_payload if gen_type == "txt2img" else img2img_payload), + **payload_overrides, + } + if gen_type == "img2img" and input_image is not None: + self.payload["init_images"] = [input_image] + + # CQ runs on CPU. Reduce steps to increase test speed. + if "steps" not in payload_overrides and APITestTemplate.is_cq_run: + self.payload["steps"] = 3 + + unit_overrides = ( + unit_overrides + if isinstance(unit_overrides, (list, tuple)) + else [unit_overrides] + ) + self.payload["alwayson_scripts"]["ControlNet"]["args"] = [ + { + **default_unit, + **unit_override, + **({"image": input_image} if gen_type == "txt2img" and input_image is not None else {}), + } + for unit_override in unit_overrides + ] + self.active_unit_count = len(unit_overrides) + + def exec(self, *args, **kwargs) -> bool: + if APITestTemplate.is_cq_run: + return self.exec_cq(*args, **kwargs) + else: + return self.exec_local(*args, **kwargs) + + def exec_cq(self, expected_output_num: Optional[int] = None, *args, **kwargs) -> bool: + """Execute test in CQ environment.""" + res = requests.post(url=self.url, json=self.payload) + if res.status_code != 200: + print(f"Unexpected status code {res.status_code}") + return False + + response = res.json() + if "images" not in response: + print(response.keys()) + return False + + if expected_output_num is None: + expected_output_num = self.payload["n_iter"] * self.payload["batch_size"] + self.active_unit_count + + if len(response["images"]) != expected_output_num: + print(f"{len(response['images'])} != {expected_output_num}") + return False + + return True + + def exec_local(self, result_only: bool = True, *args, **kwargs) -> bool: + """Execute test in local environment.""" + if not APITestTemplate.is_set_expectation_run: + os.makedirs(test_result_dir, exist_ok=True) + + failed = False + + response = requests.post(url=self.url, json=self.payload).json() + if "images" not in response: + print(response.keys()) + return False + + dest_dir = ( + test_expectation_dir + if APITestTemplate.is_set_expectation_run + else test_result_dir + ) + results = response["images"][:1] if result_only else response["images"] + for i, base64image in enumerate(results): + img_file_name = f"{self.name}_{i}.png" + Image.open(io.BytesIO(base64.b64decode(base64image.split(",", 1)[0]))).save( + dest_dir / img_file_name + ) + + if not APITestTemplate.is_set_expectation_run: + try: + img1 = cv2.imread(os.path.join(test_expectation_dir, img_file_name)) + img2 = cv2.imread(os.path.join(test_result_dir, img_file_name)) + except Exception as e: + print(f"Get exception reading imgs: {e}") + failed = True + continue + + if img1 is None: + print(f"Warn: No expectation file found {img_file_name}.") + continue + + if not expect_same_image( + img1, + img2, + diff_img_path=str(test_result_dir + / img_file_name.replace(".png", "_diff.png")), + ): + failed = True + return not failed + + +def expect_same_image(img1, img2, diff_img_path: str) -> bool: + # Calculate the difference between the two images + diff = cv2.absdiff(img1, img2) + + # Set a threshold to highlight the different pixels + threshold = 30 + diff_highlighted = np.where(diff > threshold, 255, 0).astype(np.uint8) + + # Assert that the two images are similar within a tolerance + similar = np.allclose(img1, img2, rtol=0.5, atol=1) + if not similar: + # Save the diff_highlighted image to inspect the differences + cv2.imwrite(diff_img_path, diff_highlighted) + + matching_pixels = np.isclose(img1, img2, rtol=0.5, atol=1) + similar_in_general = (matching_pixels.sum() / matching_pixels.size) >= 0.95 + return similar_in_general + + +def get_model(model_name: str) -> str: + """ Find an available model with specified model name.""" + if model_name.lower() == "none": + return "None" + + r = requests.get(APITestTemplate.BASE_URL + "controlnet/model_list") + result = r.json() + if "model_list" not in result: + raise ValueError("No model available") + + candidates = [ + model + for model in result["model_list"] + if model_name.lower() in model.lower() + ] + + if not candidates: + raise ValueError("No suitable model available") + + return candidates[0] + + +default_unit = { + "control_mode": 0, + "enabled": True, + "guidance_end": 1, + "guidance_start": 0, + "low_vram": False, + "pixel_perfect": True, + "processor_res": 512, + "resize_mode": 1, + "threshold_a": 64, + "threshold_b": 64, + "weight": 1, + "module": "None", + "model": get_model("sd15_canny"), +} + +img2img_payload = { + "batch_size": 1, + "cfg_scale": 7, + "height": 768, + "width": 512, + "n_iter": 1, + "steps": 10, + "sampler_name": "Euler a", + "prompt": "(masterpiece: 1.3), (highres: 1.3), best quality,", + "negative_prompt": "", + "seed": 42, + "seed_enable_extras": False, + "seed_resize_from_h": 0, + "seed_resize_from_w": 0, + "subseed": -1, + "subseed_strength": 0, + "override_settings": {}, + "override_settings_restore_afterwards": False, + "do_not_save_grid": False, + "do_not_save_samples": False, + "s_churn": 0, + "s_min_uncond": 0, + "s_noise": 1, + "s_tmax": None, + "s_tmin": 0, + "script_args": [], + "script_name": None, + "styles": [], + "alwayson_scripts": {"ControlNet": {"args": [default_unit]}}, + "denoising_strength": 0.75, + "initial_noise_multiplier": 1, + "inpaint_full_res": 0, + "inpaint_full_res_padding": 32, + "inpainting_fill": 1, + "inpainting_mask_invert": 0, + "mask_blur_x": 4, + "mask_blur_y": 4, + "mask_blur": 4, + "resize_mode": 0, +} + +txt2img_payload = { + "alwayson_scripts": {"ControlNet": {"args": [default_unit]}}, + "batch_size": 1, + "cfg_scale": 7, + "comments": {}, + "disable_extra_networks": False, + "do_not_save_grid": False, + "do_not_save_samples": False, + "enable_hr": False, + "height": 768, + "hr_negative_prompt": "", + "hr_prompt": "", + "hr_resize_x": 0, + "hr_resize_y": 0, + "hr_scale": 2, + "hr_second_pass_steps": 0, + "hr_upscaler": "Latent", + "n_iter": 1, + "negative_prompt": "", + "override_settings": {}, + "override_settings_restore_afterwards": True, + "prompt": "(masterpiece: 1.3), (highres: 1.3), best quality,", + "restore_faces": False, + "s_churn": 0.0, + "s_min_uncond": 0, + "s_noise": 1.0, + "s_tmax": None, + "s_tmin": 0.0, + "sampler_name": "Euler a", + "script_args": [], + "script_name": None, + "seed": 42, + "seed_enable_extras": True, + "seed_resize_from_h": -1, + "seed_resize_from_w": -1, + "steps": 10, + "styles": [], + "subseed": -1, + "subseed_strength": 0, + "tiling": False, + "width": 512, +} diff --git a/modules/api/api.py b/modules/api/api.py index 4e656082..d5348bb2 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -2,11 +2,13 @@ import base64 import io import os import time +import itertools import datetime import uvicorn import ipaddress import requests import gradio as gr +import numpy as np from threading import Lock from io import BytesIO from fastapi import APIRouter, Depends, FastAPI, Request, Response @@ -103,6 +105,8 @@ def encode_pil_to_base64(image): with io.BytesIO() as output_bytes: if isinstance(image, str): return image + if isinstance(image, np.ndarray): + image = Image.fromarray(image) if opts.samples_format.lower() == 'png': use_metadata = False metadata = PngImagePlugin.PngInfo() @@ -480,7 +484,11 @@ class Api: shared.state.end() shared.total_tqdm.clear() - b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] + b64images = [ + encode_pil_to_base64(image) + for image in itertools.chain(processed.images, processed.extra_images) + if send_images + ] return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) @@ -547,7 +555,11 @@ class Api: shared.state.end() shared.total_tqdm.clear() - b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] + b64images = [ + encode_pil_to_base64(image) + for image in itertools.chain(processed.images, processed.extra_images) + if send_images + ] if not img2imgreq.include_init_images: img2imgreq.init_images = None