marigold_ini

This commit is contained in:
lllyasviel 2024-02-02 00:14:37 -08:00
parent 24db0e241a
commit 7d7094b6ec
11 changed files with 703 additions and 1 deletions

View File

@ -0,0 +1,313 @@
# Author: Bingxin Ke
# Last modified: 2023-12-11
import logging
from typing import Dict
import numpy as np
import torch
from diffusers import (
DDIMScheduler,
DDPMScheduler,
PNDMScheduler,
DEISMultistepScheduler,
SchedulerMixin,
UNet2DConditionModel,
)
from torch import nn
from torch.nn import Conv2d
from torch.nn.parameter import Parameter
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from .rgb_encoder import RGBEncoder
from .stacked_depth_AE import StackedDepthAE
class MarigoldPipeline(nn.Module):
"""
Marigold monocular depth estimator.
"""
def __init__(
self,
unet_pretrained_path: Dict, # {path: xxx, subfolder: xxx}
rgb_encoder_pretrained_path: Dict,
depht_ae_pretrained_path: Dict,
noise_scheduler_pretrained_path: Dict,
tokenizer_pretrained_path: Dict,
text_encoder_pretrained_path: Dict,
empty_text_embed=None,
trainable_unet=False,
rgb_latent_scale_factor=0.18215,
depth_latent_scale_factor=0.18215,
noise_scheduler_type=None,
enable_gradient_checkpointing=False,
enable_xformers=True,
) -> None:
super().__init__()
self.rgb_latent_scale_factor = rgb_latent_scale_factor
self.depth_latent_scale_factor = depth_latent_scale_factor
self.device = "cpu"
# ******* Initialize modules *******
# Trainable modules
self.trainable_module_dic: Dict[str, nn.Module] = {}
self.trainable_unet = trainable_unet
# Denoising UNet
self.unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(
unet_pretrained_path["path"], subfolder=unet_pretrained_path["subfolder"]
)
logging.info(f"pretrained UNet loaded from: {unet_pretrained_path}")
if 8 != self.unet.config["in_channels"]:
self._replace_unet_conv_in()
logging.warning("Unet conv_in layer is replaced")
if enable_xformers:
self.unet.enable_xformers_memory_efficient_attention()
else:
self.unet.disable_xformers_memory_efficient_attention()
# Image encoder
self.rgb_encoder = RGBEncoder(
pretrained_path=rgb_encoder_pretrained_path["path"],
subfolder=rgb_encoder_pretrained_path["subfolder"],
)
logging.info(
f"pretrained RGBEncoder loaded from: {rgb_encoder_pretrained_path}"
)
self.rgb_encoder.requires_grad_(False)
# Depth encoder-decoder
self.depth_ae = StackedDepthAE(
pretrained_path=depht_ae_pretrained_path["path"],
subfolder=depht_ae_pretrained_path["subfolder"],
)
logging.info(
f"pretrained Depth Autoencoder loaded from: {rgb_encoder_pretrained_path}"
)
# Trainability
# unet
if self.trainable_unet:
self.unet.requires_grad_(True)
self.trainable_module_dic["unet"] = self.unet
logging.debug(f"UNet is set to trainable")
else:
self.unet.requires_grad_(False)
logging.debug(f"UNet is set to frozen")
# Gradient checkpointing
if enable_gradient_checkpointing:
self.unet.enable_gradient_checkpointing()
self.depth_ae.vae.enable_gradient_checkpointing()
# Noise scheduler
if "DDPMScheduler" == noise_scheduler_type:
self.noise_scheduler: SchedulerMixin = DDPMScheduler.from_pretrained(
noise_scheduler_pretrained_path["path"],
subfolder=noise_scheduler_pretrained_path["subfolder"],
)
elif "DDIMScheduler" == noise_scheduler_type:
self.noise_scheduler: SchedulerMixin = DDIMScheduler.from_pretrained(
noise_scheduler_pretrained_path["path"],
subfolder=noise_scheduler_pretrained_path["subfolder"],
)
elif "PNDMScheduler" == noise_scheduler_type:
self.noise_scheduler: SchedulerMixin = PNDMScheduler.from_pretrained(
noise_scheduler_pretrained_path["path"],
subfolder=noise_scheduler_pretrained_path["subfolder"],
)
elif "DEISMultistepScheduler" == noise_scheduler_type:
self.noise_scheduler: SchedulerMixin = DEISMultistepScheduler.from_pretrained(
noise_scheduler_pretrained_path["path"],
subfolder=noise_scheduler_pretrained_path["subfolder"],
)
else:
raise NotImplementedError
# Text embed for empty prompt (always in CPU)
if empty_text_embed is None:
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
tokenizer_pretrained_path["path"],
subfolder=tokenizer_pretrained_path["subfolder"],
)
text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained(
text_encoder_pretrained_path["path"],
subfolder=text_encoder_pretrained_path["subfolder"],
)
with torch.no_grad():
self.empty_text_embed = self._encode_text(
"", tokenizer, text_encoder
).detach()#.to(dtype=precision) # [1, 2, 1024]
else:
self.empty_text_embed = empty_text_embed
def from_pretrained(pretrained_path, **kwargs):
return __class__(
unet_pretrained_path={"path": pretrained_path, "subfolder": "unet"},
rgb_encoder_pretrained_path={"path": pretrained_path, "subfolder": "vae"},
depht_ae_pretrained_path={"path": pretrained_path, "subfolder": "vae"},
noise_scheduler_pretrained_path={
"path": pretrained_path,
"subfolder": "scheduler",
},
tokenizer_pretrained_path={
"path": pretrained_path,
"subfolder": "tokenizer",
},
text_encoder_pretrained_path={
"path": pretrained_path,
"subfolder": "text_encoder",
},
**kwargs,
)
def _replace_unet_conv_in(self):
# Replace the first layer to accept 8 in_channels. Only applied when loading pretrained SD U-Net
_weight = self.unet.conv_in.weight.clone() # [320, 4, 3, 3]
_bias = self.unet.conv_in.bias.clone() # [320]
_weight = _weight.repeat((1, 2, 1, 1)) # Keep selected channel(s)
# half the activation magnitude
_weight *= 0.5
_bias *= 0.5
# new conv_in channel
_n_convin_out_channel = self.unet.conv_in.out_channels
_new_conv_in = Conv2d(
8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
)
_new_conv_in.weight = Parameter(_weight)
_new_conv_in.bias = Parameter(_bias)
self.unet.conv_in = _new_conv_in
# replace config
self.unet.config["in_channels"] = 8
return
def to(self, device):
self.rgb_encoder.to(device)
self.depth_ae.to(device)
self.unet.to(device)
self.empty_text_embed = self.empty_text_embed.to(device)
self.device = device
return self
def forward(
self,
rgb_in,
num_inference_steps: int = 50,
num_output_inter_results: int = 0,
show_pbar=False,
init_depth_latent=None,
return_depth_latent=False,
):
device = rgb_in.device
precision = self.unet.dtype
# Set timesteps
self.noise_scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.noise_scheduler.timesteps # [T]
# Encode image
rgb_latent = self.encode_rgb(rgb_in)
# Initial depth map (noise)
if init_depth_latent is not None:
init_depth_latent = init_depth_latent.to(dtype=precision)
assert (
init_depth_latent.shape == rgb_latent.shape
), "initial depth latent should be the size of [B, 4, H/8, W/8]"
depth_latent = init_depth_latent
depth_latent = torch.randn(rgb_latent.shape, device=device, dtype=precision)
else:
depth_latent = torch.randn(rgb_latent.shape, device=device) # [B, 4, h, w]
# Expand text embeding for batch
batch_empty_text_embed = self.empty_text_embed.repeat(
(rgb_latent.shape[0], 1, 1)
).to(device=device, dtype=precision) # [B, 2, 1024]
# Export intermediate denoising steps
if num_output_inter_results > 0:
depth_latent_ls = []
inter_steps = []
_idx = (
-1
* (
np.arange(0, num_output_inter_results)
* num_inference_steps
/ num_output_inter_results
)
.round()
.astype(int)
- 1
)
steps_to_output = timesteps[_idx]
# Denoising loop
if show_pbar:
iterable = tqdm(enumerate(timesteps), total=len(timesteps), leave=False, desc="denoising")
else:
iterable = enumerate(timesteps)
for i, t in iterable:
unet_input = torch.cat(
[rgb_latent, depth_latent], dim=1
) # this order is important
unet_input = unet_input.to(dtype=precision)
# predict the noise residual
noise_pred = self.unet(
unet_input, t, encoder_hidden_states=batch_empty_text_embed
).sample # [B, 4, h, w]
# compute the previous noisy sample x_t -> x_t-1
depth_latent = self.noise_scheduler.step(
noise_pred, t, depth_latent
).prev_sample.to(dtype=precision)
if num_output_inter_results > 0 and t in steps_to_output:
depth_latent_ls.append(depth_latent.detach().clone())
#depth_latent_ls = depth_latent_ls.to(dtype=precision)
inter_steps.append(t - 1)
# Decode depth latent
if num_output_inter_results > 0:
assert 0 in inter_steps
depth = [self.decode_depth(lat) for lat in depth_latent_ls]
if return_depth_latent:
return depth, inter_steps, depth_latent_ls
else:
return depth, inter_steps
else:
depth = self.decode_depth(depth_latent)
if return_depth_latent:
return depth, depth_latent
else:
return depth
def encode_rgb(self, rgb_in):
rgb_latent = self.rgb_encoder(rgb_in) # [B, 4, h, w]
rgb_latent = rgb_latent * self.rgb_latent_scale_factor
return rgb_latent
def encode_depth(self, depth_in):
depth_latent = self.depth_ae.encode(depth_in)
depth_latent = depth_latent * self.depth_latent_scale_factor
return depth_latent
def decode_depth(self, depth_latent):
#depth_latent = depth_latent.to(dtype=torch.float16)
depth_latent = depth_latent / self.depth_latent_scale_factor
depth = self.depth_ae.decode(depth_latent) # [B, 1, H, W]
return depth
@staticmethod
def _encode_text(prompt, tokenizer, text_encoder):
text_inputs = tokenizer(
prompt,
padding="do_not_pad",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(text_encoder.device)
text_embed = text_encoder(text_input_ids)[0]
return text_embed

View File

@ -0,0 +1,36 @@
# Author: Bingxin Ke
# Last modified: 2023-12-05
import torch
import torch.nn as nn
import logging
from diffusers import AutoencoderKL
class RGBEncoder(nn.Module):
"""
The encoder of pretrained Stable Diffusion VAE
"""
def __init__(self, pretrained_path, subfolder=None) -> None:
super().__init__()
vae: AutoencoderKL = AutoencoderKL.from_pretrained(pretrained_path, subfolder=subfolder)
logging.info(f"pretrained AutoencoderKL loaded from: {pretrained_path}")
self.rgb_encoder = nn.Sequential(
vae.encoder,
vae.quant_conv,
)
def to(self, *args, **kwargs):
self.rgb_encoder.to(*args, **kwargs)
def forward(self, rgb_in):
return self.encode(rgb_in)
def encode(self, rgb_in):
moments = self.rgb_encoder(rgb_in) # [B, 8, H/8, W/8]
mean, logvar = torch.chunk(moments, 2, dim=1)
rgb_latent = mean
return rgb_latent

View File

@ -0,0 +1,52 @@
# Author: Bingxin Ke
# Last modified: 2023-12-05
import torch
import torch.nn as nn
import logging
from diffusers import AutoencoderKL
class StackedDepthAE(nn.Module):
"""
Tailored pretrained image VAE for depth map.
Encode: Depth images are repeated into 3 channels.
Decode: The average of 3 chennels are taken as output.
"""
def __init__(self, pretrained_path, subfolder=None) -> None:
super().__init__()
self.vae: AutoencoderKL = AutoencoderKL.from_pretrained(pretrained_path, subfolder=subfolder)
logging.info(f"pretrained AutoencoderKL loaded from: {pretrained_path}")
def forward(self, depth_in):
depth_latent = self.encode(depth_in)
depth_out = self.decode(depth_latent)
return depth_out
def to(self, *args, **kwargs):
self.vae.to(*args, **kwargs)
@staticmethod
def _stack_depth_images(depth_in):
if 4 == len(depth_in.shape):
stacked = depth_in.repeat(1, 3, 1, 1)
elif 3 == len(depth_in.shape):
stacked = depth_in.unsqueeze(1)
stacked = depth_in.repeat(1, 3, 1, 1)
return stacked
def encode(self, depth_in):
stacked = self._stack_depth_images(depth_in)
h = self.vae.encoder(stacked)
moments = self.vae.quant_conv(h)
mean, logvar = torch.chunk(moments, 2, dim=1)
depth_latent = mean
return depth_latent
def decode(self, depth_latent):
z = self.vae.post_quant_conv(depth_latent)
stacked = self.vae.decoder(z)
depth_mean = stacked.mean(dim=1, keepdim=True)
return depth_mean

View File

@ -0,0 +1,38 @@
# Author: Bingxin Ke
# Last modified: 2023-12-11
import torch
import math
# Search table for suggested max. inference batch size
bs_search_table = [
# tested on A100-PCIE-80GB
{"res": 768, "total_vram": 79, "bs": 35},
{"res": 1024, "total_vram": 79, "bs": 20},
# tested on A100-PCIE-40GB
{"res": 768, "total_vram": 39, "bs": 15},
{"res": 1024, "total_vram": 39, "bs": 8},
# tested on RTX3090, RTX4090
{"res": 512, "total_vram": 23, "bs": 20},
{"res": 768, "total_vram": 23, "bs": 7},
{"res": 1024, "total_vram": 23, "bs": 3},
# tested on GTX1080Ti
{"res": 512, "total_vram": 10, "bs": 5},
{"res": 768, "total_vram": 10, "bs": 2},
]
def find_batch_size(n_repeat, input_res):
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
for settings in sorted(bs_search_table, key=lambda k: (k['res'], -k['total_vram'])):
if input_res <= settings['res'] and total_vram >= settings['total_vram']:
bs = settings['bs']
if bs > n_repeat:
bs = n_repeat
elif bs > math.ceil(n_repeat / 2) and bs < n_repeat:
bs = math.ceil(n_repeat / 2)
return bs
return 1

View File

@ -0,0 +1,103 @@
# Test align depth images
# Author: Bingxin Ke
# Last modified: 2023-12-11
import numpy as np
import torch
from scipy.optimize import minimize
def inter_distances(tensors):
"""
To calculate the distance between each two depth maps.
"""
distances = []
for i, j in torch.combinations(torch.arange(tensors.shape[0])):
arr1 = tensors[i:i+1]
arr2 = tensors[j:j+1]
distances.append(arr1 - arr2)
dist = torch.concatenate(distances, dim=0)
return dist
def ensemble_depths(input_images, regularizer_strength=0.02, max_iter=2, tol=1e-3, reduction='median', max_res=None, disp=False, device='cuda'):
"""
To ensemble multiple affine-invariant depth images (up to scale and shift),
by aligning estimating the scale and shift
"""
device = input_images.device
original_input = input_images.clone()
n_img = input_images.shape[0]
ori_shape = input_images.shape
if max_res is not None:
scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:]))
if scale_factor < 1:
downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode='nearest')
input_images = downscaler(torch.from_numpy(input_images)).numpy()
# init guess
_min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
_max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
s_init = 1.0 / (_max - _min).reshape((-1, 1, 1))
t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1))
x = np.concatenate([s_init, t_init]).reshape(-1)
input_images = input_images.to(device)
# objective function
def closure(x):
x = x.astype(np.float32)
l = len(x)
s = x[:int(l/2)]
t = x[int(l/2):]
s = torch.from_numpy(s).to(device)
t = torch.from_numpy(t).to(device)
transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1))
dists = inter_distances(transformed_arrays)
sqrt_dist = torch.sqrt(torch.mean(dists**2))
if 'mean' == reduction:
pred = torch.mean(transformed_arrays, dim=0)
elif 'median' == reduction:
pred = torch.median(transformed_arrays, dim=0).values
else:
raise ValueError
near_err = torch.sqrt((0 - torch.min(pred))**2)
far_err = torch.sqrt((1 - torch.max(pred))**2)
err = sqrt_dist + (near_err + far_err) * regularizer_strength
err = err.detach().cpu().numpy()
return err
res = minimize(closure, x, method='BFGS', tol=tol, options={'maxiter': max_iter, 'disp': disp})
x = res.x
l = len(x)
s = x[:int(l/2)]
t = x[int(l/2):]
# Prediction
s = torch.from_numpy(s).to(device)
t = torch.from_numpy(t).to(device)
transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1)
if 'mean' == reduction:
aligned_images = torch.mean(transformed_arrays, dim=0)
std = torch.std(transformed_arrays, dim=0)
uncertainty = std
elif 'median' == reduction:
aligned_images = torch.median(transformed_arrays, dim=0).values
# MAD (median absolute deviation) as uncertainty indicator
abs_dev = torch.abs(transformed_arrays - aligned_images)
mad = torch.median(abs_dev, dim=0).values
uncertainty = mad
else:
raise ValueError
# Scale and shift to [0, 1]
_min = torch.min(aligned_images)
_max = torch.max(aligned_images)
aligned_images = (aligned_images - _min) / (_max - _min)
uncertainty /= (_max - _min)
return aligned_images, uncertainty

View File

@ -0,0 +1,66 @@
import matplotlib
import numpy as np
import torch
from PIL import Image
def colorize_depth_maps(depth_map, min_depth, max_depth, cmap='Spectral', valid_mask=None):
"""
Colorize depth maps.
"""
assert len(depth_map.shape) >= 2, "Invalid dimension"
if isinstance(depth_map, torch.Tensor):
depth = depth_map.detach().clone().squeeze().numpy()
elif isinstance(depth_map, np.ndarray):
depth = depth_map.copy().squeeze()
# reshape to [ (B,) H, W ]
if depth.ndim < 3:
depth = depth[np.newaxis, :, :]
# colorize
cm = matplotlib.colormaps[cmap]
depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
img_colored_np = cm(depth, bytes=False)[:,:,:,0:3] # value from 0 to 1
img_colored_np = np.rollaxis(img_colored_np, 3, 1)
if valid_mask is not None:
if isinstance(depth_map, torch.Tensor):
valid_mask = valid_mask.detach().numpy()
valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
if valid_mask.ndim < 3:
valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
else:
valid_mask = valid_mask[:, np.newaxis, :, :]
valid_mask = np.repeat(valid_mask, 3, axis=1)
img_colored_np[~valid_mask] = 0
if isinstance(depth_map, torch.Tensor):
img_colored = torch.from_numpy(img_colored_np).float()
elif isinstance(depth_map, np.ndarray):
img_colored = img_colored_np
return img_colored
def chw2hwc(chw):
assert 3 == len(chw.shape)
if isinstance(chw, torch.Tensor):
hwc = torch.permute(chw, (1, 2, 0))
elif isinstance(chw, np.ndarray):
hwc = np.moveaxis(chw, 0, -1)
return hwc
def resize_max_res(img: Image.Image, max_edge_resolution):
original_width, original_height = img.size
downscale_factor = min(max_edge_resolution / original_width, max_edge_resolution / original_height)
new_width = int(original_width * downscale_factor)
new_height = int(original_height * downscale_factor)
resized_img = img.resize((new_width, new_height))
return resized_img

View File

@ -0,0 +1,14 @@
import numpy as np
import random
import torch
def seed_all(seed: int = 0):
"""
Set random seeds of all components.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

View File

@ -0,0 +1,75 @@
from modules_forge.supported_preprocessor import Preprocessor, PreprocessorParameter
from modules_forge.shared import preprocessor_dir, add_supported_preprocessor
from modules_forge.forge_util import resize_image_with_pad
import os
import torch
import numpy as np
from marigold.model.marigold_pipeline import MarigoldPipeline
from einops import rearrange
from huggingface_hub import snapshot_download
from modules_forge.diffusers_patcher import DiffusersModelPatcher
from modules_forge.forge_util import numpy_to_pytorch
from ldm_patched.modules import model_management
class PreprocessorMarigold(Preprocessor):
def __init__(self):
super().__init__()
self.name = 'depth_marigold'
self.tags = ['Depth']
self.model_filename_filters = ['depth']
self.slider_resolution = PreprocessorParameter(
label='Resolution', minimum=128, maximum=2048, value=768, step=8, visible=True)
self.slider_1 = PreprocessorParameter(visible=False)
self.slider_2 = PreprocessorParameter(visible=False)
self.slider_3 = PreprocessorParameter(visible=False)
self.show_control_mode = True
self.do_not_need_model = False
self.sorting_priority = 100 # higher goes to top in the list
self.diffusers_patcher = None
def load_model(self):
if self.model_patcher is not None:
return
checkpoint_path = os.path.join(preprocessor_dir, 'marigold')
if not os.path.exists(checkpoint_path):
snapshot_download(repo_id="Bingxin/Marigold",
ignore_patterns=["*.bin"],
local_dir=checkpoint_path,
local_dir_use_symlinks=False)
self.diffusers_patcher = DiffusersModelPatcher(
pipeline_class=MarigoldPipeline,
pretrained_path=checkpoint_path,
enable_xformers=False,
noise_scheduler_type='DDIMScheduler')
return
def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider_3=None, **kwargs):
input_image, remove_pad = resize_image_with_pad(input_image, resolution)
self.load_model()
model_management.load_models_gpu([self.diffusers_patcher.patcher])
with torch.no_grad():
img = numpy_to_pytorch(input_image).movedim(-1, 1).to(
device=self.diffusers_patcher.patcher.current_device,
dtype=self.diffusers_patcher.dtype)
img = img * 2.0 - 1.0
depth = self.diffusers_patcher.patcher.model(img, num_inference_steps=20, show_pbar=False)
depth = depth * 0.5 + 0.5
depth = depth.movedim(1, -1)[0].cpu().numpy()
depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8)
return remove_pad(depth_image)
add_supported_preprocessor(PreprocessorMarigold())

View File

@ -27,7 +27,11 @@ class DiffusersModelPatcher:
self.pipeline.unet.set_attn_processor(AttnProcessor2_0()) self.pipeline.unet.set_attn_processor(AttnProcessor2_0())
print('Attention optimization applied to DiffusersModelPatcher') print('Attention optimization applied to DiffusersModelPatcher')
self.pipeline = self.pipeline.to(device=offload_device, dtype=dtype) self.pipeline = self.pipeline.to(device=offload_device)
if self.dtype == torch.float16:
self.pipeline = self.pipeline.half()
self.pipeline.eval() self.pipeline.eval()
self.patcher = ModelPatcher( self.patcher = ModelPatcher(

View File

@ -29,3 +29,4 @@ torchsde==0.2.6
transformers==4.30.2 transformers==4.30.2
httpx==0.24.1 httpx==0.24.1
basicsr==1.4.2 basicsr==1.4.2
diffusers==0.25.0