From 77bdb9208d019e562f9f629647356dca2b2d5ef1 Mon Sep 17 00:00:00 2001 From: GavChap Date: Thu, 23 May 2024 21:42:56 +0100 Subject: [PATCH] Add cos xl (#710) * Add V_PREDICTION_EDM handing for CosXL models Add V_PREDICTION_EDM handing for CosXL models * Get correct sigmas from checkpoint. * Round to 3 sig digs in order to make compatible with comfy implementation * Add sigma data like ComfyUI has --------- Co-authored-by: Gavin Chapman --- ldm_patched/modules/model_sampling.py | 6 ++++-- ldm_patched/modules/supported_models.py | 5 +++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/ldm_patched/modules/model_sampling.py b/ldm_patched/modules/model_sampling.py index da5cc3a6..6bb08fd2 100644 --- a/ldm_patched/modules/model_sampling.py +++ b/ldm_patched/modules/model_sampling.py @@ -107,9 +107,11 @@ class ModelSamplingContinuousEDM(torch.nn.Module): sigma_min = sampling_settings.get("sigma_min", 0.002) sigma_max = sampling_settings.get("sigma_max", 120.0) - self.set_sigma_range(sigma_min, sigma_max) + sigma_data = sampling_settings.get("sigma_data", 1.0) + self.set_sigma_range(sigma_min, sigma_max, sigma_data) - def set_sigma_range(self, sigma_min, sigma_max): + def set_sigma_range(self, sigma_min, sigma_max, sigma_data): + self.sigma_data = sigma_data sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp() self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers diff --git a/ldm_patched/modules/supported_models.py b/ldm_patched/modules/supported_models.py index de21c10d..3e1889bf 100644 --- a/ldm_patched/modules/supported_models.py +++ b/ldm_patched/modules/supported_models.py @@ -169,6 +169,11 @@ class SDXL(supported_models_base.BASE): def model_type(self, state_dict, prefix=""): if "v_pred" in state_dict: return model_base.ModelType.V_PREDICTION + elif "edm_vpred.sigma_max" in state_dict: + self.sampling_settings["sigma_max"] = round(float(state_dict["edm_vpred.sigma_max"].item()),3) + if "edm_vpred.sigma_min" in state_dict: + self.sampling_settings["sigma_min"] = round(float(state_dict["edm_vpred.sigma_min"].item()),3) + return model_base.ModelType.V_PREDICTION_EDM else: return model_base.ModelType.EPS