diff --git a/ldm_patched/modules/model_patcher.py b/ldm_patched/modules/model_patcher.py index fabbfaeb..4ae462cd 100644 --- a/ldm_patched/modules/model_patcher.py +++ b/ldm_patched/modules/model_patcher.py @@ -75,6 +75,12 @@ class ModelPatcher: def set_model_unet_function_wrapper(self, unet_wrapper_function): self.model_options["model_function_wrapper"] = unet_wrapper_function + def set_model_vae_encode_wrapper(self, wrapper_function): + self.model_options["model_vae_encode_wrapper"] = wrapper_function + + def set_model_vae_decode_wrapper(self, wrapper_function): + self.model_options["model_vae_decode_wrapper"] = wrapper_function + def set_model_patch(self, patch, name): to = self.model_options["transformer_options"] if "patches" not in to: @@ -242,7 +248,17 @@ class ModelPatcher: w1 = v[0] if alpha != 0.0: if w1.shape != weight.shape: - print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) + if w1.ndim == weight.ndim == 4: + new_shape = [max(n, m) for n, m in zip(weight.shape, w1.shape)] + print(f'Merged with {key} channel changed to {new_shape}') + new_diff = alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype) + new_weight = torch.zeros(size=new_shape).to(weight) + new_weight[:weight.shape[0], :weight.shape[1], :weight.shape[2], :weight.shape[3]] = weight + new_weight[:new_diff.shape[0], :new_diff.shape[1], :new_diff.shape[2], :new_diff.shape[3]] += new_diff + new_weight = new_weight.contiguous().clone() + weight = new_weight + else: + print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) else: weight += alpha * ldm_patched.modules.model_management.cast_to_device(w1, weight.device, weight.dtype) elif patch_type == "lora": #lora/locon diff --git a/ldm_patched/modules/sd.py b/ldm_patched/modules/sd.py index 02f166e8..2830cc72 100644 --- a/ldm_patched/modules/sd.py +++ b/ldm_patched/modules/sd.py @@ -163,7 +163,10 @@ class CLIP: return self.patcher.get_key_patches() class VAE: - def __init__(self, sd=None, device=None, config=None, dtype=None): + def __init__(self, sd=None, device=None, config=None, dtype=None, no_init=False): + if no_init: + return + if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format sd = diffusers_convert.convert_vae_state_dict(sd) @@ -215,6 +218,19 @@ class VAE: self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) + def clone(self): + n = VAE(no_init=True) + n.patcher = self.patcher.clone() + n.memory_used_encode = self.memory_used_encode + n.memory_used_decode = self.memory_used_decode + n.downscale_ratio = self.downscale_ratio + n.latent_channels = self.latent_channels + n.first_stage_model = self.first_stage_model + n.device = self.device + n.vae_dtype = self.vae_dtype + n.output_device = self.output_device + return n + def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): steps = samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) steps += samples.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) @@ -242,7 +258,7 @@ class VAE: samples /= 3.0 return samples - def decode(self, samples_in): + def decode_inner(self, samples_in): if model_management.VAE_ALWAYS_TILED: return self.decode_tiled(samples_in).to(self.output_device) @@ -264,12 +280,19 @@ class VAE: pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples + def decode(self, samples_in): + wrapper = self.patcher.model_options.get('model_vae_decode_wrapper', None) + if wrapper is None: + return self.decode_inner(samples_in) + else: + return wrapper(self.decode_inner, samples_in) + def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): model_management.load_model_gpu(self.patcher) output = self.decode_tiled_(samples, tile_x, tile_y, overlap) return output.movedim(1,-1) - def encode(self, pixel_samples): + def encode_inner(self, pixel_samples): if model_management.VAE_ALWAYS_TILED: return self.encode_tiled(pixel_samples) @@ -291,6 +314,13 @@ class VAE: return samples + def encode(self, pixel_samples): + wrapper = self.patcher.model_options.get('model_vae_encode_wrapper', None) + if wrapper is None: + return self.encode_inner(pixel_samples) + else: + return wrapper(self.encode_inner, pixel_samples) + def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): model_management.load_model_gpu(self.patcher) pixel_samples = pixel_samples.movedim(-1,1) diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index 4fa32be4..d3ea94a8 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -235,14 +235,14 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): @torch.inference_mode() def patched_decode_first_stage(x): - sample = forge_objects.unet.model.model_config.latent_format.process_out(x) - sample = forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0 + sample = sd_model.forge_objects.unet.model.model_config.latent_format.process_out(x) + sample = sd_model.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0 return sample.to(x) @torch.inference_mode() def patched_encode_first_stage(x): - sample = forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5) - sample = forge_objects.unet.model.model_config.latent_format.process_in(sample) + sample = sd_model.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5) + sample = sd_model.forge_objects.unet.model.model_config.latent_format.process_in(sample) return sample.to(x) sd_model.ema_scope = lambda *args, **kwargs: contextlib.nullcontext() diff --git a/modules_forge/forge_sampler.py b/modules_forge/forge_sampler.py index 79524dff..1b950bdf 100644 --- a/modules_forge/forge_sampler.py +++ b/modules_forge/forge_sampler.py @@ -56,6 +56,7 @@ def cond_from_a1111_to_patched_ldm_weighted(cond, weights): def forge_sample(self, denoiser_params, cond_scale, cond_composition): model = self.inner_model.inner_model.forge_objects.unet.model control = self.inner_model.inner_model.forge_objects.unet.controlnet_linked_list + extra_concat_condition = self.inner_model.inner_model.forge_objects.unet.extra_concat_condition x = denoiser_params.x timestep = denoiser_params.sigma uncond = cond_from_a1111_to_patched_ldm(denoiser_params.text_uncond) @@ -63,7 +64,11 @@ def forge_sample(self, denoiser_params, cond_scale, cond_composition): model_options = self.inner_model.inner_model.forge_objects.unet.model_options seed = self.p.seeds[0] - image_cond_in = denoiser_params.image_cond + if extra_concat_condition is not None: + image_cond_in = extra_concat_condition + else: + image_cond_in = denoiser_params.image_cond + if isinstance(image_cond_in, torch.Tensor): if image_cond_in.shape[0] == x.shape[0] \ and image_cond_in.shape[2] == x.shape[2] \ diff --git a/modules_forge/forge_version.py b/modules_forge/forge_version.py index 11772f7d..c630142e 100644 --- a/modules_forge/forge_version.py +++ b/modules_forge/forge_version.py @@ -1 +1 @@ -version = '0.0.16v1.8.0rc' +version = '0.0.17v1.8.0rc' diff --git a/modules_forge/unet_patcher.py b/modules_forge/unet_patcher.py index 6c0f0e61..275e0e96 100644 --- a/modules_forge/unet_patcher.py +++ b/modules_forge/unet_patcher.py @@ -12,6 +12,7 @@ class UnetPatcher(ModelPatcher): self.controlnet_linked_list = None self.extra_preserved_memory_during_sampling = 0 self.extra_model_patchers_during_sampling = [] + self.extra_concat_condition = None def clone(self): n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, @@ -27,6 +28,7 @@ class UnetPatcher(ModelPatcher): n.controlnet_linked_list = self.controlnet_linked_list n.extra_preserved_memory_during_sampling = self.extra_preserved_memory_during_sampling n.extra_model_patchers_during_sampling = self.extra_model_patchers_during_sampling.copy() + n.extra_concat_condition = self.extra_concat_condition return n def add_extra_preserved_memory_during_sampling(self, memory_in_bytes: int): @@ -176,3 +178,21 @@ class UnetPatcher(ModelPatcher): device=noise.device, prompt_type=prompt_type ) + + def load_frozen_patcher(self, state_dict, strength): + patch_dict = {} + for k, w in state_dict.items(): + model_key, patch_type, weight_index = k.split('::') + if model_key not in patch_dict: + patch_dict[model_key] = {} + if patch_type not in patch_dict[model_key]: + patch_dict[model_key][patch_type] = [None] * 16 + patch_dict[model_key][patch_type][int(weight_index)] = w + + patch_flat = {} + for model_key, v in patch_dict.items(): + for patch_type, weight_list in v.items(): + patch_flat[model_key] = (patch_type, weight_list) + + self.add_patches(patches=patch_flat, strength_patch=float(strength), strength_model=1.0) + return