ControlNet API (#162)

* ControlNet API

* update cache key

* nits

* disable controlnet tests
This commit is contained in:
Chenlei Hu 2024-02-10 06:16:13 +00:00 committed by GitHub
parent bd0878754c
commit 5a7e755528
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 630 additions and 30 deletions

View File

@ -52,6 +52,18 @@ jobs:
curl -Lo "$filename" "$url" curl -Lo "$filename" "$url"
fi fi
done done
# - name: Download ControlNet models
# run: |
# declare -a urls=(
# "https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_canny.pth"
# )
# for url in "${urls[@]}"; do
# filename="models/ControlNet/${url##*/}" # Extracts the last part of the URL
# if [ ! -f "$filename" ]; then
# curl -Lo "$filename" "$url"
# fi
# done
- name: Start test server - name: Start test server
run: > run: >
python -m coverage run python -m coverage run
@ -71,6 +83,16 @@ jobs:
run: | run: |
wait-for-it --service 127.0.0.1:7860 -t 20 wait-for-it --service 127.0.0.1:7860 -t 20
python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test
# TODO(huchenlei): Enable ControlNet tests. Currently it is too slow to run these tests on CPU with
# real SD model. We need to find a way to load empty SD model.
# - name: Run ControlNet tests
# run: >
# python -m pytest
# --junitxml=test/results.xml
# --cov ./extensions-builtin/sd_forge_controlnet
# --cov-report=xml
# --verify-base-url
# ./extensions-builtin/sd_forge_controlnet/tests
- name: Kill test server - name: Kill test server
if: always() if: always()
run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10 run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10

3
.gitignore vendored
View File

@ -40,4 +40,5 @@ notification.mp3
/.coverage* /.coverage*
/test/test_outputs /test/test_outputs
/test/results.xml /test/results.xml
coverage.xml coverage.xml
**/tests/**/expectations

View File

@ -0,0 +1,108 @@
from typing import List
import numpy as np
from fastapi import FastAPI, Body
from fastapi.exceptions import HTTPException
from PIL import Image
import gradio as gr
from modules.api import api
from .global_state import (
get_all_preprocessor_names,
get_all_controlnet_names,
get_preprocessor,
)
from .logging import logger
def encode_to_base64(image):
if isinstance(image, str):
return image
elif isinstance(image, Image.Image):
return api.encode_pil_to_base64(image)
elif isinstance(image, np.ndarray):
return encode_np_to_base64(image)
else:
return ""
def encode_np_to_base64(image):
pil = Image.fromarray(image)
return api.encode_pil_to_base64(pil)
def controlnet_api(_: gr.Blocks, app: FastAPI):
@app.get("/controlnet/model_list")
async def model_list():
up_to_date_model_list = get_all_controlnet_names()
logger.debug(up_to_date_model_list)
return {"model_list": up_to_date_model_list}
@app.get("/controlnet/module_list")
async def module_list():
module_list = get_all_preprocessor_names()
logger.debug(module_list)
return {
"module_list": module_list,
# TODO: Add back module detail.
# "module_detail": external_code.get_modules_detail(alias_names),
}
@app.post("/controlnet/detect")
async def detect(
controlnet_module: str = Body("none", title="Controlnet Module"),
controlnet_input_images: List[str] = Body([], title="Controlnet Input Images"),
controlnet_processor_res: int = Body(
512, title="Controlnet Processor Resolution"
),
controlnet_threshold_a: float = Body(64, title="Controlnet Threshold a"),
controlnet_threshold_b: float = Body(64, title="Controlnet Threshold b"),
):
processor_module = get_preprocessor(controlnet_module)
if processor_module is None:
raise HTTPException(status_code=422, detail="Module not available")
if len(controlnet_input_images) == 0:
raise HTTPException(status_code=422, detail="No image selected")
logger.debug(
f"Detecting {str(len(controlnet_input_images))} images with the {controlnet_module} module."
)
results = []
poses = []
for input_image in controlnet_input_images:
img = np.array(api.decode_base64_to_image(input_image)).astype('uint8')
class JsonAcceptor:
def __init__(self) -> None:
self.value = None
def accept(self, json_dict: dict) -> None:
self.value = json_dict
json_acceptor = JsonAcceptor()
results.append(
processor_module(
img,
res=controlnet_processor_res,
thr_a=controlnet_threshold_a,
thr_b=controlnet_threshold_b,
json_pose_callback=json_acceptor.accept,
)[0]
)
if "openpose" in controlnet_module:
assert json_acceptor.value is not None
poses.append(json_acceptor.value)
results64 = list(map(encode_to_base64, results))
res = {"images": results64, "info": "Success"}
if poses:
res["poses"] = poses
return res

View File

@ -850,7 +850,6 @@ class ControlNetUiGroup(object):
slider_1=pthr_a, slider_1=pthr_a,
slider_2=pthr_b, slider_2=pthr_b,
input_mask=mask, input_mask=mask,
low_vram=shared.opts.data.get("controlnet_clip_detector_on_cpu", False),
json_pose_callback=json_acceptor.accept json_pose_callback=json_acceptor.accept
if is_openpose(module) if is_openpose(module)
else None, else None,

View File

@ -148,13 +148,15 @@ InputImage = Union[Dict[str, InputImage], Tuple[InputImage, InputImage], InputIm
@dataclass @dataclass
class UiControlNetUnit: class UiControlNetUnit:
input_mode: InputMode = InputMode.SIMPLE input_mode: InputMode = InputMode.SIMPLE
use_preview_as_input: bool = False, use_preview_as_input: bool = False
batch_image_dir: str = '', batch_image_dir: str = ''
batch_mask_dir: str = '', batch_mask_dir: str = ''
batch_input_gallery: list = [], batch_input_gallery: Optional[List[str]] = None
batch_mask_gallery: list = [], batch_mask_gallery: Optional[List[str]] = None
generated_image: Optional[np.ndarray] = None, generated_image: Optional[np.ndarray] = None
mask_image: Optional[np.ndarray] = None, mask_image: Optional[np.ndarray] = None
# If hires fix is enabled in A1111, how should this ControlNet unit be applied.
# The value is ignored if the generation is not using hires fix.
hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH hr_option: Union[HiResFixOption, int, str] = HiResFixOption.BOTH
enabled: bool = True enabled: bool = True
module: str = "None" module: str = "None"
@ -169,6 +171,13 @@ class UiControlNetUnit:
guidance_end: float = 1.0 guidance_end: float = 1.0
pixel_perfect: bool = False pixel_perfect: bool = False
control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED
# ====== Start of API only fields ======
# Whether save the detected map of this unit. Setting this option to False
# prevents saving the detected map or sending detected map along with
# generated images via API. Currently the option is only accessible in API
# calls.
save_detected_map: bool = True
# ====== End of API only fields ======
@staticmethod @staticmethod
def infotext_fields(): def infotext_fields():
@ -192,6 +201,23 @@ class UiControlNetUnit:
"hr_option", "hr_option",
) )
@staticmethod
def from_dict(d: Dict) -> "UiControlNetUnit":
"""Create UiControlNetUnit from dict. This is primarily used to convert
API json dict to UiControlNetUnit."""
unit = UiControlNetUnit(
**{k: v for k, v in d.items() if k in vars(UiControlNetUnit)}
)
if isinstance(unit.image, str):
img = np.array(api.decode_base64_to_image(unit.image)).astype('uint8')
unit.image = {
"image": img,
"mask": np.zeros_like(img),
}
if isinstance(unit.mask_image, str):
unit.mask_image = np.array(api.decode_base64_to_image(unit.mask_image)).astype('uint8')
return unit
# Backward Compatible # Backward Compatible
ControlNetUnit = UiControlNetUnit ControlNetUnit = UiControlNetUnit

View File

@ -11,9 +11,10 @@ from modules.api.api import decode_base64_to_image
import gradio as gr import gradio as gr
from lib_controlnet import global_state, external_code from lib_controlnet import global_state, external_code
from lib_controlnet.external_code import ControlNetUnit
from lib_controlnet.utils import align_dim_latent, image_dict_from_any, set_numpy_seed, crop_and_resize_image, \ from lib_controlnet.utils import align_dim_latent, image_dict_from_any, set_numpy_seed, crop_and_resize_image, \
prepare_mask, judge_image_type prepare_mask, judge_image_type
from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup, UiControlNetUnit from lib_controlnet.controlnet_ui.controlnet_ui_group import ControlNetUiGroup
from lib_controlnet.controlnet_ui.photopea import Photopea from lib_controlnet.controlnet_ui.photopea import Photopea
from lib_controlnet.logging import logger from lib_controlnet.logging import logger
from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, \ from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessingTxt2Img, \
@ -21,6 +22,7 @@ from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusion
from lib_controlnet.infotext import Infotext from lib_controlnet.infotext import Infotext
from modules_forge.forge_util import HWC3, numpy_to_pytorch from modules_forge.forge_util import HWC3, numpy_to_pytorch
from lib_controlnet.enums import HiResFixOption from lib_controlnet.enums import HiResFixOption
from lib_controlnet.api import controlnet_api
import numpy as np import numpy as np
import functools import functools
@ -67,7 +69,7 @@ class ControlNetForForgeOfficial(scripts.Script):
max_models = shared.opts.data.get("control_net_unit_count", 3) max_models = shared.opts.data.get("control_net_unit_count", 3)
gen_type = "img2img" if is_img2img else "txt2img" gen_type = "img2img" if is_img2img else "txt2img"
elem_id_tabname = gen_type + "_controlnet" elem_id_tabname = gen_type + "_controlnet"
default_unit = UiControlNetUnit(enabled=False, module="None", model="None") default_unit = ControlNetUnit(enabled=False, module="None", model="None")
with gr.Group(elem_id=elem_id_tabname): with gr.Group(elem_id=elem_id_tabname):
with gr.Accordion(f"ControlNet Integrated", open=False, elem_id="controlnet", with gr.Accordion(f"ControlNet Integrated", open=False, elem_id="controlnet",
elem_classes=["controlnet"]): elem_classes=["controlnet"]):
@ -95,13 +97,19 @@ class ControlNetForForgeOfficial(scripts.Script):
return tuple(controls) return tuple(controls)
def get_enabled_units(self, units): def get_enabled_units(self, units):
# Parse dict from API calls.
units = [
ControlNetUnit.from_dict(unit) if isinstance(unit, dict) else unit
for unit in units
]
assert all(isinstance(unit, ControlNetUnit) for unit in units)
enabled_units = [x for x in units if x.enabled] enabled_units = [x for x in units if x.enabled]
return enabled_units return enabled_units
@staticmethod @staticmethod
def try_crop_image_with_a1111_mask( def try_crop_image_with_a1111_mask(
p: StableDiffusionProcessing, p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit, unit: ControlNetUnit,
input_image: np.ndarray, input_image: np.ndarray,
resize_mode: external_code.ResizeMode, resize_mode: external_code.ResizeMode,
preprocessor preprocessor
@ -252,7 +260,7 @@ class ControlNetForForgeOfficial(scripts.Script):
@torch.no_grad() @torch.no_grad()
def process_unit_after_click_generate(self, def process_unit_after_click_generate(self,
p: StableDiffusionProcessing, p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit, unit: ControlNetUnit,
params: ControlNetCachedParameters, params: ControlNetCachedParameters,
*args, **kwargs): *args, **kwargs):
@ -279,8 +287,6 @@ class ControlNetForForgeOfficial(scripts.Script):
return tqdm(iterable) if use_tqdm else iterable return tqdm(iterable) if use_tqdm else iterable
for input_image, input_mask in optional_tqdm(input_list, len(input_list) > 1): for input_image, input_mask in optional_tqdm(input_list, len(input_list) > 1):
# p.extra_result_images.append(input_image)
if unit.pixel_perfect: if unit.pixel_perfect:
unit.processor_res = external_code.pixel_perfect_resolution( unit.processor_res = external_code.pixel_perfect_resolution(
input_image, input_image,
@ -319,14 +325,20 @@ class ControlNetForForgeOfficial(scripts.Script):
hr_option = HiResFixOption.BOTH hr_option = HiResFixOption.BOTH
alignment_indices = [i % len(preprocessor_outputs) for i in range(p.batch_size)] alignment_indices = [i % len(preprocessor_outputs) for i in range(p.batch_size)]
def attach_extra_result_image(img: np.ndarray, is_high_res: bool = False):
if (
(is_high_res and hr_option.high_res_enabled) or
(not is_high_res and hr_option.low_res_enabled)
) and unit.save_detected_map:
p.extra_result_images.append(img)
if preprocessor_output_is_image: if preprocessor_output_is_image:
params.control_cond = [] params.control_cond = []
params.control_cond_for_hr_fix = [] params.control_cond_for_hr_fix = []
for preprocessor_output in preprocessor_outputs: for preprocessor_output in preprocessor_outputs:
control_cond = crop_and_resize_image(preprocessor_output, resize_mode, h, w) control_cond = crop_and_resize_image(preprocessor_output, resize_mode, h, w)
if hr_option.low_res_enabled: attach_extra_result_image(external_code.visualize_inpaint_mask(control_cond))
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond))
params.control_cond.append(numpy_to_pytorch(control_cond).movedim(-1, 1)) params.control_cond.append(numpy_to_pytorch(control_cond).movedim(-1, 1))
params.control_cond = torch.cat(params.control_cond, dim=0)[alignment_indices].contiguous() params.control_cond = torch.cat(params.control_cond, dim=0)[alignment_indices].contiguous()
@ -334,8 +346,7 @@ class ControlNetForForgeOfficial(scripts.Script):
if has_high_res_fix: if has_high_res_fix:
for preprocessor_output in preprocessor_outputs: for preprocessor_output in preprocessor_outputs:
control_cond_for_hr_fix = crop_and_resize_image(preprocessor_output, resize_mode, hr_y, hr_x) control_cond_for_hr_fix = crop_and_resize_image(preprocessor_output, resize_mode, hr_y, hr_x)
if hr_option.high_res_enabled: attach_extra_result_image(external_code.visualize_inpaint_mask(control_cond_for_hr_fix), is_high_res=True)
p.extra_result_images.append(external_code.visualize_inpaint_mask(control_cond_for_hr_fix))
params.control_cond_for_hr_fix.append(numpy_to_pytorch(control_cond_for_hr_fix).movedim(-1, 1)) params.control_cond_for_hr_fix.append(numpy_to_pytorch(control_cond_for_hr_fix).movedim(-1, 1))
params.control_cond_for_hr_fix = torch.cat(params.control_cond_for_hr_fix, dim=0)[alignment_indices].contiguous() params.control_cond_for_hr_fix = torch.cat(params.control_cond_for_hr_fix, dim=0)[alignment_indices].contiguous()
else: else:
@ -343,7 +354,7 @@ class ControlNetForForgeOfficial(scripts.Script):
else: else:
params.control_cond = preprocessor_output params.control_cond = preprocessor_output
params.control_cond_for_hr_fix = preprocessor_output params.control_cond_for_hr_fix = preprocessor_output
p.extra_result_images.append(input_image) attach_extra_result_image(input_image)
if len(control_masks) > 0: if len(control_masks) > 0:
params.control_mask = [] params.control_mask = []
@ -352,15 +363,13 @@ class ControlNetForForgeOfficial(scripts.Script):
for input_mask in control_masks: for input_mask in control_masks:
fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill
control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border) control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border)
if hr_option.low_res_enabled: attach_extra_result_image(control_mask)
p.extra_result_images.append(control_mask)
control_mask = numpy_to_pytorch(control_mask).movedim(-1, 1)[:, :1] control_mask = numpy_to_pytorch(control_mask).movedim(-1, 1)[:, :1]
params.control_mask.append(control_mask) params.control_mask.append(control_mask)
if has_high_res_fix: if has_high_res_fix:
control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border) control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border)
if hr_option.high_res_enabled: attach_extra_result_image(control_mask_for_hr_fix, is_high_res=True)
p.extra_result_images.append(control_mask_for_hr_fix)
control_mask_for_hr_fix = numpy_to_pytorch(control_mask_for_hr_fix).movedim(-1, 1)[:, :1] control_mask_for_hr_fix = numpy_to_pytorch(control_mask_for_hr_fix).movedim(-1, 1)[:, :1]
params.control_mask_for_hr_fix.append(control_mask_for_hr_fix) params.control_mask_for_hr_fix.append(control_mask_for_hr_fix)
@ -390,7 +399,7 @@ class ControlNetForForgeOfficial(scripts.Script):
@torch.no_grad() @torch.no_grad()
def process_unit_before_every_sampling(self, def process_unit_before_every_sampling(self,
p: StableDiffusionProcessing, p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit, unit: ControlNetUnit,
params: ControlNetCachedParameters, params: ControlNetCachedParameters,
*args, **kwargs): *args, **kwargs):
@ -473,14 +482,14 @@ class ControlNetForForgeOfficial(scripts.Script):
return return
@staticmethod @staticmethod
def bound_check_params(unit: external_code.ControlNetUnit) -> None: def bound_check_params(unit: ControlNetUnit) -> None:
""" """
Checks and corrects negative parameters in ControlNetUnit 'unit'. Checks and corrects negative parameters in ControlNetUnit 'unit'.
Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to
their default values if negative. their default values if negative.
Args: Args:
unit (external_code.ControlNetUnit): The ControlNetUnit instance to check. unit (ControlNetUnit): The ControlNetUnit instance to check.
""" """
preprocessor = global_state.get_preprocessor(unit.module) preprocessor = global_state.get_preprocessor(unit.module)
@ -498,7 +507,7 @@ class ControlNetForForgeOfficial(scripts.Script):
@torch.no_grad() @torch.no_grad()
def process_unit_after_every_sampling(self, def process_unit_after_every_sampling(self,
p: StableDiffusionProcessing, p: StableDiffusionProcessing,
unit: external_code.ControlNetUnit, unit: ControlNetUnit,
params: ControlNetCachedParameters, params: ControlNetCachedParameters,
*args, **kwargs): *args, **kwargs):
@ -577,3 +586,4 @@ script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_infotext_pasted(Infotext.on_infotext_pasted) script_callbacks.on_infotext_pasted(Infotext.on_infotext_pasted)
script_callbacks.on_after_component(ControlNetUiGroup.on_after_component) script_callbacks.on_after_component(ControlNetUiGroup.on_after_component)
script_callbacks.on_before_reload(ControlNetUiGroup.reset) script_callbacks.on_before_reload(ControlNetUiGroup.reset)
script_callbacks.on_app_started(controlnet_api)

View File

@ -0,0 +1,8 @@
import os
def pytest_configure(config):
# We don't want to fail on Py.test command line arguments being
# parsed by webui:
os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1")
os.environ.setdefault("FORGE_CQ_TEST", "1")

Binary file not shown.

After

Width:  |  Height:  |  Size: 482 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 244 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 226 B

View File

@ -0,0 +1,85 @@
import pytest
from .template import (
APITestTemplate,
girl_img,
)
@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"])
def test_no_unit(gen_type):
assert APITestTemplate(
f"test_no_unit{gen_type}",
gen_type,
payload_overrides={},
unit_overrides=[],
input_image=girl_img,
).exec()
@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"])
def test_multiple_iter(gen_type):
assert APITestTemplate(
f"test_multiple_iter{gen_type}",
gen_type,
payload_overrides={"n_iter": 2},
unit_overrides={},
input_image=girl_img,
).exec()
@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"])
def test_batch_size(gen_type):
assert APITestTemplate(
f"test_batch_size{gen_type}",
gen_type,
payload_overrides={"batch_size": 2},
unit_overrides={},
input_image=girl_img,
).exec()
@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"])
def test_2_units(gen_type):
assert APITestTemplate(
f"test_2_units{gen_type}",
gen_type,
payload_overrides={},
unit_overrides=[{}, {}],
input_image=girl_img,
).exec()
@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"])
def test_preprocessor(gen_type):
assert APITestTemplate(
f"test_preprocessor{gen_type}",
gen_type,
payload_overrides={},
unit_overrides={"module": "canny"},
input_image=girl_img,
).exec()
@pytest.mark.parametrize("param_name", ("processor_res", "threshold_a", "threshold_b"))
@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"])
def test_invalid_param(gen_type, param_name):
assert APITestTemplate(
f"test_invalid_param{(gen_type, param_name)}",
gen_type,
payload_overrides={},
unit_overrides={param_name: -1},
input_image=girl_img,
).exec()
@pytest.mark.parametrize("save_map", [True, False])
@pytest.mark.parametrize("gen_type", ["img2img", "txt2img"])
def test_save_map(gen_type, save_map):
assert APITestTemplate(
f"test_save_map{(gen_type, save_map)}",
gen_type,
payload_overrides={},
unit_overrides={"save_detected_map": save_map},
input_image=girl_img,
).exec(expected_output_num=2 if save_map else 1)

View File

@ -0,0 +1,329 @@
import io
import os
import cv2
import base64
from typing import Dict, Any, List, Union, Literal, Optional
from pathlib import Path
import datetime
from enum import Enum
import numpy as np
import requests
from PIL import Image
PayloadOverrideType = Dict[str, Any]
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
test_result_dir = Path(__file__).parent / "results" / f"test_result_{timestamp}"
test_expectation_dir = Path(__file__).parent / "expectations"
os.makedirs(test_expectation_dir, exist_ok=True)
resource_dir = Path(__file__).parents[1] / "images"
def read_image(img_path: Path) -> str:
img = cv2.imread(str(img_path))
_, bytes = cv2.imencode(".png", img)
encoded_image = base64.b64encode(bytes).decode("utf-8")
return encoded_image
def read_image_dir(img_dir: Path, suffixes=('.png', '.jpg', '.jpeg', '.webp')) -> List[str]:
"""Try read all images in given img_dir."""
img_dir = str(img_dir)
images = []
for filename in os.listdir(img_dir):
if filename.endswith(suffixes):
img_path = os.path.join(img_dir, filename)
try:
images.append(read_image(img_path))
except IOError:
print(f"Error opening {img_path}")
return images
girl_img = read_image(resource_dir / "1girl.png")
mask_img = read_image(resource_dir / "mask.png")
mask_small_img = read_image(resource_dir / "mask_small.png")
general_negative_prompt = """
(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality,
((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot,
backlight,(ugly:1.331), (duplicate:1.331), (morbid:1.21), (mutilated:1.21),
(tranny:1.331), mutated hands, (poorly drawn hands:1.331), blurry, (bad anatomy:1.21),
(bad proportions:1.331), extra limbs, (missing arms:1.331), (extra legs:1.331),
(fused fingers:1.61051), (too many fingers:1.61051), (unclear eyes:1.331), bad hands,
missing fingers, extra digit, bad body, easynegative, nsfw"""
class StableDiffusionVersion(Enum):
"""The version family of stable diffusion model."""
UNKNOWN = 0
SD1x = 1
SD2x = 2
SDXL = 3
sd_version = StableDiffusionVersion(
int(os.environ.get("CONTROLNET_TEST_SD_VERSION", StableDiffusionVersion.SD1x.value))
)
is_full_coverage = os.environ.get("CONTROLNET_TEST_FULL_COVERAGE", None) is not None
class APITestTemplate:
is_set_expectation_run = os.environ.get("CONTROLNET_SET_EXP", "True") == "True"
is_cq_run = bool(os.environ.get("FORGE_CQ_TEST", ""))
BASE_URL = "http://localhost:7860/"
def __init__(
self,
name: str,
gen_type: Union[Literal["img2img"], Literal["txt2img"]],
payload_overrides: PayloadOverrideType,
unit_overrides: Union[PayloadOverrideType, List[PayloadOverrideType]],
input_image: Optional[str] = None,
):
self.name = name
self.url = APITestTemplate.BASE_URL + "sdapi/v1/" + gen_type
self.payload = {
**(txt2img_payload if gen_type == "txt2img" else img2img_payload),
**payload_overrides,
}
if gen_type == "img2img" and input_image is not None:
self.payload["init_images"] = [input_image]
# CQ runs on CPU. Reduce steps to increase test speed.
if "steps" not in payload_overrides and APITestTemplate.is_cq_run:
self.payload["steps"] = 3
unit_overrides = (
unit_overrides
if isinstance(unit_overrides, (list, tuple))
else [unit_overrides]
)
self.payload["alwayson_scripts"]["ControlNet"]["args"] = [
{
**default_unit,
**unit_override,
**({"image": input_image} if gen_type == "txt2img" and input_image is not None else {}),
}
for unit_override in unit_overrides
]
self.active_unit_count = len(unit_overrides)
def exec(self, *args, **kwargs) -> bool:
if APITestTemplate.is_cq_run:
return self.exec_cq(*args, **kwargs)
else:
return self.exec_local(*args, **kwargs)
def exec_cq(self, expected_output_num: Optional[int] = None, *args, **kwargs) -> bool:
"""Execute test in CQ environment."""
res = requests.post(url=self.url, json=self.payload)
if res.status_code != 200:
print(f"Unexpected status code {res.status_code}")
return False
response = res.json()
if "images" not in response:
print(response.keys())
return False
if expected_output_num is None:
expected_output_num = self.payload["n_iter"] * self.payload["batch_size"] + self.active_unit_count
if len(response["images"]) != expected_output_num:
print(f"{len(response['images'])} != {expected_output_num}")
return False
return True
def exec_local(self, result_only: bool = True, *args, **kwargs) -> bool:
"""Execute test in local environment."""
if not APITestTemplate.is_set_expectation_run:
os.makedirs(test_result_dir, exist_ok=True)
failed = False
response = requests.post(url=self.url, json=self.payload).json()
if "images" not in response:
print(response.keys())
return False
dest_dir = (
test_expectation_dir
if APITestTemplate.is_set_expectation_run
else test_result_dir
)
results = response["images"][:1] if result_only else response["images"]
for i, base64image in enumerate(results):
img_file_name = f"{self.name}_{i}.png"
Image.open(io.BytesIO(base64.b64decode(base64image.split(",", 1)[0]))).save(
dest_dir / img_file_name
)
if not APITestTemplate.is_set_expectation_run:
try:
img1 = cv2.imread(os.path.join(test_expectation_dir, img_file_name))
img2 = cv2.imread(os.path.join(test_result_dir, img_file_name))
except Exception as e:
print(f"Get exception reading imgs: {e}")
failed = True
continue
if img1 is None:
print(f"Warn: No expectation file found {img_file_name}.")
continue
if not expect_same_image(
img1,
img2,
diff_img_path=str(test_result_dir
/ img_file_name.replace(".png", "_diff.png")),
):
failed = True
return not failed
def expect_same_image(img1, img2, diff_img_path: str) -> bool:
# Calculate the difference between the two images
diff = cv2.absdiff(img1, img2)
# Set a threshold to highlight the different pixels
threshold = 30
diff_highlighted = np.where(diff > threshold, 255, 0).astype(np.uint8)
# Assert that the two images are similar within a tolerance
similar = np.allclose(img1, img2, rtol=0.5, atol=1)
if not similar:
# Save the diff_highlighted image to inspect the differences
cv2.imwrite(diff_img_path, diff_highlighted)
matching_pixels = np.isclose(img1, img2, rtol=0.5, atol=1)
similar_in_general = (matching_pixels.sum() / matching_pixels.size) >= 0.95
return similar_in_general
def get_model(model_name: str) -> str:
""" Find an available model with specified model name."""
if model_name.lower() == "none":
return "None"
r = requests.get(APITestTemplate.BASE_URL + "controlnet/model_list")
result = r.json()
if "model_list" not in result:
raise ValueError("No model available")
candidates = [
model
for model in result["model_list"]
if model_name.lower() in model.lower()
]
if not candidates:
raise ValueError("No suitable model available")
return candidates[0]
default_unit = {
"control_mode": 0,
"enabled": True,
"guidance_end": 1,
"guidance_start": 0,
"low_vram": False,
"pixel_perfect": True,
"processor_res": 512,
"resize_mode": 1,
"threshold_a": 64,
"threshold_b": 64,
"weight": 1,
"module": "None",
"model": get_model("sd15_canny"),
}
img2img_payload = {
"batch_size": 1,
"cfg_scale": 7,
"height": 768,
"width": 512,
"n_iter": 1,
"steps": 10,
"sampler_name": "Euler a",
"prompt": "(masterpiece: 1.3), (highres: 1.3), best quality,",
"negative_prompt": "",
"seed": 42,
"seed_enable_extras": False,
"seed_resize_from_h": 0,
"seed_resize_from_w": 0,
"subseed": -1,
"subseed_strength": 0,
"override_settings": {},
"override_settings_restore_afterwards": False,
"do_not_save_grid": False,
"do_not_save_samples": False,
"s_churn": 0,
"s_min_uncond": 0,
"s_noise": 1,
"s_tmax": None,
"s_tmin": 0,
"script_args": [],
"script_name": None,
"styles": [],
"alwayson_scripts": {"ControlNet": {"args": [default_unit]}},
"denoising_strength": 0.75,
"initial_noise_multiplier": 1,
"inpaint_full_res": 0,
"inpaint_full_res_padding": 32,
"inpainting_fill": 1,
"inpainting_mask_invert": 0,
"mask_blur_x": 4,
"mask_blur_y": 4,
"mask_blur": 4,
"resize_mode": 0,
}
txt2img_payload = {
"alwayson_scripts": {"ControlNet": {"args": [default_unit]}},
"batch_size": 1,
"cfg_scale": 7,
"comments": {},
"disable_extra_networks": False,
"do_not_save_grid": False,
"do_not_save_samples": False,
"enable_hr": False,
"height": 768,
"hr_negative_prompt": "",
"hr_prompt": "",
"hr_resize_x": 0,
"hr_resize_y": 0,
"hr_scale": 2,
"hr_second_pass_steps": 0,
"hr_upscaler": "Latent",
"n_iter": 1,
"negative_prompt": "",
"override_settings": {},
"override_settings_restore_afterwards": True,
"prompt": "(masterpiece: 1.3), (highres: 1.3), best quality,",
"restore_faces": False,
"s_churn": 0.0,
"s_min_uncond": 0,
"s_noise": 1.0,
"s_tmax": None,
"s_tmin": 0.0,
"sampler_name": "Euler a",
"script_args": [],
"script_name": None,
"seed": 42,
"seed_enable_extras": True,
"seed_resize_from_h": -1,
"seed_resize_from_w": -1,
"steps": 10,
"styles": [],
"subseed": -1,
"subseed_strength": 0,
"tiling": False,
"width": 512,
}

View File

@ -2,11 +2,13 @@ import base64
import io import io
import os import os
import time import time
import itertools
import datetime import datetime
import uvicorn import uvicorn
import ipaddress import ipaddress
import requests import requests
import gradio as gr import gradio as gr
import numpy as np
from threading import Lock from threading import Lock
from io import BytesIO from io import BytesIO
from fastapi import APIRouter, Depends, FastAPI, Request, Response from fastapi import APIRouter, Depends, FastAPI, Request, Response
@ -103,6 +105,8 @@ def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes: with io.BytesIO() as output_bytes:
if isinstance(image, str): if isinstance(image, str):
return image return image
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
if opts.samples_format.lower() == 'png': if opts.samples_format.lower() == 'png':
use_metadata = False use_metadata = False
metadata = PngImagePlugin.PngInfo() metadata = PngImagePlugin.PngInfo()
@ -480,7 +484,11 @@ class Api:
shared.state.end() shared.state.end()
shared.total_tqdm.clear() shared.total_tqdm.clear()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] b64images = [
encode_pil_to_base64(image)
for image in itertools.chain(processed.images, processed.extra_images)
if send_images
]
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
@ -547,7 +555,11 @@ class Api:
shared.state.end() shared.state.end()
shared.total_tqdm.clear() shared.total_tqdm.clear()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] b64images = [
encode_pil_to_base64(image)
for image in itertools.chain(processed.images, processed.extra_images)
if send_images
]
if not img2imgreq.include_init_images: if not img2imgreq.include_init_images:
img2imgreq.init_images = None img2imgreq.init_images = None