Fix controlnet/detect API endpoint (#187)

This commit is contained in:
Chenlei Hu 2024-02-11 06:15:06 +00:00 committed by GitHub
parent 6a854fcb38
commit e11753ff84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 103 additions and 16 deletions

View File

@ -4,6 +4,9 @@ on:
- push - push
- pull_request - pull_request
env:
FORGE_CQ_TEST: "True"
jobs: jobs:
test: test:
name: tests on CPU name: tests on CPU

View File

@ -12,17 +12,21 @@ from .global_state import (
get_all_controlnet_names, get_all_controlnet_names,
get_preprocessor, get_preprocessor,
) )
from .utils import judge_image_type
from .logging import logger from .logging import logger
def encode_to_base64(image): def encode_to_base64(image):
if isinstance(image, str): if isinstance(image, str):
return image return image
elif not judge_image_type(image):
return "Detect result is not image"
elif isinstance(image, Image.Image): elif isinstance(image, Image.Image):
return api.encode_pil_to_base64(image) return api.encode_pil_to_base64(image)
elif isinstance(image, np.ndarray): elif isinstance(image, np.ndarray):
return encode_np_to_base64(image) return encode_np_to_base64(image)
else: else:
logger.warn("Unable to encode image.")
return "" return ""
@ -88,18 +92,18 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
results.append( results.append(
processor_module( processor_module(
img, img,
res=controlnet_processor_res, resolution=controlnet_processor_res,
thr_a=controlnet_threshold_a, slider_1=controlnet_threshold_a,
thr_b=controlnet_threshold_b, slider_2=controlnet_threshold_b,
json_pose_callback=json_acceptor.accept, json_pose_callback=json_acceptor.accept,
)[0] )
) )
if "openpose" in controlnet_module: if "openpose" in controlnet_module:
assert json_acceptor.value is not None assert json_acceptor.value is not None
poses.append(json_acceptor.value) poses.append(json_acceptor.value)
results64 = list(map(encode_to_base64, results)) results64 = [encode_to_base64(img) for img in results]
res = {"images": results64, "info": "Success"} res = {"images": results64, "info": "Success"}
if poses: if poses:
res["poses"] = poses res["poses"] = poses

View File

@ -5,4 +5,3 @@ def pytest_configure(config):
# We don't want to fail on Py.test command line arguments being # We don't want to fail on Py.test command line arguments being
# parsed by webui: # parsed by webui:
os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1") os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1")
os.environ.setdefault("FORGE_CQ_TEST", "1")

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 202 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

View File

@ -0,0 +1,63 @@
import pytest
import requests
from typing import List
from .template import (
APITestTemplate,
realistic_girl_face_img,
save_base64,
get_dest_dir,
disable_in_cq,
)
def get_modules() -> List[str]:
return requests.get(APITestTemplate.BASE_URL + "controlnet/module_list").json()[
"module_list"
]
def detect_template(payload, output_name: str):
url = APITestTemplate.BASE_URL + "controlnet/detect"
resp = requests.post(url, json=payload)
assert resp.status_code == 200
resp_json = resp.json()
assert "images" in resp_json
assert len(resp_json["images"]) == len(payload["controlnet_input_images"])
if not APITestTemplate.is_cq_run:
for i, img in enumerate(resp_json["images"]):
if img == "Detect result is not image":
continue
dest = get_dest_dir() / f"{output_name}_{i}.png"
save_base64(img, dest)
return resp_json
@disable_in_cq
@pytest.mark.parametrize("module", get_modules())
def test_detect_all_modules(module: str):
payload = dict(
controlnet_input_images=[realistic_girl_face_img],
controlnet_module=module,
)
detect_template(payload, f"detect_{module}")
def test_detect_simple():
detect_template(
dict(
controlnet_input_images=[realistic_girl_face_img],
controlnet_module="canny", # Canny does not require model download.
),
"simple_detect",
)
def test_detect_multiple_inputs():
detect_template(
dict(
controlnet_input_images=[realistic_girl_face_img, realistic_girl_face_img],
controlnet_module="canny", # Canny does not require model download.
),
"multiple_inputs_detect",
)

View File

@ -2,16 +2,28 @@ import io
import os import os
import cv2 import cv2
import base64 import base64
import functools
from typing import Dict, Any, List, Union, Literal, Optional from typing import Dict, Any, List, Union, Literal, Optional
from pathlib import Path from pathlib import Path
import datetime import datetime
from enum import Enum from enum import Enum
import numpy as np import numpy as np
import pytest
import requests import requests
from PIL import Image from PIL import Image
def disable_in_cq(func):
"""Skips the decorated test func in CQ run."""
@functools.wraps(func)
def wrapped_func(*args, **kwargs):
if APITestTemplate.is_cq_run:
pytest.skip()
return func(*args, **kwargs)
return wrapped_func
PayloadOverrideType = Dict[str, Any] PayloadOverrideType = Dict[str, Any]
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
@ -21,6 +33,17 @@ os.makedirs(test_expectation_dir, exist_ok=True)
resource_dir = Path(__file__).parents[1] / "images" resource_dir = Path(__file__).parents[1] / "images"
def get_dest_dir():
if APITestTemplate.is_set_expectation_run:
return test_expectation_dir
else:
return test_result_dir
def save_base64(base64img: str, dest: Path):
Image.open(io.BytesIO(base64.b64decode(base64img.split(",", 1)[0]))).save(dest)
def read_image(img_path: Path) -> str: def read_image(img_path: Path) -> str:
img = cv2.imread(str(img_path)) img = cv2.imread(str(img_path))
_, bytes = cv2.imencode(".png", img) _, bytes = cv2.imencode(".png", img)
@ -45,6 +68,8 @@ def read_image_dir(img_dir: Path, suffixes=('.png', '.jpg', '.jpeg', '.webp')) -
girl_img = read_image(resource_dir / "1girl.png") girl_img = read_image(resource_dir / "1girl.png")
mask_img = read_image(resource_dir / "mask.png") mask_img = read_image(resource_dir / "mask.png")
mask_small_img = read_image(resource_dir / "mask_small.png") mask_small_img = read_image(resource_dir / "mask_small.png")
portrait_imgs = read_image_dir(resource_dir / "portrait")
realistic_girl_face_img = portrait_imgs[0]
general_negative_prompt = """ general_negative_prompt = """
@ -74,7 +99,7 @@ is_full_coverage = os.environ.get("CONTROLNET_TEST_FULL_COVERAGE", None) is not
class APITestTemplate: class APITestTemplate:
is_set_expectation_run = os.environ.get("CONTROLNET_SET_EXP", "True") == "True" is_set_expectation_run = os.environ.get("CONTROLNET_SET_EXP", "True") == "True"
is_cq_run = bool(os.environ.get("FORGE_CQ_TEST", "")) is_cq_run = os.environ.get("FORGE_CQ_TEST", "False") == "True"
BASE_URL = "http://localhost:7860/" BASE_URL = "http://localhost:7860/"
def __init__( def __init__(
@ -152,17 +177,11 @@ class APITestTemplate:
print(response.keys()) print(response.keys())
return False return False
dest_dir = ( dest_dir = get_dest_dir()
test_expectation_dir
if APITestTemplate.is_set_expectation_run
else test_result_dir
)
results = response["images"][:1] if result_only else response["images"] results = response["images"][:1] if result_only else response["images"]
for i, base64image in enumerate(results): for i, base64image in enumerate(results):
img_file_name = f"{self.name}_{i}.png" img_file_name = f"{self.name}_{i}.png"
Image.open(io.BytesIO(base64.b64decode(base64image.split(",", 1)[0]))).save( save_base64(base64image, dest_dir / img_file_name)
dest_dir / img_file_name
)
if not APITestTemplate.is_set_expectation_run: if not APITestTemplate.is_set_expectation_run:
try: try:

View File

@ -11,7 +11,6 @@ def pytest_configure(config):
# We don't want to fail on Py.test command line arguments being # We don't want to fail on Py.test command line arguments being
# parsed by webui: # parsed by webui:
os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1") os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1")
os.environ.setdefault("FORGE_CQ_TEST", "1")
def file_to_base64(filename): def file_to_base64(filename):