diff --git a/modules/processing.py b/modules/processing.py index 4528866c..64e564e0 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -633,10 +633,10 @@ class DecodedSamples(list): def decode_latent_batch(model, batch, target_device=None, check_for_nans=False): samples = DecodedSamples() + samples_pytorch = decode_first_stage(model, batch).to(target_device) - for i in range(batch.shape[0]): - sample = decode_first_stage(model, batch[i:i + 1])[0] - samples.append(sample.to(target_device)) + for x in samples_pytorch: + samples.append(x) return samples