add block inner modifiers
This commit is contained in:
parent
905a027fc1
commit
95c9ed52fc
@ -30,7 +30,12 @@ class TimestepBlock(nn.Module):
|
||||
|
||||
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
|
||||
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
|
||||
for layer in ts:
|
||||
block_inner_modifiers = transformer_options.get("block_inner_modifiers", [])
|
||||
|
||||
for layer_index, layer in enumerate(ts):
|
||||
for modifier in block_inner_modifiers:
|
||||
x = modifier(x, 'before', layer, layer_index, ts, transformer_options)
|
||||
|
||||
if isinstance(layer, VideoResBlock):
|
||||
x = layer(x, emb, num_video_frames, image_only_indicator)
|
||||
elif isinstance(layer, TimestepBlock):
|
||||
@ -47,6 +52,9 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
|
||||
x = layer(x, output_shape=output_shape)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
for modifier in block_inner_modifiers:
|
||||
x = modifier(x, 'after', layer, layer_index, ts, transformer_options)
|
||||
return x
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
|
@ -88,6 +88,10 @@ class UnetPatcher(ModelPatcher):
|
||||
self.append_transformer_option('block_modifiers', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
||||
def add_block_inner_modifier(self, modifier, ensure_uniqueness=False):
|
||||
self.append_transformer_option('block_inner_modifiers', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
||||
def add_controlnet_conditioning_modifier(self, modifier, ensure_uniqueness=False):
|
||||
self.append_transformer_option('controlnet_conditioning_modifiers', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user