3d039591fe
Previously before this commit, credits are already in entry and licenses are already in root. This commit will make info clearer.
314 lines
11 KiB
Python
314 lines
11 KiB
Python
# Taken from https://github.com/comfyanonymous/ComfyUI
|
|
|
|
|
|
import torch
|
|
from . import model_base
|
|
from . import utils
|
|
|
|
from . import sd1_clip
|
|
from . import sd2_clip
|
|
from . import sdxl_clip
|
|
|
|
from . import supported_models_base
|
|
from . import latent_formats
|
|
|
|
from . import diffusers_convert
|
|
|
|
class SD15(supported_models_base.BASE):
|
|
unet_config = {
|
|
"context_dim": 768,
|
|
"model_channels": 320,
|
|
"use_linear_in_transformer": False,
|
|
"adm_in_channels": None,
|
|
"use_temporal_attention": False,
|
|
}
|
|
|
|
unet_extra_config = {
|
|
"num_heads": 8,
|
|
"num_head_channels": -1,
|
|
}
|
|
|
|
latent_format = latent_formats.SD15
|
|
|
|
def process_clip_state_dict(self, state_dict):
|
|
k = list(state_dict.keys())
|
|
for x in k:
|
|
if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."):
|
|
y = x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")
|
|
state_dict[y] = state_dict.pop(x)
|
|
|
|
if 'cond_stage_model.transformer.text_model.embeddings.position_ids' in state_dict:
|
|
ids = state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids']
|
|
if ids.dtype == torch.float32:
|
|
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
|
|
|
replace_prefix = {}
|
|
replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l."
|
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
|
return state_dict
|
|
|
|
def process_clip_state_dict_for_saving(self, state_dict):
|
|
replace_prefix = {"clip_l.": "cond_stage_model."}
|
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
|
|
|
def clip_target(self):
|
|
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
|
|
|
|
class SD20(supported_models_base.BASE):
|
|
unet_config = {
|
|
"context_dim": 1024,
|
|
"model_channels": 320,
|
|
"use_linear_in_transformer": True,
|
|
"adm_in_channels": None,
|
|
"use_temporal_attention": False,
|
|
}
|
|
|
|
latent_format = latent_formats.SD15
|
|
|
|
def model_type(self, state_dict, prefix=""):
|
|
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
|
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
|
|
out = state_dict[k]
|
|
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
|
return model_base.ModelType.V_PREDICTION
|
|
return model_base.ModelType.EPS
|
|
|
|
def process_clip_state_dict(self, state_dict):
|
|
replace_prefix = {}
|
|
replace_prefix["conditioner.embedders.0.model."] = "cond_stage_model.model." #SD2 in sgm format
|
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
|
|
|
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24)
|
|
return state_dict
|
|
|
|
def process_clip_state_dict_for_saving(self, state_dict):
|
|
replace_prefix = {}
|
|
replace_prefix["clip_h"] = "cond_stage_model.model"
|
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
|
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
|
return state_dict
|
|
|
|
def clip_target(self):
|
|
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
|
|
|
|
class SD21UnclipL(SD20):
|
|
unet_config = {
|
|
"context_dim": 1024,
|
|
"model_channels": 320,
|
|
"use_linear_in_transformer": True,
|
|
"adm_in_channels": 1536,
|
|
"use_temporal_attention": False,
|
|
}
|
|
|
|
clip_vision_prefix = "embedder.model.visual."
|
|
noise_aug_config = {"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 768}
|
|
|
|
|
|
class SD21UnclipH(SD20):
|
|
unet_config = {
|
|
"context_dim": 1024,
|
|
"model_channels": 320,
|
|
"use_linear_in_transformer": True,
|
|
"adm_in_channels": 2048,
|
|
"use_temporal_attention": False,
|
|
}
|
|
|
|
clip_vision_prefix = "embedder.model.visual."
|
|
noise_aug_config = {"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1024}
|
|
|
|
class SDXLRefiner(supported_models_base.BASE):
|
|
unet_config = {
|
|
"model_channels": 384,
|
|
"use_linear_in_transformer": True,
|
|
"context_dim": 1280,
|
|
"adm_in_channels": 2560,
|
|
"transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0],
|
|
"use_temporal_attention": False,
|
|
}
|
|
|
|
latent_format = latent_formats.SDXL
|
|
|
|
def get_model(self, state_dict, prefix="", device=None):
|
|
return model_base.SDXLRefiner(self, device=device)
|
|
|
|
def process_clip_state_dict(self, state_dict):
|
|
keys_to_replace = {}
|
|
replace_prefix = {}
|
|
|
|
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.0.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
|
|
keys_to_replace["conditioner.embedders.0.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
|
keys_to_replace["conditioner.embedders.0.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
|
|
|
|
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
|
|
return state_dict
|
|
|
|
def process_clip_state_dict_for_saving(self, state_dict):
|
|
replace_prefix = {}
|
|
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
|
if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
|
|
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
|
|
replace_prefix["clip_g"] = "conditioner.embedders.0.model"
|
|
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
|
return state_dict_g
|
|
|
|
def clip_target(self):
|
|
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
|
|
|
|
class SDXL(supported_models_base.BASE):
|
|
unet_config = {
|
|
"model_channels": 320,
|
|
"use_linear_in_transformer": True,
|
|
"transformer_depth": [0, 0, 2, 2, 10, 10],
|
|
"context_dim": 2048,
|
|
"adm_in_channels": 2816,
|
|
"use_temporal_attention": False,
|
|
}
|
|
|
|
latent_format = latent_formats.SDXL
|
|
|
|
def model_type(self, state_dict, prefix=""):
|
|
if "v_pred" in state_dict:
|
|
return model_base.ModelType.V_PREDICTION
|
|
else:
|
|
return model_base.ModelType.EPS
|
|
|
|
def get_model(self, state_dict, prefix="", device=None):
|
|
out = model_base.SDXL(self, model_type=self.model_type(state_dict, prefix), device=device)
|
|
if self.inpaint_model():
|
|
out.set_inpaint()
|
|
return out
|
|
|
|
def process_clip_state_dict(self, state_dict):
|
|
keys_to_replace = {}
|
|
replace_prefix = {}
|
|
|
|
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model"
|
|
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
|
|
keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
|
keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection"
|
|
keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
|
|
|
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
|
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
|
|
return state_dict
|
|
|
|
def process_clip_state_dict_for_saving(self, state_dict):
|
|
replace_prefix = {}
|
|
keys_to_replace = {}
|
|
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
|
if "clip_g.transformer.text_model.embeddings.position_ids" in state_dict_g:
|
|
state_dict_g.pop("clip_g.transformer.text_model.embeddings.position_ids")
|
|
for k in state_dict:
|
|
if k.startswith("clip_l"):
|
|
state_dict_g[k] = state_dict[k]
|
|
|
|
replace_prefix["clip_g"] = "conditioner.embedders.1.model"
|
|
replace_prefix["clip_l"] = "conditioner.embedders.0"
|
|
state_dict_g = utils.state_dict_prefix_replace(state_dict_g, replace_prefix)
|
|
return state_dict_g
|
|
|
|
def clip_target(self):
|
|
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
|
|
|
class SSD1B(SDXL):
|
|
unet_config = {
|
|
"model_channels": 320,
|
|
"use_linear_in_transformer": True,
|
|
"transformer_depth": [0, 0, 2, 2, 4, 4],
|
|
"context_dim": 2048,
|
|
"adm_in_channels": 2816,
|
|
"use_temporal_attention": False,
|
|
}
|
|
|
|
class Segmind_Vega(SDXL):
|
|
unet_config = {
|
|
"model_channels": 320,
|
|
"use_linear_in_transformer": True,
|
|
"transformer_depth": [0, 0, 1, 1, 2, 2],
|
|
"context_dim": 2048,
|
|
"adm_in_channels": 2816,
|
|
"use_temporal_attention": False,
|
|
}
|
|
|
|
class SVD_img2vid(supported_models_base.BASE):
|
|
unet_config = {
|
|
"model_channels": 320,
|
|
"in_channels": 8,
|
|
"use_linear_in_transformer": True,
|
|
"transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0],
|
|
"context_dim": 1024,
|
|
"adm_in_channels": 768,
|
|
"use_temporal_attention": True,
|
|
"use_temporal_resblock": True
|
|
}
|
|
|
|
clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
|
|
|
|
latent_format = latent_formats.SD15
|
|
|
|
sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002}
|
|
|
|
def get_model(self, state_dict, prefix="", device=None):
|
|
out = model_base.SVD_img2vid(self, device=device)
|
|
return out
|
|
|
|
def clip_target(self):
|
|
return None
|
|
|
|
class Stable_Zero123(supported_models_base.BASE):
|
|
unet_config = {
|
|
"context_dim": 768,
|
|
"model_channels": 320,
|
|
"use_linear_in_transformer": False,
|
|
"adm_in_channels": None,
|
|
"use_temporal_attention": False,
|
|
"in_channels": 8,
|
|
}
|
|
|
|
unet_extra_config = {
|
|
"num_heads": 8,
|
|
"num_head_channels": -1,
|
|
}
|
|
|
|
clip_vision_prefix = "cond_stage_model.model.visual."
|
|
|
|
latent_format = latent_formats.SD15
|
|
|
|
def get_model(self, state_dict, prefix="", device=None):
|
|
out = model_base.Stable_Zero123(self, device=device, cc_projection_weight=state_dict["cc_projection.weight"], cc_projection_bias=state_dict["cc_projection.bias"])
|
|
return out
|
|
|
|
def clip_target(self):
|
|
return None
|
|
|
|
class SD_X4Upscaler(SD20):
|
|
unet_config = {
|
|
"context_dim": 1024,
|
|
"model_channels": 256,
|
|
'in_channels': 7,
|
|
"use_linear_in_transformer": True,
|
|
"adm_in_channels": None,
|
|
"use_temporal_attention": False,
|
|
}
|
|
|
|
unet_extra_config = {
|
|
"disable_self_attentions": [True, True, True, False],
|
|
"num_classes": 1000,
|
|
"num_heads": 8,
|
|
"num_head_channels": -1,
|
|
}
|
|
|
|
latent_format = latent_formats.SD_X4
|
|
|
|
sampling_settings = {
|
|
"linear_start": 0.0001,
|
|
"linear_end": 0.02,
|
|
}
|
|
|
|
def get_model(self, state_dict, prefix="", device=None):
|
|
out = model_base.SD_X4Upscaler(self, device=device)
|
|
return out
|
|
|
|
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler]
|
|
models += [SVD_img2vid]
|