marigold_ini
This commit is contained in:
parent
24db0e241a
commit
7d7094b6ec
@ -0,0 +1,313 @@
|
|||||||
|
# Author: Bingxin Ke
|
||||||
|
# Last modified: 2023-12-11
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from diffusers import (
|
||||||
|
DDIMScheduler,
|
||||||
|
DDPMScheduler,
|
||||||
|
PNDMScheduler,
|
||||||
|
DEISMultistepScheduler,
|
||||||
|
SchedulerMixin,
|
||||||
|
UNet2DConditionModel,
|
||||||
|
)
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import Conv2d
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
from .rgb_encoder import RGBEncoder
|
||||||
|
from .stacked_depth_AE import StackedDepthAE
|
||||||
|
|
||||||
|
|
||||||
|
class MarigoldPipeline(nn.Module):
|
||||||
|
"""
|
||||||
|
Marigold monocular depth estimator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
unet_pretrained_path: Dict, # {path: xxx, subfolder: xxx}
|
||||||
|
rgb_encoder_pretrained_path: Dict,
|
||||||
|
depht_ae_pretrained_path: Dict,
|
||||||
|
noise_scheduler_pretrained_path: Dict,
|
||||||
|
tokenizer_pretrained_path: Dict,
|
||||||
|
text_encoder_pretrained_path: Dict,
|
||||||
|
empty_text_embed=None,
|
||||||
|
trainable_unet=False,
|
||||||
|
rgb_latent_scale_factor=0.18215,
|
||||||
|
depth_latent_scale_factor=0.18215,
|
||||||
|
noise_scheduler_type=None,
|
||||||
|
enable_gradient_checkpointing=False,
|
||||||
|
enable_xformers=True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.rgb_latent_scale_factor = rgb_latent_scale_factor
|
||||||
|
self.depth_latent_scale_factor = depth_latent_scale_factor
|
||||||
|
self.device = "cpu"
|
||||||
|
|
||||||
|
# ******* Initialize modules *******
|
||||||
|
# Trainable modules
|
||||||
|
self.trainable_module_dic: Dict[str, nn.Module] = {}
|
||||||
|
self.trainable_unet = trainable_unet
|
||||||
|
|
||||||
|
# Denoising UNet
|
||||||
|
self.unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(
|
||||||
|
unet_pretrained_path["path"], subfolder=unet_pretrained_path["subfolder"]
|
||||||
|
)
|
||||||
|
logging.info(f"pretrained UNet loaded from: {unet_pretrained_path}")
|
||||||
|
if 8 != self.unet.config["in_channels"]:
|
||||||
|
self._replace_unet_conv_in()
|
||||||
|
logging.warning("Unet conv_in layer is replaced")
|
||||||
|
if enable_xformers:
|
||||||
|
self.unet.enable_xformers_memory_efficient_attention()
|
||||||
|
else:
|
||||||
|
self.unet.disable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
|
# Image encoder
|
||||||
|
self.rgb_encoder = RGBEncoder(
|
||||||
|
pretrained_path=rgb_encoder_pretrained_path["path"],
|
||||||
|
subfolder=rgb_encoder_pretrained_path["subfolder"],
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
f"pretrained RGBEncoder loaded from: {rgb_encoder_pretrained_path}"
|
||||||
|
)
|
||||||
|
self.rgb_encoder.requires_grad_(False)
|
||||||
|
|
||||||
|
# Depth encoder-decoder
|
||||||
|
self.depth_ae = StackedDepthAE(
|
||||||
|
pretrained_path=depht_ae_pretrained_path["path"],
|
||||||
|
subfolder=depht_ae_pretrained_path["subfolder"],
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
f"pretrained Depth Autoencoder loaded from: {rgb_encoder_pretrained_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Trainability
|
||||||
|
# unet
|
||||||
|
if self.trainable_unet:
|
||||||
|
self.unet.requires_grad_(True)
|
||||||
|
self.trainable_module_dic["unet"] = self.unet
|
||||||
|
logging.debug(f"UNet is set to trainable")
|
||||||
|
else:
|
||||||
|
self.unet.requires_grad_(False)
|
||||||
|
logging.debug(f"UNet is set to frozen")
|
||||||
|
|
||||||
|
# Gradient checkpointing
|
||||||
|
if enable_gradient_checkpointing:
|
||||||
|
self.unet.enable_gradient_checkpointing()
|
||||||
|
self.depth_ae.vae.enable_gradient_checkpointing()
|
||||||
|
|
||||||
|
# Noise scheduler
|
||||||
|
if "DDPMScheduler" == noise_scheduler_type:
|
||||||
|
self.noise_scheduler: SchedulerMixin = DDPMScheduler.from_pretrained(
|
||||||
|
noise_scheduler_pretrained_path["path"],
|
||||||
|
subfolder=noise_scheduler_pretrained_path["subfolder"],
|
||||||
|
)
|
||||||
|
elif "DDIMScheduler" == noise_scheduler_type:
|
||||||
|
self.noise_scheduler: SchedulerMixin = DDIMScheduler.from_pretrained(
|
||||||
|
noise_scheduler_pretrained_path["path"],
|
||||||
|
subfolder=noise_scheduler_pretrained_path["subfolder"],
|
||||||
|
)
|
||||||
|
elif "PNDMScheduler" == noise_scheduler_type:
|
||||||
|
self.noise_scheduler: SchedulerMixin = PNDMScheduler.from_pretrained(
|
||||||
|
noise_scheduler_pretrained_path["path"],
|
||||||
|
subfolder=noise_scheduler_pretrained_path["subfolder"],
|
||||||
|
)
|
||||||
|
elif "DEISMultistepScheduler" == noise_scheduler_type:
|
||||||
|
self.noise_scheduler: SchedulerMixin = DEISMultistepScheduler.from_pretrained(
|
||||||
|
noise_scheduler_pretrained_path["path"],
|
||||||
|
subfolder=noise_scheduler_pretrained_path["subfolder"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# Text embed for empty prompt (always in CPU)
|
||||||
|
if empty_text_embed is None:
|
||||||
|
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
|
||||||
|
tokenizer_pretrained_path["path"],
|
||||||
|
subfolder=tokenizer_pretrained_path["subfolder"],
|
||||||
|
)
|
||||||
|
text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained(
|
||||||
|
text_encoder_pretrained_path["path"],
|
||||||
|
subfolder=text_encoder_pretrained_path["subfolder"],
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
self.empty_text_embed = self._encode_text(
|
||||||
|
"", tokenizer, text_encoder
|
||||||
|
).detach()#.to(dtype=precision) # [1, 2, 1024]
|
||||||
|
else:
|
||||||
|
self.empty_text_embed = empty_text_embed
|
||||||
|
|
||||||
|
def from_pretrained(pretrained_path, **kwargs):
|
||||||
|
return __class__(
|
||||||
|
unet_pretrained_path={"path": pretrained_path, "subfolder": "unet"},
|
||||||
|
rgb_encoder_pretrained_path={"path": pretrained_path, "subfolder": "vae"},
|
||||||
|
depht_ae_pretrained_path={"path": pretrained_path, "subfolder": "vae"},
|
||||||
|
noise_scheduler_pretrained_path={
|
||||||
|
"path": pretrained_path,
|
||||||
|
"subfolder": "scheduler",
|
||||||
|
},
|
||||||
|
tokenizer_pretrained_path={
|
||||||
|
"path": pretrained_path,
|
||||||
|
"subfolder": "tokenizer",
|
||||||
|
},
|
||||||
|
text_encoder_pretrained_path={
|
||||||
|
"path": pretrained_path,
|
||||||
|
"subfolder": "text_encoder",
|
||||||
|
},
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _replace_unet_conv_in(self):
|
||||||
|
# Replace the first layer to accept 8 in_channels. Only applied when loading pretrained SD U-Net
|
||||||
|
_weight = self.unet.conv_in.weight.clone() # [320, 4, 3, 3]
|
||||||
|
_bias = self.unet.conv_in.bias.clone() # [320]
|
||||||
|
_weight = _weight.repeat((1, 2, 1, 1)) # Keep selected channel(s)
|
||||||
|
# half the activation magnitude
|
||||||
|
_weight *= 0.5
|
||||||
|
_bias *= 0.5
|
||||||
|
# new conv_in channel
|
||||||
|
_n_convin_out_channel = self.unet.conv_in.out_channels
|
||||||
|
_new_conv_in = Conv2d(
|
||||||
|
8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
|
||||||
|
)
|
||||||
|
_new_conv_in.weight = Parameter(_weight)
|
||||||
|
_new_conv_in.bias = Parameter(_bias)
|
||||||
|
self.unet.conv_in = _new_conv_in
|
||||||
|
# replace config
|
||||||
|
self.unet.config["in_channels"] = 8
|
||||||
|
return
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
self.rgb_encoder.to(device)
|
||||||
|
self.depth_ae.to(device)
|
||||||
|
self.unet.to(device)
|
||||||
|
self.empty_text_embed = self.empty_text_embed.to(device)
|
||||||
|
self.device = device
|
||||||
|
return self
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
rgb_in,
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
num_output_inter_results: int = 0,
|
||||||
|
show_pbar=False,
|
||||||
|
init_depth_latent=None,
|
||||||
|
return_depth_latent=False,
|
||||||
|
):
|
||||||
|
device = rgb_in.device
|
||||||
|
precision = self.unet.dtype
|
||||||
|
# Set timesteps
|
||||||
|
self.noise_scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
|
timesteps = self.noise_scheduler.timesteps # [T]
|
||||||
|
|
||||||
|
# Encode image
|
||||||
|
rgb_latent = self.encode_rgb(rgb_in)
|
||||||
|
|
||||||
|
# Initial depth map (noise)
|
||||||
|
if init_depth_latent is not None:
|
||||||
|
init_depth_latent = init_depth_latent.to(dtype=precision)
|
||||||
|
assert (
|
||||||
|
init_depth_latent.shape == rgb_latent.shape
|
||||||
|
), "initial depth latent should be the size of [B, 4, H/8, W/8]"
|
||||||
|
depth_latent = init_depth_latent
|
||||||
|
depth_latent = torch.randn(rgb_latent.shape, device=device, dtype=precision)
|
||||||
|
else:
|
||||||
|
depth_latent = torch.randn(rgb_latent.shape, device=device) # [B, 4, h, w]
|
||||||
|
|
||||||
|
# Expand text embeding for batch
|
||||||
|
batch_empty_text_embed = self.empty_text_embed.repeat(
|
||||||
|
(rgb_latent.shape[0], 1, 1)
|
||||||
|
).to(device=device, dtype=precision) # [B, 2, 1024]
|
||||||
|
|
||||||
|
# Export intermediate denoising steps
|
||||||
|
if num_output_inter_results > 0:
|
||||||
|
depth_latent_ls = []
|
||||||
|
inter_steps = []
|
||||||
|
_idx = (
|
||||||
|
-1
|
||||||
|
* (
|
||||||
|
np.arange(0, num_output_inter_results)
|
||||||
|
* num_inference_steps
|
||||||
|
/ num_output_inter_results
|
||||||
|
)
|
||||||
|
.round()
|
||||||
|
.astype(int)
|
||||||
|
- 1
|
||||||
|
)
|
||||||
|
steps_to_output = timesteps[_idx]
|
||||||
|
|
||||||
|
# Denoising loop
|
||||||
|
if show_pbar:
|
||||||
|
iterable = tqdm(enumerate(timesteps), total=len(timesteps), leave=False, desc="denoising")
|
||||||
|
else:
|
||||||
|
iterable = enumerate(timesteps)
|
||||||
|
for i, t in iterable:
|
||||||
|
unet_input = torch.cat(
|
||||||
|
[rgb_latent, depth_latent], dim=1
|
||||||
|
) # this order is important
|
||||||
|
unet_input = unet_input.to(dtype=precision)
|
||||||
|
# predict the noise residual
|
||||||
|
noise_pred = self.unet(
|
||||||
|
unet_input, t, encoder_hidden_states=batch_empty_text_embed
|
||||||
|
).sample # [B, 4, h, w]
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
depth_latent = self.noise_scheduler.step(
|
||||||
|
noise_pred, t, depth_latent
|
||||||
|
).prev_sample.to(dtype=precision)
|
||||||
|
|
||||||
|
|
||||||
|
if num_output_inter_results > 0 and t in steps_to_output:
|
||||||
|
depth_latent_ls.append(depth_latent.detach().clone())
|
||||||
|
#depth_latent_ls = depth_latent_ls.to(dtype=precision)
|
||||||
|
inter_steps.append(t - 1)
|
||||||
|
|
||||||
|
# Decode depth latent
|
||||||
|
if num_output_inter_results > 0:
|
||||||
|
assert 0 in inter_steps
|
||||||
|
depth = [self.decode_depth(lat) for lat in depth_latent_ls]
|
||||||
|
if return_depth_latent:
|
||||||
|
return depth, inter_steps, depth_latent_ls
|
||||||
|
else:
|
||||||
|
return depth, inter_steps
|
||||||
|
else:
|
||||||
|
depth = self.decode_depth(depth_latent)
|
||||||
|
if return_depth_latent:
|
||||||
|
return depth, depth_latent
|
||||||
|
else:
|
||||||
|
return depth
|
||||||
|
|
||||||
|
def encode_rgb(self, rgb_in):
|
||||||
|
rgb_latent = self.rgb_encoder(rgb_in) # [B, 4, h, w]
|
||||||
|
rgb_latent = rgb_latent * self.rgb_latent_scale_factor
|
||||||
|
return rgb_latent
|
||||||
|
|
||||||
|
def encode_depth(self, depth_in):
|
||||||
|
depth_latent = self.depth_ae.encode(depth_in)
|
||||||
|
depth_latent = depth_latent * self.depth_latent_scale_factor
|
||||||
|
return depth_latent
|
||||||
|
|
||||||
|
def decode_depth(self, depth_latent):
|
||||||
|
#depth_latent = depth_latent.to(dtype=torch.float16)
|
||||||
|
depth_latent = depth_latent / self.depth_latent_scale_factor
|
||||||
|
depth = self.depth_ae.decode(depth_latent) # [B, 1, H, W]
|
||||||
|
return depth
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _encode_text(prompt, tokenizer, text_encoder):
|
||||||
|
text_inputs = tokenizer(
|
||||||
|
prompt,
|
||||||
|
padding="do_not_pad",
|
||||||
|
max_length=tokenizer.model_max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
text_input_ids = text_inputs.input_ids.to(text_encoder.device)
|
||||||
|
text_embed = text_encoder(text_input_ids)[0]
|
||||||
|
return text_embed
|
@ -0,0 +1,36 @@
|
|||||||
|
# Author: Bingxin Ke
|
||||||
|
# Last modified: 2023-12-05
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import logging
|
||||||
|
from diffusers import AutoencoderKL
|
||||||
|
|
||||||
|
|
||||||
|
class RGBEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
The encoder of pretrained Stable Diffusion VAE
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pretrained_path, subfolder=None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
vae: AutoencoderKL = AutoencoderKL.from_pretrained(pretrained_path, subfolder=subfolder)
|
||||||
|
logging.info(f"pretrained AutoencoderKL loaded from: {pretrained_path}")
|
||||||
|
|
||||||
|
self.rgb_encoder = nn.Sequential(
|
||||||
|
vae.encoder,
|
||||||
|
vae.quant_conv,
|
||||||
|
)
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
self.rgb_encoder.to(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, rgb_in):
|
||||||
|
return self.encode(rgb_in)
|
||||||
|
|
||||||
|
def encode(self, rgb_in):
|
||||||
|
moments = self.rgb_encoder(rgb_in) # [B, 8, H/8, W/8]
|
||||||
|
mean, logvar = torch.chunk(moments, 2, dim=1)
|
||||||
|
rgb_latent = mean
|
||||||
|
return rgb_latent
|
@ -0,0 +1,52 @@
|
|||||||
|
# Author: Bingxin Ke
|
||||||
|
# Last modified: 2023-12-05
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import logging
|
||||||
|
from diffusers import AutoencoderKL
|
||||||
|
|
||||||
|
|
||||||
|
class StackedDepthAE(nn.Module):
|
||||||
|
"""
|
||||||
|
Tailored pretrained image VAE for depth map.
|
||||||
|
Encode: Depth images are repeated into 3 channels.
|
||||||
|
Decode: The average of 3 chennels are taken as output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pretrained_path, subfolder=None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.vae: AutoencoderKL = AutoencoderKL.from_pretrained(pretrained_path, subfolder=subfolder)
|
||||||
|
logging.info(f"pretrained AutoencoderKL loaded from: {pretrained_path}")
|
||||||
|
|
||||||
|
def forward(self, depth_in):
|
||||||
|
depth_latent = self.encode(depth_in)
|
||||||
|
depth_out = self.decode(depth_latent)
|
||||||
|
return depth_out
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
self.vae.to(*args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _stack_depth_images(depth_in):
|
||||||
|
if 4 == len(depth_in.shape):
|
||||||
|
stacked = depth_in.repeat(1, 3, 1, 1)
|
||||||
|
elif 3 == len(depth_in.shape):
|
||||||
|
stacked = depth_in.unsqueeze(1)
|
||||||
|
stacked = depth_in.repeat(1, 3, 1, 1)
|
||||||
|
return stacked
|
||||||
|
|
||||||
|
def encode(self, depth_in):
|
||||||
|
stacked = self._stack_depth_images(depth_in)
|
||||||
|
h = self.vae.encoder(stacked)
|
||||||
|
moments = self.vae.quant_conv(h)
|
||||||
|
mean, logvar = torch.chunk(moments, 2, dim=1)
|
||||||
|
depth_latent = mean
|
||||||
|
return depth_latent
|
||||||
|
|
||||||
|
def decode(self, depth_latent):
|
||||||
|
z = self.vae.post_quant_conv(depth_latent)
|
||||||
|
stacked = self.vae.decoder(z)
|
||||||
|
depth_mean = stacked.mean(dim=1, keepdim=True)
|
||||||
|
return depth_mean
|
@ -0,0 +1,38 @@
|
|||||||
|
# Author: Bingxin Ke
|
||||||
|
# Last modified: 2023-12-11
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
# Search table for suggested max. inference batch size
|
||||||
|
bs_search_table = [
|
||||||
|
# tested on A100-PCIE-80GB
|
||||||
|
{"res": 768, "total_vram": 79, "bs": 35},
|
||||||
|
{"res": 1024, "total_vram": 79, "bs": 20},
|
||||||
|
# tested on A100-PCIE-40GB
|
||||||
|
{"res": 768, "total_vram": 39, "bs": 15},
|
||||||
|
{"res": 1024, "total_vram": 39, "bs": 8},
|
||||||
|
# tested on RTX3090, RTX4090
|
||||||
|
{"res": 512, "total_vram": 23, "bs": 20},
|
||||||
|
{"res": 768, "total_vram": 23, "bs": 7},
|
||||||
|
{"res": 1024, "total_vram": 23, "bs": 3},
|
||||||
|
# tested on GTX1080Ti
|
||||||
|
{"res": 512, "total_vram": 10, "bs": 5},
|
||||||
|
{"res": 768, "total_vram": 10, "bs": 2},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def find_batch_size(n_repeat, input_res):
|
||||||
|
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
|
||||||
|
|
||||||
|
for settings in sorted(bs_search_table, key=lambda k: (k['res'], -k['total_vram'])):
|
||||||
|
if input_res <= settings['res'] and total_vram >= settings['total_vram']:
|
||||||
|
bs = settings['bs']
|
||||||
|
if bs > n_repeat:
|
||||||
|
bs = n_repeat
|
||||||
|
elif bs > math.ceil(n_repeat / 2) and bs < n_repeat:
|
||||||
|
bs = math.ceil(n_repeat / 2)
|
||||||
|
return bs
|
||||||
|
return 1
|
@ -0,0 +1,103 @@
|
|||||||
|
# Test align depth images
|
||||||
|
# Author: Bingxin Ke
|
||||||
|
# Last modified: 2023-12-11
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from scipy.optimize import minimize
|
||||||
|
|
||||||
|
def inter_distances(tensors):
|
||||||
|
"""
|
||||||
|
To calculate the distance between each two depth maps.
|
||||||
|
"""
|
||||||
|
distances = []
|
||||||
|
for i, j in torch.combinations(torch.arange(tensors.shape[0])):
|
||||||
|
arr1 = tensors[i:i+1]
|
||||||
|
arr2 = tensors[j:j+1]
|
||||||
|
distances.append(arr1 - arr2)
|
||||||
|
dist = torch.concatenate(distances, dim=0)
|
||||||
|
return dist
|
||||||
|
|
||||||
|
|
||||||
|
def ensemble_depths(input_images, regularizer_strength=0.02, max_iter=2, tol=1e-3, reduction='median', max_res=None, disp=False, device='cuda'):
|
||||||
|
"""
|
||||||
|
To ensemble multiple affine-invariant depth images (up to scale and shift),
|
||||||
|
by aligning estimating the scale and shift
|
||||||
|
"""
|
||||||
|
device = input_images.device
|
||||||
|
original_input = input_images.clone()
|
||||||
|
n_img = input_images.shape[0]
|
||||||
|
ori_shape = input_images.shape
|
||||||
|
|
||||||
|
if max_res is not None:
|
||||||
|
scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:]))
|
||||||
|
if scale_factor < 1:
|
||||||
|
downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode='nearest')
|
||||||
|
input_images = downscaler(torch.from_numpy(input_images)).numpy()
|
||||||
|
|
||||||
|
# init guess
|
||||||
|
_min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
|
||||||
|
_max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
|
||||||
|
s_init = 1.0 / (_max - _min).reshape((-1, 1, 1))
|
||||||
|
t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1))
|
||||||
|
x = np.concatenate([s_init, t_init]).reshape(-1)
|
||||||
|
|
||||||
|
input_images = input_images.to(device)
|
||||||
|
|
||||||
|
# objective function
|
||||||
|
def closure(x):
|
||||||
|
x = x.astype(np.float32)
|
||||||
|
l = len(x)
|
||||||
|
s = x[:int(l/2)]
|
||||||
|
t = x[int(l/2):]
|
||||||
|
s = torch.from_numpy(s).to(device)
|
||||||
|
t = torch.from_numpy(t).to(device)
|
||||||
|
|
||||||
|
transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1))
|
||||||
|
dists = inter_distances(transformed_arrays)
|
||||||
|
sqrt_dist = torch.sqrt(torch.mean(dists**2))
|
||||||
|
|
||||||
|
if 'mean' == reduction:
|
||||||
|
pred = torch.mean(transformed_arrays, dim=0)
|
||||||
|
elif 'median' == reduction:
|
||||||
|
pred = torch.median(transformed_arrays, dim=0).values
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
near_err = torch.sqrt((0 - torch.min(pred))**2)
|
||||||
|
far_err = torch.sqrt((1 - torch.max(pred))**2)
|
||||||
|
|
||||||
|
err = sqrt_dist + (near_err + far_err) * regularizer_strength
|
||||||
|
err = err.detach().cpu().numpy()
|
||||||
|
return err
|
||||||
|
|
||||||
|
res = minimize(closure, x, method='BFGS', tol=tol, options={'maxiter': max_iter, 'disp': disp})
|
||||||
|
x = res.x
|
||||||
|
l = len(x)
|
||||||
|
s = x[:int(l/2)]
|
||||||
|
t = x[int(l/2):]
|
||||||
|
|
||||||
|
# Prediction
|
||||||
|
s = torch.from_numpy(s).to(device)
|
||||||
|
t = torch.from_numpy(t).to(device)
|
||||||
|
transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1)
|
||||||
|
if 'mean' == reduction:
|
||||||
|
aligned_images = torch.mean(transformed_arrays, dim=0)
|
||||||
|
std = torch.std(transformed_arrays, dim=0)
|
||||||
|
uncertainty = std
|
||||||
|
elif 'median' == reduction:
|
||||||
|
aligned_images = torch.median(transformed_arrays, dim=0).values
|
||||||
|
# MAD (median absolute deviation) as uncertainty indicator
|
||||||
|
abs_dev = torch.abs(transformed_arrays - aligned_images)
|
||||||
|
mad = torch.median(abs_dev, dim=0).values
|
||||||
|
uncertainty = mad
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
# Scale and shift to [0, 1]
|
||||||
|
_min = torch.min(aligned_images)
|
||||||
|
_max = torch.max(aligned_images)
|
||||||
|
aligned_images = (aligned_images - _min) / (_max - _min)
|
||||||
|
uncertainty /= (_max - _min)
|
||||||
|
return aligned_images, uncertainty
|
@ -0,0 +1,66 @@
|
|||||||
|
|
||||||
|
import matplotlib
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
def colorize_depth_maps(depth_map, min_depth, max_depth, cmap='Spectral', valid_mask=None):
|
||||||
|
"""
|
||||||
|
Colorize depth maps.
|
||||||
|
"""
|
||||||
|
assert len(depth_map.shape) >= 2, "Invalid dimension"
|
||||||
|
|
||||||
|
if isinstance(depth_map, torch.Tensor):
|
||||||
|
depth = depth_map.detach().clone().squeeze().numpy()
|
||||||
|
elif isinstance(depth_map, np.ndarray):
|
||||||
|
depth = depth_map.copy().squeeze()
|
||||||
|
# reshape to [ (B,) H, W ]
|
||||||
|
if depth.ndim < 3:
|
||||||
|
depth = depth[np.newaxis, :, :]
|
||||||
|
|
||||||
|
# colorize
|
||||||
|
cm = matplotlib.colormaps[cmap]
|
||||||
|
depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
|
||||||
|
img_colored_np = cm(depth, bytes=False)[:,:,:,0:3] # value from 0 to 1
|
||||||
|
img_colored_np = np.rollaxis(img_colored_np, 3, 1)
|
||||||
|
|
||||||
|
if valid_mask is not None:
|
||||||
|
if isinstance(depth_map, torch.Tensor):
|
||||||
|
valid_mask = valid_mask.detach().numpy()
|
||||||
|
valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
|
||||||
|
if valid_mask.ndim < 3:
|
||||||
|
valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
|
||||||
|
else:
|
||||||
|
valid_mask = valid_mask[:, np.newaxis, :, :]
|
||||||
|
valid_mask = np.repeat(valid_mask, 3, axis=1)
|
||||||
|
img_colored_np[~valid_mask] = 0
|
||||||
|
|
||||||
|
if isinstance(depth_map, torch.Tensor):
|
||||||
|
img_colored = torch.from_numpy(img_colored_np).float()
|
||||||
|
elif isinstance(depth_map, np.ndarray):
|
||||||
|
img_colored = img_colored_np
|
||||||
|
|
||||||
|
return img_colored
|
||||||
|
|
||||||
|
|
||||||
|
def chw2hwc(chw):
|
||||||
|
assert 3 == len(chw.shape)
|
||||||
|
if isinstance(chw, torch.Tensor):
|
||||||
|
hwc = torch.permute(chw, (1, 2, 0))
|
||||||
|
elif isinstance(chw, np.ndarray):
|
||||||
|
hwc = np.moveaxis(chw, 0, -1)
|
||||||
|
return hwc
|
||||||
|
|
||||||
|
|
||||||
|
def resize_max_res(img: Image.Image, max_edge_resolution):
|
||||||
|
original_width, original_height = img.size
|
||||||
|
downscale_factor = min(max_edge_resolution / original_width, max_edge_resolution / original_height)
|
||||||
|
|
||||||
|
new_width = int(original_width * downscale_factor)
|
||||||
|
new_height = int(original_height * downscale_factor)
|
||||||
|
|
||||||
|
resized_img = img.resize((new_width, new_height))
|
||||||
|
return resized_img
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,14 @@
|
|||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def seed_all(seed: int = 0):
|
||||||
|
"""
|
||||||
|
Set random seeds of all components.
|
||||||
|
"""
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
@ -0,0 +1,75 @@
|
|||||||
|
from modules_forge.supported_preprocessor import Preprocessor, PreprocessorParameter
|
||||||
|
from modules_forge.shared import preprocessor_dir, add_supported_preprocessor
|
||||||
|
from modules_forge.forge_util import resize_image_with_pad
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from marigold.model.marigold_pipeline import MarigoldPipeline
|
||||||
|
from einops import rearrange
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from modules_forge.diffusers_patcher import DiffusersModelPatcher
|
||||||
|
from modules_forge.forge_util import numpy_to_pytorch
|
||||||
|
from ldm_patched.modules import model_management
|
||||||
|
|
||||||
|
|
||||||
|
class PreprocessorMarigold(Preprocessor):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.name = 'depth_marigold'
|
||||||
|
self.tags = ['Depth']
|
||||||
|
self.model_filename_filters = ['depth']
|
||||||
|
self.slider_resolution = PreprocessorParameter(
|
||||||
|
label='Resolution', minimum=128, maximum=2048, value=768, step=8, visible=True)
|
||||||
|
self.slider_1 = PreprocessorParameter(visible=False)
|
||||||
|
self.slider_2 = PreprocessorParameter(visible=False)
|
||||||
|
self.slider_3 = PreprocessorParameter(visible=False)
|
||||||
|
self.show_control_mode = True
|
||||||
|
self.do_not_need_model = False
|
||||||
|
self.sorting_priority = 100 # higher goes to top in the list
|
||||||
|
self.diffusers_patcher = None
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
if self.model_patcher is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
checkpoint_path = os.path.join(preprocessor_dir, 'marigold')
|
||||||
|
|
||||||
|
if not os.path.exists(checkpoint_path):
|
||||||
|
snapshot_download(repo_id="Bingxin/Marigold",
|
||||||
|
ignore_patterns=["*.bin"],
|
||||||
|
local_dir=checkpoint_path,
|
||||||
|
local_dir_use_symlinks=False)
|
||||||
|
|
||||||
|
self.diffusers_patcher = DiffusersModelPatcher(
|
||||||
|
pipeline_class=MarigoldPipeline,
|
||||||
|
pretrained_path=checkpoint_path,
|
||||||
|
enable_xformers=False,
|
||||||
|
noise_scheduler_type='DDIMScheduler')
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider_3=None, **kwargs):
|
||||||
|
input_image, remove_pad = resize_image_with_pad(input_image, resolution)
|
||||||
|
|
||||||
|
self.load_model()
|
||||||
|
|
||||||
|
model_management.load_models_gpu([self.diffusers_patcher.patcher])
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
img = numpy_to_pytorch(input_image).movedim(-1, 1).to(
|
||||||
|
device=self.diffusers_patcher.patcher.current_device,
|
||||||
|
dtype=self.diffusers_patcher.dtype)
|
||||||
|
|
||||||
|
img = img * 2.0 - 1.0
|
||||||
|
depth = self.diffusers_patcher.patcher.model(img, num_inference_steps=20, show_pbar=False)
|
||||||
|
depth = depth * 0.5 + 0.5
|
||||||
|
depth = depth.movedim(1, -1)[0].cpu().numpy()
|
||||||
|
depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8)
|
||||||
|
|
||||||
|
return remove_pad(depth_image)
|
||||||
|
|
||||||
|
|
||||||
|
add_supported_preprocessor(PreprocessorMarigold())
|
@ -27,7 +27,11 @@ class DiffusersModelPatcher:
|
|||||||
self.pipeline.unet.set_attn_processor(AttnProcessor2_0())
|
self.pipeline.unet.set_attn_processor(AttnProcessor2_0())
|
||||||
print('Attention optimization applied to DiffusersModelPatcher')
|
print('Attention optimization applied to DiffusersModelPatcher')
|
||||||
|
|
||||||
self.pipeline = self.pipeline.to(device=offload_device, dtype=dtype)
|
self.pipeline = self.pipeline.to(device=offload_device)
|
||||||
|
|
||||||
|
if self.dtype == torch.float16:
|
||||||
|
self.pipeline = self.pipeline.half()
|
||||||
|
|
||||||
self.pipeline.eval()
|
self.pipeline.eval()
|
||||||
|
|
||||||
self.patcher = ModelPatcher(
|
self.patcher = ModelPatcher(
|
||||||
|
@ -29,3 +29,4 @@ torchsde==0.2.6
|
|||||||
transformers==4.30.2
|
transformers==4.30.2
|
||||||
httpx==0.24.1
|
httpx==0.24.1
|
||||||
basicsr==1.4.2
|
basicsr==1.4.2
|
||||||
|
diffusers==0.25.0
|
||||||
|
Loading…
Reference in New Issue
Block a user