vae already sliced in inner loop

This commit is contained in:
lllyasviel 2024-03-08 00:40:33 -08:00
parent e48533bdcd
commit 10b5ca2541

View File

@ -633,10 +633,10 @@ class DecodedSamples(list):
def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
samples = DecodedSamples()
samples_pytorch = decode_first_stage(model, batch).to(target_device)
for i in range(batch.shape[0]):
sample = decode_first_stage(model, batch[i:i + 1])[0]
samples.append(sample.to(target_device))
for x in samples_pytorch:
samples.append(x)
return samples