3d039591fe
Previously before this commit, credits are already in entry and licenses are already in root. This commit will make info clearer.
143 lines
5.2 KiB
Python
143 lines
5.2 KiB
Python
# Taken from https://github.com/comfyanonymous/ComfyUI
|
|
# This file is only for reference, and not used in the backend or runtime.
|
|
|
|
|
|
import torch
|
|
|
|
class LatentRebatch:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": { "latents": ("LATENT",),
|
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
|
}}
|
|
RETURN_TYPES = ("LATENT",)
|
|
INPUT_IS_LIST = True
|
|
OUTPUT_IS_LIST = (True, )
|
|
|
|
FUNCTION = "rebatch"
|
|
|
|
CATEGORY = "latent/batch"
|
|
|
|
@staticmethod
|
|
def get_batch(latents, list_ind, offset):
|
|
'''prepare a batch out of the list of latents'''
|
|
samples = latents[list_ind]['samples']
|
|
shape = samples.shape
|
|
mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu')
|
|
if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]:
|
|
torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear")
|
|
if mask.shape[0] < samples.shape[0]:
|
|
mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]]
|
|
if 'batch_index' in latents[list_ind]:
|
|
batch_inds = latents[list_ind]['batch_index']
|
|
else:
|
|
batch_inds = [x+offset for x in range(shape[0])]
|
|
return samples, mask, batch_inds
|
|
|
|
@staticmethod
|
|
def get_slices(indexable, num, batch_size):
|
|
'''divides an indexable object into num slices of length batch_size, and a remainder'''
|
|
slices = []
|
|
for i in range(num):
|
|
slices.append(indexable[i*batch_size:(i+1)*batch_size])
|
|
if num * batch_size < len(indexable):
|
|
return slices, indexable[num * batch_size:]
|
|
else:
|
|
return slices, None
|
|
|
|
@staticmethod
|
|
def slice_batch(batch, num, batch_size):
|
|
result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
|
|
return list(zip(*result))
|
|
|
|
@staticmethod
|
|
def cat_batch(batch1, batch2):
|
|
if batch1[0] is None:
|
|
return batch2
|
|
result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
|
|
return result
|
|
|
|
def rebatch(self, latents, batch_size):
|
|
batch_size = batch_size[0]
|
|
|
|
output_list = []
|
|
current_batch = (None, None, None)
|
|
processed = 0
|
|
|
|
for i in range(len(latents)):
|
|
# fetch new entry of list
|
|
#samples, masks, indices = self.get_batch(latents, i)
|
|
next_batch = self.get_batch(latents, i, processed)
|
|
processed += len(next_batch[2])
|
|
# set to current if current is None
|
|
if current_batch[0] is None:
|
|
current_batch = next_batch
|
|
# add previous to list if dimensions do not match
|
|
elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
|
|
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
|
|
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
|
current_batch = next_batch
|
|
# cat if everything checks out
|
|
else:
|
|
current_batch = self.cat_batch(current_batch, next_batch)
|
|
|
|
# add to list if dimensions gone above target batch size
|
|
if current_batch[0].shape[0] > batch_size:
|
|
num = current_batch[0].shape[0] // batch_size
|
|
sliced, remainder = self.slice_batch(current_batch, num, batch_size)
|
|
|
|
for i in range(num):
|
|
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
|
|
|
|
current_batch = remainder
|
|
|
|
#add remainder
|
|
if current_batch[0] is not None:
|
|
sliced, _ = self.slice_batch(current_batch, 1, batch_size)
|
|
output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
|
|
|
|
#get rid of empty masks
|
|
for s in output_list:
|
|
if s['noise_mask'].mean() == 1.0:
|
|
del s['noise_mask']
|
|
|
|
return (output_list,)
|
|
|
|
class ImageRebatch:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": { "images": ("IMAGE",),
|
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
|
}}
|
|
RETURN_TYPES = ("IMAGE",)
|
|
INPUT_IS_LIST = True
|
|
OUTPUT_IS_LIST = (True, )
|
|
|
|
FUNCTION = "rebatch"
|
|
|
|
CATEGORY = "image/batch"
|
|
|
|
def rebatch(self, images, batch_size):
|
|
batch_size = batch_size[0]
|
|
|
|
output_list = []
|
|
all_images = []
|
|
for img in images:
|
|
for i in range(img.shape[0]):
|
|
all_images.append(img[i:i+1])
|
|
|
|
for i in range(0, len(all_images), batch_size):
|
|
output_list.append(torch.cat(all_images[i:i+batch_size], dim=0))
|
|
|
|
return (output_list,)
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"RebatchLatents": LatentRebatch,
|
|
"RebatchImages": ImageRebatch,
|
|
}
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"RebatchLatents": "Rebatch Latents",
|
|
"RebatchImages": "Rebatch Images",
|
|
}
|