my-sd/modules_forge/patch_basic.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

244 lines
8.7 KiB
Python
Raw Normal View History

2024-01-28 07:11:56 +00:00
import torch
2024-01-30 03:28:45 +00:00
import os
2024-01-29 17:28:38 +00:00
import time
2024-01-30 03:28:45 +00:00
import safetensors
2024-01-28 07:34:09 +00:00
import ldm_patched.modules.samplers
2024-01-28 07:11:56 +00:00
from ldm_patched.modules.controlnet import ControlBase
2024-01-28 07:34:09 +00:00
from ldm_patched.modules.samplers import get_area_and_mult, can_concat_cond, cond_cat
from ldm_patched.modules import model_management
2024-01-28 15:56:42 +00:00
from modules_forge.controlnet import compute_controlnet_weighting
2024-01-28 16:27:43 +00:00
from modules_forge.forge_util import compute_cond_mark
2024-01-28 04:42:38 +00:00
2024-01-28 07:11:56 +00:00
def patched_control_merge(self, control_input, control_output, control_prev, output_dtype):
out = {'input': [], 'middle': [], 'output': []}
if control_input is not None:
for i in range(len(control_input)):
key = 'input'
x = control_input[i]
if x is not None:
x *= self.strength
if x.dtype != output_dtype:
x = x.to(output_dtype)
out[key].insert(0, x)
if control_output is not None:
for i in range(len(control_output)):
if i == (len(control_output) - 1):
key = 'middle'
index = 0
else:
key = 'output'
index = i
x = control_output[i]
if x is not None:
if self.global_average_pooling:
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
x *= self.strength
if x.dtype != output_dtype:
x = x.to(output_dtype)
out[key].append(x)
2024-01-28 07:39:08 +00:00
2024-01-28 15:56:42 +00:00
out = compute_controlnet_weighting(
out,
positive_advanced_weighting=self.positive_advanced_weighting,
negative_advanced_weighting=self.negative_advanced_weighting,
advanced_frame_weighting=self.advanced_frame_weighting,
advanced_sigma_weighting=self.advanced_sigma_weighting,
transformer_options=self.transformer_options
)
2024-01-28 07:39:08 +00:00
2024-01-28 07:11:56 +00:00
if control_prev is not None:
for x in ['input', 'middle', 'output']:
o = out[x]
for i in range(len(control_prev[x])):
prev_val = control_prev[x][i]
if i >= len(o):
o.append(prev_val)
elif prev_val is not None:
if o[i] is None:
o[i] = prev_val
else:
if o[i].shape[0] < prev_val.shape[0]:
o[i] = prev_val + o[i]
else:
o[i] += prev_val
return out
2024-01-28 07:34:09 +00:00
def patched_calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in) * 1e-37
out_uncond = torch.zeros_like(x_in)
out_uncond_count = torch.ones_like(x_in) * 1e-37
COND = 0
UNCOND = 1
to_run = []
for x in cond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, COND)]
if uncond is not None:
for x in uncond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, UNCOND)]
while len(to_run) > 0:
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]):
to_batch_temp += [x]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = model_management.get_free_memory(x_in.device)
for i in range(1, len(to_batch_temp) + 1):
2024-01-28 08:03:04 +00:00
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
2024-01-28 07:34:09 +00:00
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) < free_memory:
to_batch = batch_amount
break
input_x = []
mult = []
c = []
cond_or_uncond = []
area = []
control = None
patches = None
for x in to_batch:
o = to_run.pop(x)
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
control = p.control
patches = p.patches
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x)
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
transformer_options = {}
if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy()
if patches is not None:
if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p]
else:
cur_patches[p] = patches[p]
else:
transformer_options["patches"] = patches
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["sigmas"] = timestep
2024-01-28 16:27:43 +00:00
cond_mark = compute_cond_mark(cond_or_uncond=cond_or_uncond, sigmas=timestep)
transformer_options["cond_mark"] = cond_mark
2024-01-28 07:34:09 +00:00
c['transformer_options'] = transformer_options
2024-01-28 15:47:41 +00:00
if control is not None:
control.transformer_options = transformer_options
2024-01-28 15:56:42 +00:00
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
2024-01-28 15:47:41 +00:00
2024-01-28 07:34:09 +00:00
if 'model_function_wrapper' in model_options:
2024-01-28 08:03:04 +00:00
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
2024-01-28 07:34:09 +00:00
else:
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
del input_x
for o in range(batch_chunks):
if cond_or_uncond[o] == COND:
2024-01-28 08:03:04 +00:00
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
2024-01-28 07:34:09 +00:00
else:
2024-01-28 08:03:04 +00:00
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
2024-01-28 07:34:09 +00:00
del mult
out_cond /= out_count
del out_count
out_uncond /= out_uncond_count
del out_uncond_count
return out_cond, out_uncond
2024-01-29 17:28:38 +00:00
def patched_load_models_gpu(*args, **kwargs):
execution_start_time = time.perf_counter()
y = model_management.load_models_gpu_origin(*args, **kwargs)
moving_time = time.perf_counter() - execution_start_time
if moving_time > 0.1:
print(f'Moving model(s) has taken {moving_time:.2f} seconds')
return y
2024-01-30 03:28:45 +00:00
def build_loaded(module, loader_name):
original_loader_name = loader_name + '_origin'
if not hasattr(module, original_loader_name):
setattr(module, original_loader_name, getattr(module, loader_name))
original_loader = getattr(module, original_loader_name)
def loader(*args, **kwargs):
result = None
try:
result = original_loader(*args, **kwargs)
except Exception as e:
result = None
exp = str(e) + '\n'
for path in list(args) + list(kwargs.values()):
if isinstance(path, str):
if os.path.exists(path):
exp += f'File corrupted: {path} \n'
corrupted_backup_file = path + '.corrupted'
if os.path.exists(corrupted_backup_file):
os.remove(corrupted_backup_file)
os.replace(path, corrupted_backup_file)
if os.path.exists(path):
os.remove(path)
exp += f'Forge has tried to move the corrupted file to {corrupted_backup_file} \n'
2024-01-30 03:29:45 +00:00
exp += f'You may try again now and Forge will download models again. \n'
2024-01-30 03:28:45 +00:00
raise ValueError(exp)
return result
setattr(module, loader_name, loader)
return
2024-01-28 04:42:38 +00:00
def patch_all_basics():
2024-01-29 17:28:38 +00:00
if not hasattr(model_management, 'load_models_gpu_origin'):
model_management.load_models_gpu_origin = model_management.load_models_gpu
model_management.load_models_gpu = patched_load_models_gpu
2024-01-28 07:11:56 +00:00
ControlBase.control_merge = patched_control_merge
2024-01-28 07:34:09 +00:00
ldm_patched.modules.samplers.calc_cond_uncond_batch = patched_calc_cond_uncond_batch
2024-01-30 03:28:45 +00:00
build_loaded(safetensors.torch, 'load_file')
build_loaded(torch, 'load')
2024-01-28 04:42:38 +00:00
return