diff --git a/ldm_patched/modules/model_sampling.py b/ldm_patched/modules/model_sampling.py index da5cc3a6..6bb08fd2 100644 --- a/ldm_patched/modules/model_sampling.py +++ b/ldm_patched/modules/model_sampling.py @@ -107,9 +107,11 @@ class ModelSamplingContinuousEDM(torch.nn.Module): sigma_min = sampling_settings.get("sigma_min", 0.002) sigma_max = sampling_settings.get("sigma_max", 120.0) - self.set_sigma_range(sigma_min, sigma_max) + sigma_data = sampling_settings.get("sigma_data", 1.0) + self.set_sigma_range(sigma_min, sigma_max, sigma_data) - def set_sigma_range(self, sigma_min, sigma_max): + def set_sigma_range(self, sigma_min, sigma_max, sigma_data): + self.sigma_data = sigma_data sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp() self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers diff --git a/ldm_patched/modules/supported_models.py b/ldm_patched/modules/supported_models.py index de21c10d..3e1889bf 100644 --- a/ldm_patched/modules/supported_models.py +++ b/ldm_patched/modules/supported_models.py @@ -169,6 +169,11 @@ class SDXL(supported_models_base.BASE): def model_type(self, state_dict, prefix=""): if "v_pred" in state_dict: return model_base.ModelType.V_PREDICTION + elif "edm_vpred.sigma_max" in state_dict: + self.sampling_settings["sigma_max"] = round(float(state_dict["edm_vpred.sigma_max"].item()),3) + if "edm_vpred.sigma_min" in state_dict: + self.sampling_settings["sigma_min"] = round(float(state_dict["edm_vpred.sigma_min"].item()),3) + return model_base.ModelType.V_PREDICTION_EDM else: return model_base.ModelType.EPS