diff --git a/modules/processing.py b/modules/processing.py index a23607e7..2453bf99 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -34,6 +34,7 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion from einops import repeat, rearrange from blendmodes.blend import blendLayers, BlendType from modules.sd_models import apply_token_merging +from modules_forge.forge_util import apply_circular_forge # some of those options should not be changed at all because they would break the model, so I removed them from options. @@ -791,7 +792,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.sd_vae_name = sd_vae.get_loaded_vae_name() p.sd_vae_hash = sd_vae.get_loaded_vae_hash() - modules.sd_hijack.model_hijack.apply_circular(p.tiling) + apply_circular_forge(p.sd_model, p.tiling) modules.sd_hijack.model_hijack.clear_comments() p.setup_prompts() diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index ff53e835..4fa32be4 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -250,6 +250,7 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): sd_model.decode_first_stage = patched_decode_first_stage sd_model.encode_first_stage = patched_encode_first_stage sd_model.clip = sd_model.cond_stage_model + sd_model.tiling_enabled = False timer.record("forge finalize") sd_model.current_lora_hash = str([]) diff --git a/modules_forge/forge_util.py b/modules_forge/forge_util.py index 4f24d85d..df47b571 100644 --- a/modules_forge/forge_util.py +++ b/modules_forge/forge_util.py @@ -21,6 +21,27 @@ def prepare_free_memory(aggressive=False): return +def apply_circular_forge(model, tiling_enabled=False): + if model.tiling_enabled == tiling_enabled: + return + + print(f'Tiling: {tiling_enabled}') + model.tiling_enabled = tiling_enabled + + def flatten(el): + flattened = [flatten(children) for children in el.children()] + res = [el] + for c in flattened: + res += c + return res + + layers = flatten(model) + + for layer in [layer for layer in layers if 'Conv' in type(layer).__name__]: + layer.padding_mode = 'circular' if tiling_enabled else 'zeros' + return + + def HWC3(x): assert x.dtype == np.uint8 if x.ndim == 2: