change patcher method

This commit is contained in:
lllyasviel 2024-03-07 00:26:17 -08:00
parent b9705c58f6
commit e48533bdcd
2 changed files with 9 additions and 3 deletions

View File

@ -190,10 +190,10 @@ class ModelPatcher:
def patch_model(self, device_to=None, patch_weights=True):
for k in self.object_patches:
old = getattr(self.model, k)
old = ldm_patched.modules.utils.get_attr(self.model, k)
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
setattr(self.model, k, self.object_patches[k])
ldm_patched.modules.utils.set_attr_raw(self.model, k, self.object_patches[k])
if patch_weights:
model_sd = self.model_state_dict()
@ -378,6 +378,6 @@ class ModelPatcher:
keys = list(self.object_patches_backup.keys())
for k in keys:
setattr(self.model, k, self.object_patches_backup[k])
ldm_patched.modules.utils.set_attr_raw(self.model, k, self.object_patches_backup[k])
self.object_patches_backup = {}

View File

@ -286,6 +286,12 @@ def set_attr(obj, attr, value):
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
del prev
def set_attr_raw(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
setattr(obj, attrs[-1], value)
def copy_to_param(obj, attr, value):
# inplace update tensor instead of replacing it
attrs = attr.split(".")