add block inner modifiers

This commit is contained in:
lllyasviel 2024-02-03 00:57:49 -08:00
parent 905a027fc1
commit 95c9ed52fc
2 changed files with 13 additions and 1 deletions

View File

@ -30,7 +30,12 @@ class TimestepBlock(nn.Module):
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
for layer in ts:
block_inner_modifiers = transformer_options.get("block_inner_modifiers", [])
for layer_index, layer in enumerate(ts):
for modifier in block_inner_modifiers:
x = modifier(x, 'before', layer, layer_index, ts, transformer_options)
if isinstance(layer, VideoResBlock):
x = layer(x, emb, num_video_frames, image_only_indicator)
elif isinstance(layer, TimestepBlock):
@ -47,6 +52,9 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
for modifier in block_inner_modifiers:
x = modifier(x, 'after', layer, layer_index, ts, transformer_options)
return x
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):

View File

@ -88,6 +88,10 @@ class UnetPatcher(ModelPatcher):
self.append_transformer_option('block_modifiers', modifier, ensure_uniqueness)
return
def add_block_inner_modifier(self, modifier, ensure_uniqueness=False):
self.append_transformer_option('block_inner_modifiers', modifier, ensure_uniqueness)
return
def add_controlnet_conditioning_modifier(self, modifier, ensure_uniqueness=False):
self.append_transformer_option('controlnet_conditioning_modifiers', modifier, ensure_uniqueness)
return