my-sd/modules_forge/unet_patcher.py
lllyasviel a3dffecb3f i
2024-01-28 06:21:53 -08:00

36 lines
1.2 KiB
Python

import copy
from ldm_patched.modules.model_patcher import ModelPatcher
class UnetPatcher(ModelPatcher):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.controlnet_linked_list = None
def clone(self):
n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device,
weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
n.object_patches = self.object_patches.copy()
n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys
n.controlnet_linked_list = self.controlnet_linked_list
return n
def add_patched_controlnet(self, cnet):
cnet.set_previous_controlnet(self.controlnet_linked_list)
self.controlnet_linked_list = cnet
return
def list_controlnets(self):
results = []
pointer = self.controlnet_linked_list
while pointer is not None:
results.append(pointer)
pointer = pointer.previous_controlnet
return results