Fix controlnet/detect API endpoint (#187)
This commit is contained in:
parent
6a854fcb38
commit
e11753ff84
3
.github/workflows/run_tests.yaml
vendored
3
.github/workflows/run_tests.yaml
vendored
@ -4,6 +4,9 @@ on:
|
||||
- push
|
||||
- pull_request
|
||||
|
||||
env:
|
||||
FORGE_CQ_TEST: "True"
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: tests on CPU
|
||||
|
@ -12,17 +12,21 @@ from .global_state import (
|
||||
get_all_controlnet_names,
|
||||
get_preprocessor,
|
||||
)
|
||||
from .utils import judge_image_type
|
||||
from .logging import logger
|
||||
|
||||
|
||||
def encode_to_base64(image):
|
||||
if isinstance(image, str):
|
||||
return image
|
||||
elif not judge_image_type(image):
|
||||
return "Detect result is not image"
|
||||
elif isinstance(image, Image.Image):
|
||||
return api.encode_pil_to_base64(image)
|
||||
elif isinstance(image, np.ndarray):
|
||||
return encode_np_to_base64(image)
|
||||
else:
|
||||
logger.warn("Unable to encode image.")
|
||||
return ""
|
||||
|
||||
|
||||
@ -88,18 +92,18 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
|
||||
results.append(
|
||||
processor_module(
|
||||
img,
|
||||
res=controlnet_processor_res,
|
||||
thr_a=controlnet_threshold_a,
|
||||
thr_b=controlnet_threshold_b,
|
||||
resolution=controlnet_processor_res,
|
||||
slider_1=controlnet_threshold_a,
|
||||
slider_2=controlnet_threshold_b,
|
||||
json_pose_callback=json_acceptor.accept,
|
||||
)[0]
|
||||
)
|
||||
)
|
||||
|
||||
if "openpose" in controlnet_module:
|
||||
assert json_acceptor.value is not None
|
||||
poses.append(json_acceptor.value)
|
||||
|
||||
results64 = list(map(encode_to_base64, results))
|
||||
results64 = [encode_to_base64(img) for img in results]
|
||||
res = {"images": results64, "info": "Success"}
|
||||
if poses:
|
||||
res["poses"] = poses
|
||||
|
@ -5,4 +5,3 @@ def pytest_configure(config):
|
||||
# We don't want to fail on Py.test command line arguments being
|
||||
# parsed by webui:
|
||||
os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1")
|
||||
os.environ.setdefault("FORGE_CQ_TEST", "1")
|
||||
|
Binary file not shown.
After Width: | Height: | Size: 20 KiB |
Binary file not shown.
After Width: | Height: | Size: 37 KiB |
Binary file not shown.
After Width: | Height: | Size: 22 KiB |
Binary file not shown.
After Width: | Height: | Size: 6.4 KiB |
Binary file not shown.
After Width: | Height: | Size: 202 KiB |
Binary file not shown.
After Width: | Height: | Size: 15 KiB |
@ -0,0 +1,63 @@
|
||||
import pytest
|
||||
import requests
|
||||
from typing import List
|
||||
|
||||
from .template import (
|
||||
APITestTemplate,
|
||||
realistic_girl_face_img,
|
||||
save_base64,
|
||||
get_dest_dir,
|
||||
disable_in_cq,
|
||||
)
|
||||
|
||||
|
||||
def get_modules() -> List[str]:
|
||||
return requests.get(APITestTemplate.BASE_URL + "controlnet/module_list").json()[
|
||||
"module_list"
|
||||
]
|
||||
|
||||
|
||||
def detect_template(payload, output_name: str):
|
||||
url = APITestTemplate.BASE_URL + "controlnet/detect"
|
||||
resp = requests.post(url, json=payload)
|
||||
assert resp.status_code == 200
|
||||
resp_json = resp.json()
|
||||
assert "images" in resp_json
|
||||
assert len(resp_json["images"]) == len(payload["controlnet_input_images"])
|
||||
if not APITestTemplate.is_cq_run:
|
||||
for i, img in enumerate(resp_json["images"]):
|
||||
if img == "Detect result is not image":
|
||||
continue
|
||||
dest = get_dest_dir() / f"{output_name}_{i}.png"
|
||||
save_base64(img, dest)
|
||||
return resp_json
|
||||
|
||||
|
||||
@disable_in_cq
|
||||
@pytest.mark.parametrize("module", get_modules())
|
||||
def test_detect_all_modules(module: str):
|
||||
payload = dict(
|
||||
controlnet_input_images=[realistic_girl_face_img],
|
||||
controlnet_module=module,
|
||||
)
|
||||
detect_template(payload, f"detect_{module}")
|
||||
|
||||
|
||||
def test_detect_simple():
|
||||
detect_template(
|
||||
dict(
|
||||
controlnet_input_images=[realistic_girl_face_img],
|
||||
controlnet_module="canny", # Canny does not require model download.
|
||||
),
|
||||
"simple_detect",
|
||||
)
|
||||
|
||||
|
||||
def test_detect_multiple_inputs():
|
||||
detect_template(
|
||||
dict(
|
||||
controlnet_input_images=[realistic_girl_face_img, realistic_girl_face_img],
|
||||
controlnet_module="canny", # Canny does not require model download.
|
||||
),
|
||||
"multiple_inputs_detect",
|
||||
)
|
@ -2,16 +2,28 @@ import io
|
||||
import os
|
||||
import cv2
|
||||
import base64
|
||||
import functools
|
||||
from typing import Dict, Any, List, Union, Literal, Optional
|
||||
from pathlib import Path
|
||||
import datetime
|
||||
from enum import Enum
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def disable_in_cq(func):
|
||||
"""Skips the decorated test func in CQ run."""
|
||||
@functools.wraps(func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
if APITestTemplate.is_cq_run:
|
||||
pytest.skip()
|
||||
return func(*args, **kwargs)
|
||||
return wrapped_func
|
||||
|
||||
|
||||
PayloadOverrideType = Dict[str, Any]
|
||||
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
@ -21,6 +33,17 @@ os.makedirs(test_expectation_dir, exist_ok=True)
|
||||
resource_dir = Path(__file__).parents[1] / "images"
|
||||
|
||||
|
||||
def get_dest_dir():
|
||||
if APITestTemplate.is_set_expectation_run:
|
||||
return test_expectation_dir
|
||||
else:
|
||||
return test_result_dir
|
||||
|
||||
|
||||
def save_base64(base64img: str, dest: Path):
|
||||
Image.open(io.BytesIO(base64.b64decode(base64img.split(",", 1)[0]))).save(dest)
|
||||
|
||||
|
||||
def read_image(img_path: Path) -> str:
|
||||
img = cv2.imread(str(img_path))
|
||||
_, bytes = cv2.imencode(".png", img)
|
||||
@ -45,6 +68,8 @@ def read_image_dir(img_dir: Path, suffixes=('.png', '.jpg', '.jpeg', '.webp')) -
|
||||
girl_img = read_image(resource_dir / "1girl.png")
|
||||
mask_img = read_image(resource_dir / "mask.png")
|
||||
mask_small_img = read_image(resource_dir / "mask_small.png")
|
||||
portrait_imgs = read_image_dir(resource_dir / "portrait")
|
||||
realistic_girl_face_img = portrait_imgs[0]
|
||||
|
||||
|
||||
general_negative_prompt = """
|
||||
@ -74,7 +99,7 @@ is_full_coverage = os.environ.get("CONTROLNET_TEST_FULL_COVERAGE", None) is not
|
||||
|
||||
class APITestTemplate:
|
||||
is_set_expectation_run = os.environ.get("CONTROLNET_SET_EXP", "True") == "True"
|
||||
is_cq_run = bool(os.environ.get("FORGE_CQ_TEST", ""))
|
||||
is_cq_run = os.environ.get("FORGE_CQ_TEST", "False") == "True"
|
||||
BASE_URL = "http://localhost:7860/"
|
||||
|
||||
def __init__(
|
||||
@ -152,17 +177,11 @@ class APITestTemplate:
|
||||
print(response.keys())
|
||||
return False
|
||||
|
||||
dest_dir = (
|
||||
test_expectation_dir
|
||||
if APITestTemplate.is_set_expectation_run
|
||||
else test_result_dir
|
||||
)
|
||||
dest_dir = get_dest_dir()
|
||||
results = response["images"][:1] if result_only else response["images"]
|
||||
for i, base64image in enumerate(results):
|
||||
img_file_name = f"{self.name}_{i}.png"
|
||||
Image.open(io.BytesIO(base64.b64decode(base64image.split(",", 1)[0]))).save(
|
||||
dest_dir / img_file_name
|
||||
)
|
||||
save_base64(base64image, dest_dir / img_file_name)
|
||||
|
||||
if not APITestTemplate.is_set_expectation_run:
|
||||
try:
|
||||
|
@ -11,7 +11,6 @@ def pytest_configure(config):
|
||||
# We don't want to fail on Py.test command line arguments being
|
||||
# parsed by webui:
|
||||
os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1")
|
||||
os.environ.setdefault("FORGE_CQ_TEST", "1")
|
||||
|
||||
|
||||
def file_to_base64(filename):
|
||||
|
Loading…
Reference in New Issue
Block a user