417 lines
15 KiB
Python
417 lines
15 KiB
Python
from typing import Optional
|
|
from modules import processing
|
|
|
|
from lib_controlnet import external_code
|
|
|
|
from modules_forge.forge_util import HWC3
|
|
|
|
from PIL import Image, ImageFilter, ImageOps
|
|
from lib_controlnet.lvminthin import lvmin_thin, nake_nms
|
|
|
|
import torch
|
|
import os
|
|
import functools
|
|
import time
|
|
import base64
|
|
import numpy as np
|
|
import safetensors.torch
|
|
import cv2
|
|
import logging
|
|
|
|
from typing import Any, Callable, Dict, List
|
|
from modules.safe import unsafe_torch_load
|
|
from lib_controlnet.logging import logger
|
|
|
|
|
|
def load_state_dict(ckpt_path, location="cpu"):
|
|
_, extension = os.path.splitext(ckpt_path)
|
|
if extension.lower() == ".safetensors":
|
|
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
|
|
else:
|
|
state_dict = unsafe_torch_load(ckpt_path, map_location=torch.device(location))
|
|
state_dict = get_state_dict(state_dict)
|
|
logger.info(f"Loaded state_dict from [{ckpt_path}]")
|
|
return state_dict
|
|
|
|
|
|
def get_state_dict(d):
|
|
return d.get("state_dict", d)
|
|
|
|
|
|
def ndarray_lru_cache(max_size: int = 128, typed: bool = False):
|
|
"""
|
|
Decorator to enable caching for functions with numpy array arguments.
|
|
Numpy arrays are mutable, and thus not directly usable as hash keys.
|
|
|
|
The idea here is to wrap the incoming arguments with type `np.ndarray`
|
|
as `HashableNpArray` so that `lru_cache` can correctly handles `np.ndarray`
|
|
arguments.
|
|
|
|
`HashableNpArray` functions exactly the same way as `np.ndarray` except
|
|
having `__hash__` and `__eq__` overriden.
|
|
"""
|
|
|
|
def decorator(func: Callable):
|
|
"""The actual decorator that accept function as input."""
|
|
|
|
class HashableNpArray(np.ndarray):
|
|
def __new__(cls, input_array):
|
|
# Input array is an instance of ndarray.
|
|
# The view makes the input array and returned array share the same data.
|
|
obj = np.asarray(input_array).view(cls)
|
|
return obj
|
|
|
|
def __eq__(self, other) -> bool:
|
|
return np.array_equal(self, other)
|
|
|
|
def __hash__(self):
|
|
# Hash the bytes representing the data of the array.
|
|
return hash(self.tobytes())
|
|
|
|
@functools.lru_cache(maxsize=max_size, typed=typed)
|
|
def cached_func(*args, **kwargs):
|
|
"""This function only accepts `HashableNpArray` as input params."""
|
|
return func(*args, **kwargs)
|
|
|
|
# Preserves original function.__name__ and __doc__.
|
|
@functools.wraps(func)
|
|
def decorated_func(*args, **kwargs):
|
|
"""The decorated function that delegates the original function."""
|
|
|
|
def convert_item(item: Any):
|
|
if isinstance(item, np.ndarray):
|
|
return HashableNpArray(item)
|
|
if isinstance(item, tuple):
|
|
return tuple(convert_item(i) for i in item)
|
|
return item
|
|
|
|
args = [convert_item(arg) for arg in args]
|
|
kwargs = {k: convert_item(arg) for k, arg in kwargs.items()}
|
|
return cached_func(*args, **kwargs)
|
|
|
|
return decorated_func
|
|
|
|
return decorator
|
|
|
|
|
|
def timer_decorator(func):
|
|
"""Time the decorated function and output the result to debug logger."""
|
|
if logger.level != logging.DEBUG:
|
|
return func
|
|
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
start_time = time.time()
|
|
result = func(*args, **kwargs)
|
|
end_time = time.time()
|
|
duration = end_time - start_time
|
|
# Only report function that are significant enough.
|
|
if duration > 1e-3:
|
|
logger.debug(f"{func.__name__} ran in: {duration:.3f} sec")
|
|
return result
|
|
|
|
return wrapper
|
|
|
|
|
|
class TimeMeta(type):
|
|
""" Metaclass to record execution time on all methods of the
|
|
child class. """
|
|
def __new__(cls, name, bases, attrs):
|
|
for attr_name, attr_value in attrs.items():
|
|
if callable(attr_value):
|
|
attrs[attr_name] = timer_decorator(attr_value)
|
|
return super().__new__(cls, name, bases, attrs)
|
|
|
|
|
|
# svgsupports
|
|
svgsupport = False
|
|
try:
|
|
import io
|
|
from svglib.svglib import svg2rlg
|
|
from reportlab.graphics import renderPM
|
|
|
|
svgsupport = True
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
def svg_preprocess(inputs: Dict, preprocess: Callable):
|
|
if not inputs:
|
|
return None
|
|
|
|
if inputs["image"].startswith("data:image/svg+xml;base64,") and svgsupport:
|
|
svg_data = base64.b64decode(
|
|
inputs["image"].replace("data:image/svg+xml;base64,", "")
|
|
)
|
|
drawing = svg2rlg(io.BytesIO(svg_data))
|
|
png_data = renderPM.drawToString(drawing, fmt="PNG")
|
|
encoded_string = base64.b64encode(png_data)
|
|
base64_str = str(encoded_string, "utf-8")
|
|
base64_str = "data:image/png;base64," + base64_str
|
|
inputs["image"] = base64_str
|
|
return preprocess(inputs)
|
|
|
|
|
|
def get_unique_axis0(data):
|
|
arr = np.asanyarray(data)
|
|
idxs = np.lexsort(arr.T)
|
|
arr = arr[idxs]
|
|
unique_idxs = np.empty(len(arr), dtype=np.bool_)
|
|
unique_idxs[:1] = True
|
|
unique_idxs[1:] = np.any(arr[:-1, :] != arr[1:, :], axis=-1)
|
|
return arr[unique_idxs]
|
|
|
|
|
|
def read_image(img_path: str) -> str:
|
|
"""Read image from specified path and return a base64 string."""
|
|
img = cv2.imread(img_path)
|
|
_, bytes = cv2.imencode(".png", img)
|
|
encoded_image = base64.b64encode(bytes).decode("utf-8")
|
|
return encoded_image
|
|
|
|
|
|
def read_image_dir(img_dir: str, suffixes=('.png', '.jpg', '.jpeg', '.webp')) -> List[str]:
|
|
"""Try read all images in given img_dir."""
|
|
images = []
|
|
for filename in os.listdir(img_dir):
|
|
if filename.endswith(suffixes):
|
|
img_path = os.path.join(img_dir, filename)
|
|
try:
|
|
images.append(read_image(img_path))
|
|
except IOError:
|
|
logger.error(f"Error opening {img_path}")
|
|
return images
|
|
|
|
|
|
def align_dim_latent(x: int) -> int:
|
|
""" Align the pixel dimension (w/h) to latent dimension.
|
|
Stable diffusion 1:8 ratio for latent/pixel, i.e.,
|
|
1 latent unit == 8 pixel unit."""
|
|
return (x // 8) * 8
|
|
|
|
|
|
def image_dict_from_any(image) -> Optional[Dict[str, np.ndarray]]:
|
|
if image is None:
|
|
return None
|
|
|
|
if isinstance(image, (tuple, list)):
|
|
image = {'image': image[0], 'mask': image[1]}
|
|
elif not isinstance(image, dict):
|
|
image = {'image': image, 'mask': None}
|
|
else: # type(image) is dict
|
|
# copy to enable modifying the dict and prevent response serialization error
|
|
image = dict(image)
|
|
|
|
if isinstance(image['image'], str):
|
|
if os.path.exists(image['image']):
|
|
image['image'] = np.array(Image.open(image['image'])).astype('uint8')
|
|
elif image['image']:
|
|
image['image'] = external_code.to_base64_nparray(image['image'])
|
|
else:
|
|
image['image'] = None
|
|
|
|
# If there is no image, return image with None image and None mask
|
|
if image['image'] is None:
|
|
image['mask'] = None
|
|
return image
|
|
|
|
if 'mask' not in image or image['mask'] is None:
|
|
image['mask'] = np.zeros_like(image['image'], dtype=np.uint8)
|
|
elif isinstance(image['mask'], str):
|
|
if os.path.exists(image['mask']):
|
|
image['mask'] = np.array(Image.open(image['mask'])).astype('uint8')
|
|
elif image['mask']:
|
|
image['mask'] = external_code.to_base64_nparray(image['mask'])
|
|
else:
|
|
image['mask'] = np.zeros_like(image['image'], dtype=np.uint8)
|
|
|
|
return image
|
|
|
|
|
|
def prepare_mask(
|
|
mask: Image.Image, p: processing.StableDiffusionProcessing
|
|
) -> Image.Image:
|
|
"""
|
|
Prepare an image mask for the inpainting process.
|
|
|
|
This function takes as input a PIL Image object and an instance of the
|
|
StableDiffusionProcessing class, and performs the following steps to prepare the mask:
|
|
|
|
1. Convert the mask to grayscale (mode "L").
|
|
2. If the 'inpainting_mask_invert' attribute of the processing instance is True,
|
|
invert the mask colors.
|
|
3. If the 'mask_blur' attribute of the processing instance is greater than 0,
|
|
apply a Gaussian blur to the mask with a radius equal to 'mask_blur'.
|
|
|
|
Args:
|
|
mask (Image.Image): The input mask as a PIL Image object.
|
|
p (processing.StableDiffusionProcessing): An instance of the StableDiffusionProcessing class
|
|
containing the processing parameters.
|
|
|
|
Returns:
|
|
mask (Image.Image): The prepared mask as a PIL Image object.
|
|
"""
|
|
mask = mask.convert("L")
|
|
if getattr(p, "inpainting_mask_invert", False):
|
|
mask = ImageOps.invert(mask)
|
|
|
|
if hasattr(p, 'mask_blur_x'):
|
|
if getattr(p, "mask_blur_x", 0) > 0:
|
|
np_mask = np.array(mask)
|
|
kernel_size = 2 * int(2.5 * p.mask_blur_x + 0.5) + 1
|
|
np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), p.mask_blur_x)
|
|
mask = Image.fromarray(np_mask)
|
|
if getattr(p, "mask_blur_y", 0) > 0:
|
|
np_mask = np.array(mask)
|
|
kernel_size = 2 * int(2.5 * p.mask_blur_y + 0.5) + 1
|
|
np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), p.mask_blur_y)
|
|
mask = Image.fromarray(np_mask)
|
|
else:
|
|
if getattr(p, "mask_blur", 0) > 0:
|
|
mask = mask.filter(ImageFilter.GaussianBlur(p.mask_blur))
|
|
|
|
return mask
|
|
|
|
|
|
def set_numpy_seed(p: processing.StableDiffusionProcessing) -> Optional[int]:
|
|
"""
|
|
Set the random seed for NumPy based on the provided parameters.
|
|
|
|
Args:
|
|
p (processing.StableDiffusionProcessing): The instance of the StableDiffusionProcessing class.
|
|
|
|
Returns:
|
|
Optional[int]: The computed random seed if successful, or None if an exception occurs.
|
|
|
|
This function sets the random seed for NumPy using the seed and subseed values from the given instance of
|
|
StableDiffusionProcessing. If either seed or subseed is -1, it uses the first value from `all_seeds`.
|
|
Otherwise, it takes the maximum of the provided seed value and 0.
|
|
|
|
The final random seed is computed by adding the seed and subseed values, applying a bitwise AND operation
|
|
with 0xFFFFFFFF to ensure it fits within a 32-bit integer.
|
|
"""
|
|
try:
|
|
tmp_seed = int(p.all_seeds[0] if p.seed == -1 else max(int(p.seed), 0))
|
|
tmp_subseed = int(p.all_seeds[0] if p.subseed == -1 else max(int(p.subseed), 0))
|
|
seed = (tmp_seed + tmp_subseed) & 0xFFFFFFFF
|
|
np.random.seed(seed)
|
|
return seed
|
|
except Exception as e:
|
|
logger.warning(e)
|
|
logger.warning('Warning: Failed to use consistent random seed.')
|
|
return None
|
|
|
|
|
|
def safe_numpy(x):
|
|
# A very safe method to make sure that Apple/Mac works
|
|
y = x
|
|
|
|
# below is very boring but do not change these. If you change these Apple or Mac may fail.
|
|
y = y.copy()
|
|
y = np.ascontiguousarray(y)
|
|
y = y.copy()
|
|
return y
|
|
|
|
|
|
def high_quality_resize(x, size):
|
|
# Written by lvmin
|
|
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
|
|
|
|
inpaint_mask = None
|
|
if x.ndim == 3 and x.shape[2] == 4:
|
|
inpaint_mask = x[:, :, 3]
|
|
x = x[:, :, 0:3]
|
|
|
|
if x.shape[0] != size[1] or x.shape[1] != size[0]:
|
|
new_size_is_smaller = (size[0] * size[1]) < (x.shape[0] * x.shape[1])
|
|
new_size_is_bigger = (size[0] * size[1]) > (x.shape[0] * x.shape[1])
|
|
unique_color_count = len(get_unique_axis0(x.reshape(-1, x.shape[2])))
|
|
is_one_pixel_edge = False
|
|
is_binary = False
|
|
if unique_color_count == 2:
|
|
is_binary = np.min(x) < 16 and np.max(x) > 240
|
|
if is_binary:
|
|
xc = x
|
|
xc = cv2.erode(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
|
|
xc = cv2.dilate(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
|
|
one_pixel_edge_count = np.where(xc < x)[0].shape[0]
|
|
all_edge_count = np.where(x > 127)[0].shape[0]
|
|
is_one_pixel_edge = one_pixel_edge_count * 2 > all_edge_count
|
|
|
|
if 2 < unique_color_count < 200:
|
|
interpolation = cv2.INTER_NEAREST
|
|
elif new_size_is_smaller:
|
|
interpolation = cv2.INTER_AREA
|
|
else:
|
|
interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS
|
|
|
|
y = cv2.resize(x, size, interpolation=interpolation)
|
|
if inpaint_mask is not None:
|
|
inpaint_mask = cv2.resize(inpaint_mask, size, interpolation=interpolation)
|
|
|
|
if is_binary:
|
|
y = np.mean(y.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8)
|
|
if is_one_pixel_edge:
|
|
y = nake_nms(y)
|
|
_, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
|
y = lvmin_thin(y, prunings=new_size_is_bigger)
|
|
else:
|
|
_, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
|
y = np.stack([y] * 3, axis=2)
|
|
else:
|
|
y = x
|
|
|
|
if inpaint_mask is not None:
|
|
inpaint_mask = (inpaint_mask > 127).astype(np.float32) * 255.0
|
|
inpaint_mask = inpaint_mask[:, :, None].clip(0, 255).astype(np.uint8)
|
|
y = np.concatenate([y, inpaint_mask], axis=2)
|
|
|
|
return y
|
|
|
|
|
|
def crop_and_resize_image(detected_map, resize_mode, h, w):
|
|
if resize_mode == external_code.ResizeMode.RESIZE:
|
|
detected_map = high_quality_resize(detected_map, (w, h))
|
|
detected_map = safe_numpy(detected_map)
|
|
return detected_map
|
|
|
|
old_h, old_w, _ = detected_map.shape
|
|
old_w = float(old_w)
|
|
old_h = float(old_h)
|
|
k0 = float(h) / old_h
|
|
k1 = float(w) / old_w
|
|
|
|
safeint = lambda x: int(np.round(x))
|
|
|
|
if resize_mode == external_code.ResizeMode.OUTER_FIT:
|
|
k = min(k0, k1)
|
|
borders = np.concatenate([detected_map[0, :, :], detected_map[-1, :, :], detected_map[:, 0, :], detected_map[:, -1, :]], axis=0)
|
|
high_quality_border_color = np.median(borders, axis=0).astype(detected_map.dtype)
|
|
if len(high_quality_border_color) == 4:
|
|
# Inpaint hijack
|
|
high_quality_border_color[3] = 255
|
|
high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1])
|
|
detected_map = high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k)))
|
|
new_h, new_w, _ = detected_map.shape
|
|
pad_h = max(0, (h - new_h) // 2)
|
|
pad_w = max(0, (w - new_w) // 2)
|
|
high_quality_background[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = detected_map
|
|
detected_map = high_quality_background
|
|
detected_map = safe_numpy(detected_map)
|
|
return detected_map
|
|
else:
|
|
k = max(k0, k1)
|
|
detected_map = high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k)))
|
|
new_h, new_w, _ = detected_map.shape
|
|
pad_h = max(0, (new_h - h) // 2)
|
|
pad_w = max(0, (new_w - w) // 2)
|
|
detected_map = detected_map[pad_h:pad_h+h, pad_w:pad_w+w]
|
|
detected_map = safe_numpy(detected_map)
|
|
return detected_map
|
|
|
|
|
|
def judge_image_type(img):
|
|
is_image_hw3or4 = isinstance(img, np.ndarray) and img.ndim == 3 and int(img.shape[2]) in [3, 4]
|
|
is_png = is_image_hw3or4 and int(img.shape[2]) == 4
|
|
return is_image_hw3or4, is_png
|