rework sigma device mapping
This commit is contained in:
parent
d11c9d7506
commit
257ac2653a
@ -144,12 +144,12 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
|||||||
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
||||||
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
||||||
|
|
||||||
self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(unet_patcher.load_device)
|
self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device)
|
||||||
self.model_wrap.sigmas = self.model_wrap.sigmas.to(unet_patcher.load_device)
|
self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device)
|
||||||
|
|
||||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||||
|
|
||||||
sigmas = self.get_sigmas(p, steps).to(shared.device)
|
sigmas = self.get_sigmas(p, steps).to(x.device)
|
||||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||||
|
|
||||||
x = x.to(noise)
|
x = x.to(noise)
|
||||||
@ -206,12 +206,12 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
|||||||
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
||||||
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
||||||
|
|
||||||
self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(unet_patcher.load_device)
|
self.model_wrap.log_sigmas = self.model_wrap.log_sigmas.to(x.device)
|
||||||
self.model_wrap.sigmas = self.model_wrap.sigmas.to(unet_patcher.load_device)
|
self.model_wrap.sigmas = self.model_wrap.sigmas.to(x.device)
|
||||||
|
|
||||||
steps = steps or p.steps
|
steps = steps or p.steps
|
||||||
|
|
||||||
sigmas = self.get_sigmas(p, steps).to(shared.device)
|
sigmas = self.get_sigmas(p, steps).to(x.device)
|
||||||
|
|
||||||
if opts.sgm_noise_multiplier:
|
if opts.sgm_noise_multiplier:
|
||||||
p.extra_generation_params["SGM noise multiplier"] = True
|
p.extra_generation_params["SGM noise multiplier"] = True
|
||||||
|
@ -27,7 +27,7 @@ class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser):
|
|||||||
start = self.sigma_to_t(self.sigma_max)
|
start = self.sigma_to_t(self.sigma_max)
|
||||||
end = self.sigma_to_t(self.sigma_min)
|
end = self.sigma_to_t(self.sigma_min)
|
||||||
|
|
||||||
t = torch.linspace(start, end, n, device=shared.sd_model.forge_objects.unet.current_device)
|
t = torch.linspace(start, end, n, device=self.sigmas.device)
|
||||||
|
|
||||||
return sampling.append_zero(self.t_to_sigma(t))
|
return sampling.append_zero(self.t_to_sigma(t))
|
||||||
|
|
||||||
|
@ -101,11 +101,11 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
|||||||
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
||||||
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
||||||
|
|
||||||
self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(unet_patcher.load_device)
|
self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(x.device)
|
||||||
|
|
||||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||||
|
|
||||||
timesteps = self.get_timesteps(p, steps).to(shared.device)
|
timesteps = self.get_timesteps(p, steps).to(x.device)
|
||||||
timesteps_sched = timesteps[:t_enc]
|
timesteps_sched = timesteps[:t_enc]
|
||||||
|
|
||||||
alphas_cumprod = shared.sd_model.alphas_cumprod
|
alphas_cumprod = shared.sd_model.alphas_cumprod
|
||||||
@ -151,10 +151,10 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
|||||||
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
unet_patcher = self.model_wrap.inner_model.forge_objects.unet
|
||||||
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
sampling_prepare(self.model_wrap.inner_model.forge_objects.unet, x=x)
|
||||||
|
|
||||||
self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(unet_patcher.load_device)
|
self.model_wrap.inner_model.alphas_cumprod = self.model_wrap.inner_model.alphas_cumprod.to(x.device)
|
||||||
|
|
||||||
steps = steps or p.steps
|
steps = steps or p.steps
|
||||||
timesteps = self.get_timesteps(p, steps).to(shared.device)
|
timesteps = self.get_timesteps(p, steps).to(x.device)
|
||||||
|
|
||||||
extra_params_kwargs = self.initialize(p)
|
extra_params_kwargs = self.initialize(p)
|
||||||
parameters = inspect.signature(self.func).parameters
|
parameters = inspect.signature(self.func).parameters
|
||||||
|
Loading…
Reference in New Issue
Block a user