3d039591fe
Previously before this commit, credits are already in entry and licenses are already in root. This commit will make info clearer.
81 lines
3.0 KiB
Python
81 lines
3.0 KiB
Python
# Taken from https://github.com/comfyanonymous/ComfyUI
|
|
# This file is only for reference, and not used in the backend or runtime.
|
|
|
|
#!/usr/bin/env python3
|
|
"""
|
|
Tiny AutoEncoder for Stable Diffusion
|
|
(DNN for encoding / decoding SD's latent space)
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
import ldm_patched.modules.utils
|
|
import ldm_patched.modules.ops
|
|
|
|
def conv(n_in, n_out, **kwargs):
|
|
return ldm_patched.modules.ops.disable_weight_init.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
|
|
|
class Clamp(nn.Module):
|
|
def forward(self, x):
|
|
return torch.tanh(x / 3) * 3
|
|
|
|
class Block(nn.Module):
|
|
def __init__(self, n_in, n_out):
|
|
super().__init__()
|
|
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
|
|
self.skip = ldm_patched.modules.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
|
self.fuse = nn.ReLU()
|
|
def forward(self, x):
|
|
return self.fuse(self.conv(x) + self.skip(x))
|
|
|
|
def Encoder():
|
|
return nn.Sequential(
|
|
conv(3, 64), Block(64, 64),
|
|
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
|
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
|
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
|
conv(64, 4),
|
|
)
|
|
|
|
def Decoder():
|
|
return nn.Sequential(
|
|
Clamp(), conv(4, 64), nn.ReLU(),
|
|
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
|
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
|
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
|
Block(64, 64), conv(64, 3),
|
|
)
|
|
|
|
class TAESD(nn.Module):
|
|
latent_magnitude = 3
|
|
latent_shift = 0.5
|
|
|
|
def __init__(self, encoder_path=None, decoder_path=None):
|
|
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
|
super().__init__()
|
|
self.taesd_encoder = Encoder()
|
|
self.taesd_decoder = Decoder()
|
|
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
|
if encoder_path is not None:
|
|
self.taesd_encoder.load_state_dict(ldm_patched.modules.utils.load_torch_file(encoder_path, safe_load=True))
|
|
if decoder_path is not None:
|
|
self.taesd_decoder.load_state_dict(ldm_patched.modules.utils.load_torch_file(decoder_path, safe_load=True))
|
|
|
|
@staticmethod
|
|
def scale_latents(x):
|
|
"""raw latents -> [0, 1]"""
|
|
return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1)
|
|
|
|
@staticmethod
|
|
def unscale_latents(x):
|
|
"""[0, 1] -> raw latents"""
|
|
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
|
|
|
def decode(self, x):
|
|
x_sample = self.taesd_decoder(x * self.vae_scale)
|
|
x_sample = x_sample.sub(0.5).mul(2)
|
|
return x_sample
|
|
|
|
def encode(self, x):
|
|
return self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale
|