Update patch_basic.py
This commit is contained in:
parent
7ac9a06e9c
commit
bdfa1dac7b
@ -1,7 +1,10 @@
|
||||
import torch
|
||||
import ldm_patched.modules.samplers
|
||||
|
||||
from ldm_patched.modules.controlnet import ControlBase
|
||||
from ldm_patched.modules.model_patcher import ModelPatcher
|
||||
from ldm_patched.modules.samplers import get_area_and_mult, can_concat_cond, cond_cat
|
||||
from ldm_patched.modules import model_management
|
||||
|
||||
|
||||
og_model_patcher_init = ModelPatcher.__init__
|
||||
@ -84,10 +87,129 @@ def patched_control_merge(self, control_input, control_output, control_prev, out
|
||||
return out
|
||||
|
||||
|
||||
def patched_calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
||||
out_cond = torch.zeros_like(x_in)
|
||||
out_count = torch.ones_like(x_in) * 1e-37
|
||||
|
||||
out_uncond = torch.zeros_like(x_in)
|
||||
out_uncond_count = torch.ones_like(x_in) * 1e-37
|
||||
|
||||
COND = 0
|
||||
UNCOND = 1
|
||||
|
||||
to_run = []
|
||||
for x in cond:
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
|
||||
to_run += [(p, COND)]
|
||||
if uncond is not None:
|
||||
for x in uncond:
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
|
||||
to_run += [(p, UNCOND)]
|
||||
|
||||
while len(to_run) > 0:
|
||||
first = to_run[0]
|
||||
first_shape = first[0][0].shape
|
||||
to_batch_temp = []
|
||||
for x in range(len(to_run)):
|
||||
if can_concat_cond(to_run[x][0], first[0]):
|
||||
to_batch_temp += [x]
|
||||
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
free_memory = model_management.get_free_memory(x_in.device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp) // i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
if model.memory_required(input_shape) < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
|
||||
input_x = []
|
||||
mult = []
|
||||
c = []
|
||||
cond_or_uncond = []
|
||||
area = []
|
||||
control = None
|
||||
patches = None
|
||||
for x in to_batch:
|
||||
o = to_run.pop(x)
|
||||
p = o[0]
|
||||
input_x.append(p.input_x)
|
||||
mult.append(p.mult)
|
||||
c.append(p.conditioning)
|
||||
area.append(p.area)
|
||||
cond_or_uncond.append(o[1])
|
||||
control = p.control
|
||||
patches = p.patches
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x)
|
||||
c = cond_cat(c)
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
|
||||
if control is not None:
|
||||
control.current_cond_or_uncond = cond_or_uncond
|
||||
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
|
||||
|
||||
transformer_options = {}
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = model_options['transformer_options'].copy()
|
||||
|
||||
if patches is not None:
|
||||
if "patches" in transformer_options:
|
||||
cur_patches = transformer_options["patches"].copy()
|
||||
for p in patches:
|
||||
if p in cur_patches:
|
||||
cur_patches[p] = cur_patches[p] + patches[p]
|
||||
else:
|
||||
cur_patches[p] = patches[p]
|
||||
else:
|
||||
transformer_options["patches"] = patches
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["sigmas"] = timestep
|
||||
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model.apply_model,
|
||||
{"input": input_x, "timestep": timestep_, "c": c,
|
||||
"cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||
else:
|
||||
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
||||
del input_x
|
||||
|
||||
for o in range(batch_chunks):
|
||||
if cond_or_uncond[o] == COND:
|
||||
out_cond[:, :, area[o][2]:area[o][0] + area[o][2], area[o][3]:area[o][1] + area[o][3]] += output[o] * \
|
||||
mult[o]
|
||||
out_count[:, :, area[o][2]:area[o][0] + area[o][2], area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
||||
else:
|
||||
out_uncond[:, :, area[o][2]:area[o][0] + area[o][2], area[o][3]:area[o][1] + area[o][3]] += output[o] * \
|
||||
mult[o]
|
||||
out_uncond_count[:, :, area[o][2]:area[o][0] + area[o][2], area[o][3]:area[o][1] + area[o][3]] += mult[
|
||||
o]
|
||||
del mult
|
||||
|
||||
out_cond /= out_count
|
||||
del out_count
|
||||
out_uncond /= out_uncond_count
|
||||
del out_uncond_count
|
||||
return out_cond, out_uncond
|
||||
|
||||
|
||||
def patch_all_basics():
|
||||
ModelPatcher.__init__ = patched_model_patcher_init
|
||||
ModelPatcher.clone = patched_model_patcher_clone
|
||||
ModelPatcher.add_patched_controlnet = model_patcher_add_patched_controlnet
|
||||
ModelPatcher.list_controlnets = model_patcher_list_controlnets
|
||||
ControlBase.control_merge = patched_control_merge
|
||||
ldm_patched.modules.samplers.calc_cond_uncond_batch = patched_calc_cond_uncond_batch
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user