i
This commit is contained in:
parent
70948938fb
commit
a3dffecb3f
@ -16,6 +16,7 @@ from modules import sd_hijack
|
||||
from modules.sd_models_xl import extend_sdxl
|
||||
from ldm.util import instantiate_from_config
|
||||
from modules_forge import forge_clip
|
||||
from modules_forge.unet_patcher import UnetPatcher
|
||||
|
||||
import open_clip
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
@ -116,7 +117,7 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
print("left over keys:", left_over)
|
||||
|
||||
if output_model:
|
||||
model_patcher = ldm_patched.modules.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
||||
model_patcher = UnetPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
|
||||
if inital_load_device != torch.device("cpu"):
|
||||
print("loaded straight to GPU")
|
||||
model_management.load_model_gpu(model_patcher)
|
||||
|
@ -2,42 +2,10 @@ import torch
|
||||
import ldm_patched.modules.samplers
|
||||
|
||||
from ldm_patched.modules.controlnet import ControlBase
|
||||
from ldm_patched.modules.model_patcher import ModelPatcher
|
||||
from ldm_patched.modules.samplers import get_area_and_mult, can_concat_cond, cond_cat
|
||||
from ldm_patched.modules import model_management
|
||||
|
||||
|
||||
og_model_patcher_init = ModelPatcher.__init__
|
||||
og_model_patcher_clone = ModelPatcher.clone
|
||||
|
||||
|
||||
def patched_model_patcher_init(self, *args, **kwargs):
|
||||
h = og_model_patcher_init(self, *args, **kwargs)
|
||||
self.controlnet_linked_list = None
|
||||
return h
|
||||
|
||||
|
||||
def patched_model_patcher_clone(self):
|
||||
cloned = og_model_patcher_clone(self)
|
||||
cloned.controlnet_linked_list = self.controlnet_linked_list
|
||||
return cloned
|
||||
|
||||
|
||||
def model_patcher_add_patched_controlnet(self, cnet):
|
||||
cnet.set_previous_controlnet(self.controlnet_linked_list)
|
||||
self.controlnet_linked_list = cnet
|
||||
return
|
||||
|
||||
|
||||
def model_patcher_list_controlnets(self):
|
||||
results = []
|
||||
pointer = self.controlnet_linked_list
|
||||
while pointer is not None:
|
||||
results.append(pointer)
|
||||
pointer = pointer.previous_controlnet
|
||||
return results
|
||||
|
||||
|
||||
def patched_control_merge(self, control_input, control_output, control_prev, output_dtype):
|
||||
out = {'input': [], 'middle': [], 'output': []}
|
||||
|
||||
@ -208,10 +176,6 @@ def patched_calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_op
|
||||
|
||||
|
||||
def patch_all_basics():
|
||||
ModelPatcher.__init__ = patched_model_patcher_init
|
||||
ModelPatcher.clone = patched_model_patcher_clone
|
||||
ModelPatcher.add_patched_controlnet = model_patcher_add_patched_controlnet
|
||||
ModelPatcher.list_controlnets = model_patcher_list_controlnets
|
||||
ControlBase.control_merge = patched_control_merge
|
||||
ldm_patched.modules.samplers.calc_cond_uncond_batch = patched_calc_cond_uncond_batch
|
||||
return
|
||||
|
35
modules_forge/unet_patcher.py
Normal file
35
modules_forge/unet_patcher.py
Normal file
@ -0,0 +1,35 @@
|
||||
import copy
|
||||
from ldm_patched.modules.model_patcher import ModelPatcher
|
||||
|
||||
|
||||
class UnetPatcher(ModelPatcher):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.controlnet_linked_list = None
|
||||
|
||||
def clone(self):
|
||||
n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device,
|
||||
weight_inplace_update=self.weight_inplace_update)
|
||||
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
n.model_keys = self.model_keys
|
||||
n.controlnet_linked_list = self.controlnet_linked_list
|
||||
return n
|
||||
|
||||
def add_patched_controlnet(self, cnet):
|
||||
cnet.set_previous_controlnet(self.controlnet_linked_list)
|
||||
self.controlnet_linked_list = cnet
|
||||
return
|
||||
|
||||
def list_controlnets(self):
|
||||
results = []
|
||||
pointer = self.controlnet_linked_list
|
||||
while pointer is not None:
|
||||
results.append(pointer)
|
||||
pointer = pointer.previous_controlnet
|
||||
return results
|
Loading…
Reference in New Issue
Block a user