From 7d7094b6ec29c964fdbf9920ec35aeffde6cf072 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Fri, 2 Feb 2024 00:14:37 -0800 Subject: [PATCH] marigold_ini --- .../marigold/model/__init__.py | 0 .../marigold/model/marigold_pipeline.py | 313 ++++++++++++++++++ .../marigold/model/rgb_encoder.py | 36 ++ .../marigold/model/stacked_depth_AE.py | 52 +++ .../marigold/util/batchsize.py | 38 +++ .../marigold/util/ensemble.py | 103 ++++++ .../marigold/util/image_util.py | 66 ++++ .../marigold/util/seed_all.py | 14 + .../scripts/preprocessor_marigold.py | 75 +++++ modules_forge/diffusers_patcher.py | 6 +- requirements_versions.txt | 1 + 11 files changed, 703 insertions(+), 1 deletion(-) create mode 100644 extensions-builtin/forge_preprocessor_marigold/marigold/model/__init__.py create mode 100644 extensions-builtin/forge_preprocessor_marigold/marigold/model/marigold_pipeline.py create mode 100644 extensions-builtin/forge_preprocessor_marigold/marigold/model/rgb_encoder.py create mode 100644 extensions-builtin/forge_preprocessor_marigold/marigold/model/stacked_depth_AE.py create mode 100644 extensions-builtin/forge_preprocessor_marigold/marigold/util/batchsize.py create mode 100644 extensions-builtin/forge_preprocessor_marigold/marigold/util/ensemble.py create mode 100644 extensions-builtin/forge_preprocessor_marigold/marigold/util/image_util.py create mode 100644 extensions-builtin/forge_preprocessor_marigold/marigold/util/seed_all.py create mode 100644 extensions-builtin/forge_preprocessor_marigold/scripts/preprocessor_marigold.py diff --git a/extensions-builtin/forge_preprocessor_marigold/marigold/model/__init__.py b/extensions-builtin/forge_preprocessor_marigold/marigold/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/extensions-builtin/forge_preprocessor_marigold/marigold/model/marigold_pipeline.py b/extensions-builtin/forge_preprocessor_marigold/marigold/model/marigold_pipeline.py new file mode 100644 index 00000000..6a69b42e --- /dev/null +++ b/extensions-builtin/forge_preprocessor_marigold/marigold/model/marigold_pipeline.py @@ -0,0 +1,313 @@ +# Author: Bingxin Ke +# Last modified: 2023-12-11 + +import logging +from typing import Dict + +import numpy as np +import torch +from diffusers import ( + DDIMScheduler, + DDPMScheduler, + PNDMScheduler, + DEISMultistepScheduler, + SchedulerMixin, + UNet2DConditionModel, +) +from torch import nn +from torch.nn import Conv2d +from torch.nn.parameter import Parameter +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +from .rgb_encoder import RGBEncoder +from .stacked_depth_AE import StackedDepthAE + + +class MarigoldPipeline(nn.Module): + """ + Marigold monocular depth estimator. + """ + + def __init__( + self, + unet_pretrained_path: Dict, # {path: xxx, subfolder: xxx} + rgb_encoder_pretrained_path: Dict, + depht_ae_pretrained_path: Dict, + noise_scheduler_pretrained_path: Dict, + tokenizer_pretrained_path: Dict, + text_encoder_pretrained_path: Dict, + empty_text_embed=None, + trainable_unet=False, + rgb_latent_scale_factor=0.18215, + depth_latent_scale_factor=0.18215, + noise_scheduler_type=None, + enable_gradient_checkpointing=False, + enable_xformers=True, + ) -> None: + super().__init__() + + self.rgb_latent_scale_factor = rgb_latent_scale_factor + self.depth_latent_scale_factor = depth_latent_scale_factor + self.device = "cpu" + + # ******* Initialize modules ******* + # Trainable modules + self.trainable_module_dic: Dict[str, nn.Module] = {} + self.trainable_unet = trainable_unet + + # Denoising UNet + self.unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained( + unet_pretrained_path["path"], subfolder=unet_pretrained_path["subfolder"] + ) + logging.info(f"pretrained UNet loaded from: {unet_pretrained_path}") + if 8 != self.unet.config["in_channels"]: + self._replace_unet_conv_in() + logging.warning("Unet conv_in layer is replaced") + if enable_xformers: + self.unet.enable_xformers_memory_efficient_attention() + else: + self.unet.disable_xformers_memory_efficient_attention() + + # Image encoder + self.rgb_encoder = RGBEncoder( + pretrained_path=rgb_encoder_pretrained_path["path"], + subfolder=rgb_encoder_pretrained_path["subfolder"], + ) + logging.info( + f"pretrained RGBEncoder loaded from: {rgb_encoder_pretrained_path}" + ) + self.rgb_encoder.requires_grad_(False) + + # Depth encoder-decoder + self.depth_ae = StackedDepthAE( + pretrained_path=depht_ae_pretrained_path["path"], + subfolder=depht_ae_pretrained_path["subfolder"], + ) + logging.info( + f"pretrained Depth Autoencoder loaded from: {rgb_encoder_pretrained_path}" + ) + + # Trainability + # unet + if self.trainable_unet: + self.unet.requires_grad_(True) + self.trainable_module_dic["unet"] = self.unet + logging.debug(f"UNet is set to trainable") + else: + self.unet.requires_grad_(False) + logging.debug(f"UNet is set to frozen") + + # Gradient checkpointing + if enable_gradient_checkpointing: + self.unet.enable_gradient_checkpointing() + self.depth_ae.vae.enable_gradient_checkpointing() + + # Noise scheduler + if "DDPMScheduler" == noise_scheduler_type: + self.noise_scheduler: SchedulerMixin = DDPMScheduler.from_pretrained( + noise_scheduler_pretrained_path["path"], + subfolder=noise_scheduler_pretrained_path["subfolder"], + ) + elif "DDIMScheduler" == noise_scheduler_type: + self.noise_scheduler: SchedulerMixin = DDIMScheduler.from_pretrained( + noise_scheduler_pretrained_path["path"], + subfolder=noise_scheduler_pretrained_path["subfolder"], + ) + elif "PNDMScheduler" == noise_scheduler_type: + self.noise_scheduler: SchedulerMixin = PNDMScheduler.from_pretrained( + noise_scheduler_pretrained_path["path"], + subfolder=noise_scheduler_pretrained_path["subfolder"], + ) + elif "DEISMultistepScheduler" == noise_scheduler_type: + self.noise_scheduler: SchedulerMixin = DEISMultistepScheduler.from_pretrained( + noise_scheduler_pretrained_path["path"], + subfolder=noise_scheduler_pretrained_path["subfolder"], + ) + else: + raise NotImplementedError + + # Text embed for empty prompt (always in CPU) + if empty_text_embed is None: + tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained( + tokenizer_pretrained_path["path"], + subfolder=tokenizer_pretrained_path["subfolder"], + ) + text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained( + text_encoder_pretrained_path["path"], + subfolder=text_encoder_pretrained_path["subfolder"], + ) + with torch.no_grad(): + self.empty_text_embed = self._encode_text( + "", tokenizer, text_encoder + ).detach()#.to(dtype=precision) # [1, 2, 1024] + else: + self.empty_text_embed = empty_text_embed + + def from_pretrained(pretrained_path, **kwargs): + return __class__( + unet_pretrained_path={"path": pretrained_path, "subfolder": "unet"}, + rgb_encoder_pretrained_path={"path": pretrained_path, "subfolder": "vae"}, + depht_ae_pretrained_path={"path": pretrained_path, "subfolder": "vae"}, + noise_scheduler_pretrained_path={ + "path": pretrained_path, + "subfolder": "scheduler", + }, + tokenizer_pretrained_path={ + "path": pretrained_path, + "subfolder": "tokenizer", + }, + text_encoder_pretrained_path={ + "path": pretrained_path, + "subfolder": "text_encoder", + }, + **kwargs, + ) + + def _replace_unet_conv_in(self): + # Replace the first layer to accept 8 in_channels. Only applied when loading pretrained SD U-Net + _weight = self.unet.conv_in.weight.clone() # [320, 4, 3, 3] + _bias = self.unet.conv_in.bias.clone() # [320] + _weight = _weight.repeat((1, 2, 1, 1)) # Keep selected channel(s) + # half the activation magnitude + _weight *= 0.5 + _bias *= 0.5 + # new conv_in channel + _n_convin_out_channel = self.unet.conv_in.out_channels + _new_conv_in = Conv2d( + 8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) + ) + _new_conv_in.weight = Parameter(_weight) + _new_conv_in.bias = Parameter(_bias) + self.unet.conv_in = _new_conv_in + # replace config + self.unet.config["in_channels"] = 8 + return + + def to(self, device): + self.rgb_encoder.to(device) + self.depth_ae.to(device) + self.unet.to(device) + self.empty_text_embed = self.empty_text_embed.to(device) + self.device = device + return self + + def forward( + self, + rgb_in, + num_inference_steps: int = 50, + num_output_inter_results: int = 0, + show_pbar=False, + init_depth_latent=None, + return_depth_latent=False, + ): + device = rgb_in.device + precision = self.unet.dtype + # Set timesteps + self.noise_scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.noise_scheduler.timesteps # [T] + + # Encode image + rgb_latent = self.encode_rgb(rgb_in) + + # Initial depth map (noise) + if init_depth_latent is not None: + init_depth_latent = init_depth_latent.to(dtype=precision) + assert ( + init_depth_latent.shape == rgb_latent.shape + ), "initial depth latent should be the size of [B, 4, H/8, W/8]" + depth_latent = init_depth_latent + depth_latent = torch.randn(rgb_latent.shape, device=device, dtype=precision) + else: + depth_latent = torch.randn(rgb_latent.shape, device=device) # [B, 4, h, w] + + # Expand text embeding for batch + batch_empty_text_embed = self.empty_text_embed.repeat( + (rgb_latent.shape[0], 1, 1) + ).to(device=device, dtype=precision) # [B, 2, 1024] + + # Export intermediate denoising steps + if num_output_inter_results > 0: + depth_latent_ls = [] + inter_steps = [] + _idx = ( + -1 + * ( + np.arange(0, num_output_inter_results) + * num_inference_steps + / num_output_inter_results + ) + .round() + .astype(int) + - 1 + ) + steps_to_output = timesteps[_idx] + + # Denoising loop + if show_pbar: + iterable = tqdm(enumerate(timesteps), total=len(timesteps), leave=False, desc="denoising") + else: + iterable = enumerate(timesteps) + for i, t in iterable: + unet_input = torch.cat( + [rgb_latent, depth_latent], dim=1 + ) # this order is important + unet_input = unet_input.to(dtype=precision) + # predict the noise residual + noise_pred = self.unet( + unet_input, t, encoder_hidden_states=batch_empty_text_embed + ).sample # [B, 4, h, w] + # compute the previous noisy sample x_t -> x_t-1 + depth_latent = self.noise_scheduler.step( + noise_pred, t, depth_latent + ).prev_sample.to(dtype=precision) + + + if num_output_inter_results > 0 and t in steps_to_output: + depth_latent_ls.append(depth_latent.detach().clone()) + #depth_latent_ls = depth_latent_ls.to(dtype=precision) + inter_steps.append(t - 1) + + # Decode depth latent + if num_output_inter_results > 0: + assert 0 in inter_steps + depth = [self.decode_depth(lat) for lat in depth_latent_ls] + if return_depth_latent: + return depth, inter_steps, depth_latent_ls + else: + return depth, inter_steps + else: + depth = self.decode_depth(depth_latent) + if return_depth_latent: + return depth, depth_latent + else: + return depth + + def encode_rgb(self, rgb_in): + rgb_latent = self.rgb_encoder(rgb_in) # [B, 4, h, w] + rgb_latent = rgb_latent * self.rgb_latent_scale_factor + return rgb_latent + + def encode_depth(self, depth_in): + depth_latent = self.depth_ae.encode(depth_in) + depth_latent = depth_latent * self.depth_latent_scale_factor + return depth_latent + + def decode_depth(self, depth_latent): + #depth_latent = depth_latent.to(dtype=torch.float16) + depth_latent = depth_latent / self.depth_latent_scale_factor + depth = self.depth_ae.decode(depth_latent) # [B, 1, H, W] + return depth + + @staticmethod + def _encode_text(prompt, tokenizer, text_encoder): + text_inputs = tokenizer( + prompt, + padding="do_not_pad", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(text_encoder.device) + text_embed = text_encoder(text_input_ids)[0] + return text_embed diff --git a/extensions-builtin/forge_preprocessor_marigold/marigold/model/rgb_encoder.py b/extensions-builtin/forge_preprocessor_marigold/marigold/model/rgb_encoder.py new file mode 100644 index 00000000..ea229471 --- /dev/null +++ b/extensions-builtin/forge_preprocessor_marigold/marigold/model/rgb_encoder.py @@ -0,0 +1,36 @@ +# Author: Bingxin Ke +# Last modified: 2023-12-05 + +import torch +import torch.nn as nn +import logging +from diffusers import AutoencoderKL + + +class RGBEncoder(nn.Module): + """ + The encoder of pretrained Stable Diffusion VAE + """ + + def __init__(self, pretrained_path, subfolder=None) -> None: + super().__init__() + + vae: AutoencoderKL = AutoencoderKL.from_pretrained(pretrained_path, subfolder=subfolder) + logging.info(f"pretrained AutoencoderKL loaded from: {pretrained_path}") + + self.rgb_encoder = nn.Sequential( + vae.encoder, + vae.quant_conv, + ) + + def to(self, *args, **kwargs): + self.rgb_encoder.to(*args, **kwargs) + + def forward(self, rgb_in): + return self.encode(rgb_in) + + def encode(self, rgb_in): + moments = self.rgb_encoder(rgb_in) # [B, 8, H/8, W/8] + mean, logvar = torch.chunk(moments, 2, dim=1) + rgb_latent = mean + return rgb_latent \ No newline at end of file diff --git a/extensions-builtin/forge_preprocessor_marigold/marigold/model/stacked_depth_AE.py b/extensions-builtin/forge_preprocessor_marigold/marigold/model/stacked_depth_AE.py new file mode 100644 index 00000000..9b85155e --- /dev/null +++ b/extensions-builtin/forge_preprocessor_marigold/marigold/model/stacked_depth_AE.py @@ -0,0 +1,52 @@ +# Author: Bingxin Ke +# Last modified: 2023-12-05 + +import torch +import torch.nn as nn +import logging +from diffusers import AutoencoderKL + + +class StackedDepthAE(nn.Module): + """ + Tailored pretrained image VAE for depth map. + Encode: Depth images are repeated into 3 channels. + Decode: The average of 3 chennels are taken as output. + """ + + def __init__(self, pretrained_path, subfolder=None) -> None: + super().__init__() + + self.vae: AutoencoderKL = AutoencoderKL.from_pretrained(pretrained_path, subfolder=subfolder) + logging.info(f"pretrained AutoencoderKL loaded from: {pretrained_path}") + + def forward(self, depth_in): + depth_latent = self.encode(depth_in) + depth_out = self.decode(depth_latent) + return depth_out + + def to(self, *args, **kwargs): + self.vae.to(*args, **kwargs) + + @staticmethod + def _stack_depth_images(depth_in): + if 4 == len(depth_in.shape): + stacked = depth_in.repeat(1, 3, 1, 1) + elif 3 == len(depth_in.shape): + stacked = depth_in.unsqueeze(1) + stacked = depth_in.repeat(1, 3, 1, 1) + return stacked + + def encode(self, depth_in): + stacked = self._stack_depth_images(depth_in) + h = self.vae.encoder(stacked) + moments = self.vae.quant_conv(h) + mean, logvar = torch.chunk(moments, 2, dim=1) + depth_latent = mean + return depth_latent + + def decode(self, depth_latent): + z = self.vae.post_quant_conv(depth_latent) + stacked = self.vae.decoder(z) + depth_mean = stacked.mean(dim=1, keepdim=True) + return depth_mean \ No newline at end of file diff --git a/extensions-builtin/forge_preprocessor_marigold/marigold/util/batchsize.py b/extensions-builtin/forge_preprocessor_marigold/marigold/util/batchsize.py new file mode 100644 index 00000000..7740518e --- /dev/null +++ b/extensions-builtin/forge_preprocessor_marigold/marigold/util/batchsize.py @@ -0,0 +1,38 @@ +# Author: Bingxin Ke +# Last modified: 2023-12-11 + +import torch +import math + + +# Search table for suggested max. inference batch size +bs_search_table = [ + # tested on A100-PCIE-80GB + {"res": 768, "total_vram": 79, "bs": 35}, + {"res": 1024, "total_vram": 79, "bs": 20}, + # tested on A100-PCIE-40GB + {"res": 768, "total_vram": 39, "bs": 15}, + {"res": 1024, "total_vram": 39, "bs": 8}, + # tested on RTX3090, RTX4090 + {"res": 512, "total_vram": 23, "bs": 20}, + {"res": 768, "total_vram": 23, "bs": 7}, + {"res": 1024, "total_vram": 23, "bs": 3}, + # tested on GTX1080Ti + {"res": 512, "total_vram": 10, "bs": 5}, + {"res": 768, "total_vram": 10, "bs": 2}, +] + + + +def find_batch_size(n_repeat, input_res): + total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3 + + for settings in sorted(bs_search_table, key=lambda k: (k['res'], -k['total_vram'])): + if input_res <= settings['res'] and total_vram >= settings['total_vram']: + bs = settings['bs'] + if bs > n_repeat: + bs = n_repeat + elif bs > math.ceil(n_repeat / 2) and bs < n_repeat: + bs = math.ceil(n_repeat / 2) + return bs + return 1 \ No newline at end of file diff --git a/extensions-builtin/forge_preprocessor_marigold/marigold/util/ensemble.py b/extensions-builtin/forge_preprocessor_marigold/marigold/util/ensemble.py new file mode 100644 index 00000000..85c8f772 --- /dev/null +++ b/extensions-builtin/forge_preprocessor_marigold/marigold/util/ensemble.py @@ -0,0 +1,103 @@ +# Test align depth images +# Author: Bingxin Ke +# Last modified: 2023-12-11 + +import numpy as np +import torch + +from scipy.optimize import minimize + +def inter_distances(tensors): + """ + To calculate the distance between each two depth maps. + """ + distances = [] + for i, j in torch.combinations(torch.arange(tensors.shape[0])): + arr1 = tensors[i:i+1] + arr2 = tensors[j:j+1] + distances.append(arr1 - arr2) + dist = torch.concatenate(distances, dim=0) + return dist + + +def ensemble_depths(input_images, regularizer_strength=0.02, max_iter=2, tol=1e-3, reduction='median', max_res=None, disp=False, device='cuda'): + """ + To ensemble multiple affine-invariant depth images (up to scale and shift), + by aligning estimating the scale and shift + """ + device = input_images.device + original_input = input_images.clone() + n_img = input_images.shape[0] + ori_shape = input_images.shape + + if max_res is not None: + scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:])) + if scale_factor < 1: + downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode='nearest') + input_images = downscaler(torch.from_numpy(input_images)).numpy() + + # init guess + _min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) + _max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) + s_init = 1.0 / (_max - _min).reshape((-1, 1, 1)) + t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1)) + x = np.concatenate([s_init, t_init]).reshape(-1) + + input_images = input_images.to(device) + + # objective function + def closure(x): + x = x.astype(np.float32) + l = len(x) + s = x[:int(l/2)] + t = x[int(l/2):] + s = torch.from_numpy(s).to(device) + t = torch.from_numpy(t).to(device) + + transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1)) + dists = inter_distances(transformed_arrays) + sqrt_dist = torch.sqrt(torch.mean(dists**2)) + + if 'mean' == reduction: + pred = torch.mean(transformed_arrays, dim=0) + elif 'median' == reduction: + pred = torch.median(transformed_arrays, dim=0).values + else: + raise ValueError + + near_err = torch.sqrt((0 - torch.min(pred))**2) + far_err = torch.sqrt((1 - torch.max(pred))**2) + + err = sqrt_dist + (near_err + far_err) * regularizer_strength + err = err.detach().cpu().numpy() + return err + + res = minimize(closure, x, method='BFGS', tol=tol, options={'maxiter': max_iter, 'disp': disp}) + x = res.x + l = len(x) + s = x[:int(l/2)] + t = x[int(l/2):] + + # Prediction + s = torch.from_numpy(s).to(device) + t = torch.from_numpy(t).to(device) + transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1) + if 'mean' == reduction: + aligned_images = torch.mean(transformed_arrays, dim=0) + std = torch.std(transformed_arrays, dim=0) + uncertainty = std + elif 'median' == reduction: + aligned_images = torch.median(transformed_arrays, dim=0).values + # MAD (median absolute deviation) as uncertainty indicator + abs_dev = torch.abs(transformed_arrays - aligned_images) + mad = torch.median(abs_dev, dim=0).values + uncertainty = mad + else: + raise ValueError + + # Scale and shift to [0, 1] + _min = torch.min(aligned_images) + _max = torch.max(aligned_images) + aligned_images = (aligned_images - _min) / (_max - _min) + uncertainty /= (_max - _min) + return aligned_images, uncertainty diff --git a/extensions-builtin/forge_preprocessor_marigold/marigold/util/image_util.py b/extensions-builtin/forge_preprocessor_marigold/marigold/util/image_util.py new file mode 100644 index 00000000..6a06d52b --- /dev/null +++ b/extensions-builtin/forge_preprocessor_marigold/marigold/util/image_util.py @@ -0,0 +1,66 @@ + +import matplotlib +import numpy as np +import torch +from PIL import Image + +def colorize_depth_maps(depth_map, min_depth, max_depth, cmap='Spectral', valid_mask=None): + """ + Colorize depth maps. + """ + assert len(depth_map.shape) >= 2, "Invalid dimension" + + if isinstance(depth_map, torch.Tensor): + depth = depth_map.detach().clone().squeeze().numpy() + elif isinstance(depth_map, np.ndarray): + depth = depth_map.copy().squeeze() + # reshape to [ (B,) H, W ] + if depth.ndim < 3: + depth = depth[np.newaxis, :, :] + + # colorize + cm = matplotlib.colormaps[cmap] + depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1) + img_colored_np = cm(depth, bytes=False)[:,:,:,0:3] # value from 0 to 1 + img_colored_np = np.rollaxis(img_colored_np, 3, 1) + + if valid_mask is not None: + if isinstance(depth_map, torch.Tensor): + valid_mask = valid_mask.detach().numpy() + valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W] + if valid_mask.ndim < 3: + valid_mask = valid_mask[np.newaxis, np.newaxis, :, :] + else: + valid_mask = valid_mask[:, np.newaxis, :, :] + valid_mask = np.repeat(valid_mask, 3, axis=1) + img_colored_np[~valid_mask] = 0 + + if isinstance(depth_map, torch.Tensor): + img_colored = torch.from_numpy(img_colored_np).float() + elif isinstance(depth_map, np.ndarray): + img_colored = img_colored_np + + return img_colored + + +def chw2hwc(chw): + assert 3 == len(chw.shape) + if isinstance(chw, torch.Tensor): + hwc = torch.permute(chw, (1, 2, 0)) + elif isinstance(chw, np.ndarray): + hwc = np.moveaxis(chw, 0, -1) + return hwc + + +def resize_max_res(img: Image.Image, max_edge_resolution): + original_width, original_height = img.size + downscale_factor = min(max_edge_resolution / original_width, max_edge_resolution / original_height) + + new_width = int(original_width * downscale_factor) + new_height = int(original_height * downscale_factor) + + resized_img = img.resize((new_width, new_height)) + return resized_img + + + \ No newline at end of file diff --git a/extensions-builtin/forge_preprocessor_marigold/marigold/util/seed_all.py b/extensions-builtin/forge_preprocessor_marigold/marigold/util/seed_all.py new file mode 100644 index 00000000..588ef798 --- /dev/null +++ b/extensions-builtin/forge_preprocessor_marigold/marigold/util/seed_all.py @@ -0,0 +1,14 @@ + +import numpy as np +import random +import torch + + +def seed_all(seed: int = 0): + """ + Set random seeds of all components. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) \ No newline at end of file diff --git a/extensions-builtin/forge_preprocessor_marigold/scripts/preprocessor_marigold.py b/extensions-builtin/forge_preprocessor_marigold/scripts/preprocessor_marigold.py new file mode 100644 index 00000000..97b61666 --- /dev/null +++ b/extensions-builtin/forge_preprocessor_marigold/scripts/preprocessor_marigold.py @@ -0,0 +1,75 @@ +from modules_forge.supported_preprocessor import Preprocessor, PreprocessorParameter +from modules_forge.shared import preprocessor_dir, add_supported_preprocessor +from modules_forge.forge_util import resize_image_with_pad + + +import os +import torch +import numpy as np + +from marigold.model.marigold_pipeline import MarigoldPipeline +from einops import rearrange +from huggingface_hub import snapshot_download +from modules_forge.diffusers_patcher import DiffusersModelPatcher +from modules_forge.forge_util import numpy_to_pytorch +from ldm_patched.modules import model_management + + +class PreprocessorMarigold(Preprocessor): + def __init__(self): + super().__init__() + self.name = 'depth_marigold' + self.tags = ['Depth'] + self.model_filename_filters = ['depth'] + self.slider_resolution = PreprocessorParameter( + label='Resolution', minimum=128, maximum=2048, value=768, step=8, visible=True) + self.slider_1 = PreprocessorParameter(visible=False) + self.slider_2 = PreprocessorParameter(visible=False) + self.slider_3 = PreprocessorParameter(visible=False) + self.show_control_mode = True + self.do_not_need_model = False + self.sorting_priority = 100 # higher goes to top in the list + self.diffusers_patcher = None + + def load_model(self): + if self.model_patcher is not None: + return + + checkpoint_path = os.path.join(preprocessor_dir, 'marigold') + + if not os.path.exists(checkpoint_path): + snapshot_download(repo_id="Bingxin/Marigold", + ignore_patterns=["*.bin"], + local_dir=checkpoint_path, + local_dir_use_symlinks=False) + + self.diffusers_patcher = DiffusersModelPatcher( + pipeline_class=MarigoldPipeline, + pretrained_path=checkpoint_path, + enable_xformers=False, + noise_scheduler_type='DDIMScheduler') + + return + + def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider_3=None, **kwargs): + input_image, remove_pad = resize_image_with_pad(input_image, resolution) + + self.load_model() + + model_management.load_models_gpu([self.diffusers_patcher.patcher]) + + with torch.no_grad(): + img = numpy_to_pytorch(input_image).movedim(-1, 1).to( + device=self.diffusers_patcher.patcher.current_device, + dtype=self.diffusers_patcher.dtype) + + img = img * 2.0 - 1.0 + depth = self.diffusers_patcher.patcher.model(img, num_inference_steps=20, show_pbar=False) + depth = depth * 0.5 + 0.5 + depth = depth.movedim(1, -1)[0].cpu().numpy() + depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8) + + return remove_pad(depth_image) + + +add_supported_preprocessor(PreprocessorMarigold()) diff --git a/modules_forge/diffusers_patcher.py b/modules_forge/diffusers_patcher.py index dfcad136..531ed8d0 100644 --- a/modules_forge/diffusers_patcher.py +++ b/modules_forge/diffusers_patcher.py @@ -27,7 +27,11 @@ class DiffusersModelPatcher: self.pipeline.unet.set_attn_processor(AttnProcessor2_0()) print('Attention optimization applied to DiffusersModelPatcher') - self.pipeline = self.pipeline.to(device=offload_device, dtype=dtype) + self.pipeline = self.pipeline.to(device=offload_device) + + if self.dtype == torch.float16: + self.pipeline = self.pipeline.half() + self.pipeline.eval() self.patcher = ModelPatcher( diff --git a/requirements_versions.txt b/requirements_versions.txt index 05c2908d..fe217969 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -29,3 +29,4 @@ torchsde==0.2.6 transformers==4.30.2 httpx==0.24.1 basicsr==1.4.2 +diffusers==0.25.0