diff --git a/README.md b/README.md index 22676add..09e3560b 100644 --- a/README.md +++ b/README.md @@ -75,17 +75,17 @@ Pick out of three sampling methods for txt2img: DDIM, PLMS, k-diffusion: ### Prompt matrix Separate multiple prompts using the `|` character, and the system will produce an image for every combination of them. -For example, if you use `a house in a field of grass|at dawn|illustration` prompt, there are four combinations possible (first part of prompt is always kept): +For example, if you use `a busy city street in a modern city|illustration|cinematic lighting` prompt, there are four combinations possible (first part of prompt is always kept): -- `a house in a field of grass` -- `a house in a field of grass, at dawn` -- `a house in a field of grass, illustration` -- `a house in a field of grass, at dawn, illustration` +- `a busy city street in a modern city` +- `a busy city street in a modern city, illustration` +- `a busy city street in a modern city, cinematic lighting` +- `a busy city street in a modern city, illustration, cinematic lighting` Four images will be produced, in this order, all with same seed and each with corresponding prompt: ![](images/prompt-matrix.png) -Another example, this time with 5 prompts and 16 variations, (text added manually): +Another example, this time with 5 prompts and 16 variations: ![](images/prompt_matrix.jpg) ### Flagging diff --git a/images/prompt-matrix.png b/images/prompt-matrix.png index 53ba92b6..99791330 100644 Binary files a/images/prompt-matrix.png and b/images/prompt-matrix.png differ diff --git a/images/prompt_matrix.jpg b/images/prompt_matrix.jpg index a9749c01..570c8c0e 100644 Binary files a/images/prompt_matrix.jpg and b/images/prompt_matrix.jpg differ diff --git a/webui.py b/webui.py index 6f8efa84..95dcc751 100644 --- a/webui.py +++ b/webui.py @@ -1,11 +1,10 @@ -import PIL import argparse, os, sys, glob import torch import torch.nn as nn import numpy as np import gradio as gr from omegaconf import OmegaConf -from PIL import Image +from PIL import Image, ImageFont, ImageDraw from itertools import islice from einops import rearrange, repeat from torch import autocast @@ -76,23 +75,6 @@ def load_model_from_config(config, ckpt, verbose=False): return model -def load_img_pil(img_pil): - image = img_pil.convert("RGB") - w, h = image.size - print(f"loaded input image of size ({w}, {h})") - w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) - print(f"cropped image to size ({w}, {h})") - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2. * image - 1. - - -def load_img(path): - return load_img_pil(Image.open(path)) - - class CFGDenoiser(nn.Module): def __init__(self, model): super().__init__() @@ -179,6 +161,71 @@ def image_grid(imgs, batch_size, round_down=False): return grid +def draw_prompt_matrix(im, width, height, all_prompts): + def wrap(text, d, font, line_length): + lines = [''] + for word in text.split(): + line = f'{lines[-1]} {word}'.strip() + if d.textlength(line, font=font) <= line_length: + lines[-1] = line + else: + lines.append(word) + return '\n'.join(lines) + + def draw_texts(pos, x, y, texts, sizes): + for i, (text, size) in enumerate(zip(texts, sizes)): + active = pos & (1 << i) != 0 + + if not active: + text = '\u0336'.join(text) + '\u0336' + + d.multiline_text((x, y + size[1] / 2), text, font=fnt, fill=color_active if active else color_inactive, anchor="mm", align="center") + + y += size[1] + line_spacing + + fontsize = (width + height) // 25 + line_spacing = fontsize // 2 + fnt = ImageFont.truetype("arial.ttf", fontsize) + color_active = (0, 0, 0) + color_inactive = (153, 153, 153) + + pad_top = height // 4 + pad_left = width * 3 // 4 + + cols = im.width // width + rows = im.height // height + + prompts = all_prompts[1:] + + result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white") + result.paste(im, (pad_left, pad_top)) + + d = ImageDraw.Draw(result) + + boundary = math.ceil(len(prompts) / 2) + prompts_horiz = [wrap(x, d, fnt, width) for x in prompts[:boundary]] + prompts_vert = [wrap(x, d, fnt, pad_left) for x in prompts[boundary:]] + + sizes_hor = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_horiz]] + sizes_ver = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_vert]] + hor_text_height = sum([x[1] + line_spacing for x in sizes_hor]) - line_spacing + ver_text_height = sum([x[1] + line_spacing for x in sizes_ver]) - line_spacing + + for col in range(cols): + x = pad_left + width * col + width / 2 + y = pad_top / 2 - hor_text_height / 2 + + draw_texts(col, x, y, prompts_horiz, sizes_hor) + + for row in range(rows): + x = pad_left / 2 + y = pad_top + height * row + height / 2 - ver_text_height / 2 + + draw_texts(row, x, y, prompts_vert, sizes_ver) + + return result + + def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int): torch.cuda.empty_cache() @@ -212,30 +259,23 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, pro grid_count = len(os.listdir(outpath)) - 1 prompt_matrix_prompts = [] - comment = "" + prompt_matrix_parts = [] if prompt_matrix: keep_same_seed = True - comment = "Image prompts:\n\n" - items = prompt.split("|") - combination_count = 2 ** (len(items)-1) + prompt_matrix_parts = prompt.split("|") + combination_count = 2 ** (len(prompt_matrix_parts)-1) for combination_num in range(combination_count): - current = items[0] - label = 'A' + current = prompt_matrix_parts[0] - for n, text in enumerate(items[1:]): + for n, text in enumerate(prompt_matrix_parts[1:]): if combination_num & (2**n) > 0: current += ("" if text.strip().startswith(",") else ", ") + text - label += chr(ord('B') + n) - - comment += " - " + label + "\n" prompt_matrix_prompts.append(current) n_iter = math.ceil(len(prompt_matrix_prompts) / batch_size) - comment += "\nwhere:\n" - for n, text in enumerate(items): - comment += " " + chr(ord('A') + n) + " = " + items[n] + "\n" + print(f"Prompt matrix will create {len(prompt_matrix_prompts)} images using a total of {n_iter} batches.") precision_scope = autocast if opt.precision == "autocast" else nullcontext output_images = [] @@ -262,7 +302,7 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, pro x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - if not opt.skip_save or not opt.skip_grid: + if prompt_matrix or not opt.skip_save or not opt.skip_grid: for i, x_sample in enumerate(x_samples_ddim): x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = x_sample.astype(np.uint8) @@ -279,24 +319,23 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, pro output_images.append(image) base_count += 1 - if not opt.skip_grid: - # additionally, save as grid + if prompt_matrix or not opt.skip_grid: grid = image_grid(output_images, batch_size, round_down=prompt_matrix) + + if prompt_matrix: + grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts) + output_images.insert(0, grid) + grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) grid_count += 1 - - if sampler is not None: - del sampler + del sampler info = f""" {prompt} Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''} """.strip() - if len(comment) > 0: - info += "\n\n" + comment - return output_images, seed, info class Flagging(gr.FlaggingCallback): @@ -350,7 +389,7 @@ dream_interface = gr.Interface( gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False), gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False), gr.Slider(minimum=1, maximum=16, step=1, label='Batch count (how many batches of images to generate)', value=1), - gr.Slider(minimum=1, maximum=4, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1), + gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1), gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0), gr.Number(label='Seed', value=-1), gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512), @@ -389,7 +428,7 @@ def translation(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, ddim_e grid_count = len(os.listdir(outpath)) - 1 image = init_img.convert("RGB") - image = image.resize((width, height), resample=PIL.Image.Resampling.LANCZOS) + image = image.resize((width, height), resample=Image.Resampling.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) @@ -466,7 +505,7 @@ img2img_interface = gr.Interface( gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None), gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False), gr.Slider(minimum=1, maximum=16, step=1, label='Batch count (how many batches of images to generate)', value=1), - gr.Slider(minimum=1, maximum=4, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1), + gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1), gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0), gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength', value=0.75), gr.Number(label='Seed', value=-1), @@ -494,7 +533,7 @@ def run_GFPGAN(image, strength): res = Image.fromarray(restored_img) if strength < 1.0: - res = PIL.Image.blend(image, res, strength) + res = Image.blend(image, res, strength) return res