Fix controlnet/detect API endpoint (#187)
This commit is contained in:
parent
6a854fcb38
commit
e11753ff84
3
.github/workflows/run_tests.yaml
vendored
3
.github/workflows/run_tests.yaml
vendored
@ -4,6 +4,9 @@ on:
|
|||||||
- push
|
- push
|
||||||
- pull_request
|
- pull_request
|
||||||
|
|
||||||
|
env:
|
||||||
|
FORGE_CQ_TEST: "True"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
name: tests on CPU
|
name: tests on CPU
|
||||||
|
@ -12,17 +12,21 @@ from .global_state import (
|
|||||||
get_all_controlnet_names,
|
get_all_controlnet_names,
|
||||||
get_preprocessor,
|
get_preprocessor,
|
||||||
)
|
)
|
||||||
|
from .utils import judge_image_type
|
||||||
from .logging import logger
|
from .logging import logger
|
||||||
|
|
||||||
|
|
||||||
def encode_to_base64(image):
|
def encode_to_base64(image):
|
||||||
if isinstance(image, str):
|
if isinstance(image, str):
|
||||||
return image
|
return image
|
||||||
|
elif not judge_image_type(image):
|
||||||
|
return "Detect result is not image"
|
||||||
elif isinstance(image, Image.Image):
|
elif isinstance(image, Image.Image):
|
||||||
return api.encode_pil_to_base64(image)
|
return api.encode_pil_to_base64(image)
|
||||||
elif isinstance(image, np.ndarray):
|
elif isinstance(image, np.ndarray):
|
||||||
return encode_np_to_base64(image)
|
return encode_np_to_base64(image)
|
||||||
else:
|
else:
|
||||||
|
logger.warn("Unable to encode image.")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
@ -88,18 +92,18 @@ def controlnet_api(_: gr.Blocks, app: FastAPI):
|
|||||||
results.append(
|
results.append(
|
||||||
processor_module(
|
processor_module(
|
||||||
img,
|
img,
|
||||||
res=controlnet_processor_res,
|
resolution=controlnet_processor_res,
|
||||||
thr_a=controlnet_threshold_a,
|
slider_1=controlnet_threshold_a,
|
||||||
thr_b=controlnet_threshold_b,
|
slider_2=controlnet_threshold_b,
|
||||||
json_pose_callback=json_acceptor.accept,
|
json_pose_callback=json_acceptor.accept,
|
||||||
)[0]
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if "openpose" in controlnet_module:
|
if "openpose" in controlnet_module:
|
||||||
assert json_acceptor.value is not None
|
assert json_acceptor.value is not None
|
||||||
poses.append(json_acceptor.value)
|
poses.append(json_acceptor.value)
|
||||||
|
|
||||||
results64 = list(map(encode_to_base64, results))
|
results64 = [encode_to_base64(img) for img in results]
|
||||||
res = {"images": results64, "info": "Success"}
|
res = {"images": results64, "info": "Success"}
|
||||||
if poses:
|
if poses:
|
||||||
res["poses"] = poses
|
res["poses"] = poses
|
||||||
|
@ -5,4 +5,3 @@ def pytest_configure(config):
|
|||||||
# We don't want to fail on Py.test command line arguments being
|
# We don't want to fail on Py.test command line arguments being
|
||||||
# parsed by webui:
|
# parsed by webui:
|
||||||
os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1")
|
os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1")
|
||||||
os.environ.setdefault("FORGE_CQ_TEST", "1")
|
|
||||||
|
Binary file not shown.
After Width: | Height: | Size: 20 KiB |
Binary file not shown.
After Width: | Height: | Size: 37 KiB |
Binary file not shown.
After Width: | Height: | Size: 22 KiB |
Binary file not shown.
After Width: | Height: | Size: 6.4 KiB |
Binary file not shown.
After Width: | Height: | Size: 202 KiB |
Binary file not shown.
After Width: | Height: | Size: 15 KiB |
@ -0,0 +1,63 @@
|
|||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from .template import (
|
||||||
|
APITestTemplate,
|
||||||
|
realistic_girl_face_img,
|
||||||
|
save_base64,
|
||||||
|
get_dest_dir,
|
||||||
|
disable_in_cq,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_modules() -> List[str]:
|
||||||
|
return requests.get(APITestTemplate.BASE_URL + "controlnet/module_list").json()[
|
||||||
|
"module_list"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def detect_template(payload, output_name: str):
|
||||||
|
url = APITestTemplate.BASE_URL + "controlnet/detect"
|
||||||
|
resp = requests.post(url, json=payload)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
resp_json = resp.json()
|
||||||
|
assert "images" in resp_json
|
||||||
|
assert len(resp_json["images"]) == len(payload["controlnet_input_images"])
|
||||||
|
if not APITestTemplate.is_cq_run:
|
||||||
|
for i, img in enumerate(resp_json["images"]):
|
||||||
|
if img == "Detect result is not image":
|
||||||
|
continue
|
||||||
|
dest = get_dest_dir() / f"{output_name}_{i}.png"
|
||||||
|
save_base64(img, dest)
|
||||||
|
return resp_json
|
||||||
|
|
||||||
|
|
||||||
|
@disable_in_cq
|
||||||
|
@pytest.mark.parametrize("module", get_modules())
|
||||||
|
def test_detect_all_modules(module: str):
|
||||||
|
payload = dict(
|
||||||
|
controlnet_input_images=[realistic_girl_face_img],
|
||||||
|
controlnet_module=module,
|
||||||
|
)
|
||||||
|
detect_template(payload, f"detect_{module}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_simple():
|
||||||
|
detect_template(
|
||||||
|
dict(
|
||||||
|
controlnet_input_images=[realistic_girl_face_img],
|
||||||
|
controlnet_module="canny", # Canny does not require model download.
|
||||||
|
),
|
||||||
|
"simple_detect",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_multiple_inputs():
|
||||||
|
detect_template(
|
||||||
|
dict(
|
||||||
|
controlnet_input_images=[realistic_girl_face_img, realistic_girl_face_img],
|
||||||
|
controlnet_module="canny", # Canny does not require model download.
|
||||||
|
),
|
||||||
|
"multiple_inputs_detect",
|
||||||
|
)
|
@ -2,16 +2,28 @@ import io
|
|||||||
import os
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
import base64
|
import base64
|
||||||
|
import functools
|
||||||
from typing import Dict, Any, List, Union, Literal, Optional
|
from typing import Dict, Any, List, Union, Literal, Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import datetime
|
import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def disable_in_cq(func):
|
||||||
|
"""Skips the decorated test func in CQ run."""
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapped_func(*args, **kwargs):
|
||||||
|
if APITestTemplate.is_cq_run:
|
||||||
|
pytest.skip()
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return wrapped_func
|
||||||
|
|
||||||
|
|
||||||
PayloadOverrideType = Dict[str, Any]
|
PayloadOverrideType = Dict[str, Any]
|
||||||
|
|
||||||
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
@ -21,6 +33,17 @@ os.makedirs(test_expectation_dir, exist_ok=True)
|
|||||||
resource_dir = Path(__file__).parents[1] / "images"
|
resource_dir = Path(__file__).parents[1] / "images"
|
||||||
|
|
||||||
|
|
||||||
|
def get_dest_dir():
|
||||||
|
if APITestTemplate.is_set_expectation_run:
|
||||||
|
return test_expectation_dir
|
||||||
|
else:
|
||||||
|
return test_result_dir
|
||||||
|
|
||||||
|
|
||||||
|
def save_base64(base64img: str, dest: Path):
|
||||||
|
Image.open(io.BytesIO(base64.b64decode(base64img.split(",", 1)[0]))).save(dest)
|
||||||
|
|
||||||
|
|
||||||
def read_image(img_path: Path) -> str:
|
def read_image(img_path: Path) -> str:
|
||||||
img = cv2.imread(str(img_path))
|
img = cv2.imread(str(img_path))
|
||||||
_, bytes = cv2.imencode(".png", img)
|
_, bytes = cv2.imencode(".png", img)
|
||||||
@ -45,6 +68,8 @@ def read_image_dir(img_dir: Path, suffixes=('.png', '.jpg', '.jpeg', '.webp')) -
|
|||||||
girl_img = read_image(resource_dir / "1girl.png")
|
girl_img = read_image(resource_dir / "1girl.png")
|
||||||
mask_img = read_image(resource_dir / "mask.png")
|
mask_img = read_image(resource_dir / "mask.png")
|
||||||
mask_small_img = read_image(resource_dir / "mask_small.png")
|
mask_small_img = read_image(resource_dir / "mask_small.png")
|
||||||
|
portrait_imgs = read_image_dir(resource_dir / "portrait")
|
||||||
|
realistic_girl_face_img = portrait_imgs[0]
|
||||||
|
|
||||||
|
|
||||||
general_negative_prompt = """
|
general_negative_prompt = """
|
||||||
@ -74,7 +99,7 @@ is_full_coverage = os.environ.get("CONTROLNET_TEST_FULL_COVERAGE", None) is not
|
|||||||
|
|
||||||
class APITestTemplate:
|
class APITestTemplate:
|
||||||
is_set_expectation_run = os.environ.get("CONTROLNET_SET_EXP", "True") == "True"
|
is_set_expectation_run = os.environ.get("CONTROLNET_SET_EXP", "True") == "True"
|
||||||
is_cq_run = bool(os.environ.get("FORGE_CQ_TEST", ""))
|
is_cq_run = os.environ.get("FORGE_CQ_TEST", "False") == "True"
|
||||||
BASE_URL = "http://localhost:7860/"
|
BASE_URL = "http://localhost:7860/"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -152,17 +177,11 @@ class APITestTemplate:
|
|||||||
print(response.keys())
|
print(response.keys())
|
||||||
return False
|
return False
|
||||||
|
|
||||||
dest_dir = (
|
dest_dir = get_dest_dir()
|
||||||
test_expectation_dir
|
|
||||||
if APITestTemplate.is_set_expectation_run
|
|
||||||
else test_result_dir
|
|
||||||
)
|
|
||||||
results = response["images"][:1] if result_only else response["images"]
|
results = response["images"][:1] if result_only else response["images"]
|
||||||
for i, base64image in enumerate(results):
|
for i, base64image in enumerate(results):
|
||||||
img_file_name = f"{self.name}_{i}.png"
|
img_file_name = f"{self.name}_{i}.png"
|
||||||
Image.open(io.BytesIO(base64.b64decode(base64image.split(",", 1)[0]))).save(
|
save_base64(base64image, dest_dir / img_file_name)
|
||||||
dest_dir / img_file_name
|
|
||||||
)
|
|
||||||
|
|
||||||
if not APITestTemplate.is_set_expectation_run:
|
if not APITestTemplate.is_set_expectation_run:
|
||||||
try:
|
try:
|
||||||
|
@ -11,7 +11,6 @@ def pytest_configure(config):
|
|||||||
# We don't want to fail on Py.test command line arguments being
|
# We don't want to fail on Py.test command line arguments being
|
||||||
# parsed by webui:
|
# parsed by webui:
|
||||||
os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1")
|
os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1")
|
||||||
os.environ.setdefault("FORGE_CQ_TEST", "1")
|
|
||||||
|
|
||||||
|
|
||||||
def file_to_base64(filename):
|
def file_to_base64(filename):
|
||||||
|
Loading…
Reference in New Issue
Block a user