From 6fc12428e3c5f903584ca7986e0c441f80fa2807 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Mon, 4 Dec 2023 19:42:59 -0700 Subject: [PATCH] Fixed issue where batched inpainting (batch size > 1) wouldn't work because of mismatched tensor sizes. The 'already_decoded' decoded case should also be handled correctly (tested indirectly). --- modules/processing.py | 23 ++++++++----- modules/soft_inpainting.py | 66 ++++++++++++++++++++++++++++++++------ 2 files changed, 71 insertions(+), 18 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 7fc282cf..71bb056a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -883,20 +883,27 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim # todo: generate adaptive masks based on pixel differences. - # if p.masks_for_overlay is used, it will already be populated with masks + if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None: + si.apply_masks(soft_inpainting=p.soft_inpainting, + nmask=p.nmask, + overlay_images=p.overlay_images, + masks_for_overlay=p.masks_for_overlay, + width=p.width, + height=p.height, + paste_to=p.paste_to) else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method # Generate the mask(s) based on similarity between the original and denoised latent vectors if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None: - si.generate_adaptive_masks(latent_orig=p.init_latent, - latent_processed=samples_ddim, - overlay_images=p.overlay_images, - masks_for_overlay=p.masks_for_overlay, - width=p.width, - height=p.height, - paste_to=p.paste_to) + si.apply_adaptive_masks(latent_orig=p.init_latent, + latent_processed=samples_ddim, + overlay_images=p.overlay_images, + masks_for_overlay=p.masks_for_overlay, + width=p.width, + height=p.height, + paste_to=p.paste_to) x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) diff --git a/modules/soft_inpainting.py b/modules/soft_inpainting.py index 56a87774..b36ac8fa 100644 --- a/modules/soft_inpainting.py +++ b/modules/soft_inpainting.py @@ -25,26 +25,32 @@ def latent_blend(soft_inpainting, a, b, t): # NOTE: We use inplace operations wherever possible. - one_minus_t = 1 - t + # [4][w][h] to [1][4][w][h] + t2 = t.unsqueeze(0) + # [4][w][h] to [1][1][w][h] - the [4] seem redundant. + t3 = t[0].unsqueeze(0).unsqueeze(0) + + one_minus_t2 = 1 - t2 + one_minus_t3 = 1 - t3 # Linearly interpolate the image vectors. - a_scaled = a * one_minus_t - b_scaled = b * t + a_scaled = a * one_minus_t2 + b_scaled = b * t2 image_interp = a_scaled image_interp.add_(b_scaled) result_type = image_interp.dtype - del a_scaled, b_scaled + del a_scaled, b_scaled, t2, one_minus_t2 # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) # 64-bit operations are used here to allow large exponents. - current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001) + current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001) # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). - a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t - b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t + a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t3 + b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t3 desired_magnitude = a_magnitude desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation) - del a_magnitude, b_magnitude, one_minus_t + del a_magnitude, b_magnitude, t3, one_minus_t3 # Change the linearly interpolated image vectors' magnitudes to the value we want. # This is the last 64-bit operation. @@ -78,10 +84,11 @@ def get_modified_nmask(soft_inpainting, nmask, sigma): NOTE: "mask" is not used """ import torch - return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale) + # todo: Why is sigma 2D? Both values are the same. + return torch.pow(nmask, (sigma[0] ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale) -def generate_adaptive_masks( +def apply_adaptive_masks( latent_orig, latent_processed, overlay_images, @@ -142,6 +149,45 @@ def generate_adaptive_masks( overlay_images[i] = image_masked.convert('RGBA') +def apply_masks( + soft_inpainting, + nmask, + overlay_images, + masks_for_overlay, + width, height, + paste_to): + import torch + import numpy as np + import modules.processing as proc + import modules.images as images + from PIL import Image, ImageOps, ImageFilter + + converted_mask = nmask[0].float() + converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(soft_inpainting.mask_blend_scale / 2) + converted_mask = 255. * converted_mask + converted_mask = converted_mask.cpu().numpy().astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, width, height) + converted_mask = proc.create_binary_mask(converted_mask, round=False) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if paste_to is not None: + converted_mask = proc.uncrop(converted_mask, + (width, height), + paste_to) + + for i, overlay_image in enumerate(overlay_images): + masks_for_overlay[i] = converted_mask + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + overlay_images[i] = image_masked.convert('RGBA') + # ------------------- Constants -------------------