From aec5053938ba57f5019c7c848285e099780bbb4e Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Sat, 27 Jan 2024 09:44:44 -0800 Subject: [PATCH] i --- ldm_patched/contrib/external_latent.py | 24 +++++++++++++++++++ .../contrib/external_post_processing.py | 1 + ldm_patched/contrib/external_stable3d.py | 2 +- ldm_patched/modules/args_parser.py | 4 ++-- modules/cmd_args.py | 3 ++- modules/shared_cmd_options.py | 5 +++- 6 files changed, 34 insertions(+), 5 deletions(-) diff --git a/ldm_patched/contrib/external_latent.py b/ldm_patched/contrib/external_latent.py index c6f874e1..6d753d0f 100644 --- a/ldm_patched/contrib/external_latent.py +++ b/ldm_patched/contrib/external_latent.py @@ -124,10 +124,34 @@ class LatentBatch: samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])]) return (samples_out,) +class LatentBatchSeedBehavior: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "seed_behavior": (["random", "fixed"],),}} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples, seed_behavior): + samples_out = samples.copy() + latent = samples["samples"] + if seed_behavior == "random": + if 'batch_index' in samples_out: + samples_out.pop('batch_index') + elif seed_behavior == "fixed": + batch_number = samples_out.get("batch_index", [0])[0] + samples_out["batch_index"] = [batch_number] * latent.shape[0] + + return (samples_out,) + NODE_CLASS_MAPPINGS = { "LatentAdd": LatentAdd, "LatentSubtract": LatentSubtract, "LatentMultiply": LatentMultiply, "LatentInterpolate": LatentInterpolate, "LatentBatch": LatentBatch, + "LatentBatchSeedBehavior": LatentBatchSeedBehavior, } diff --git a/ldm_patched/contrib/external_post_processing.py b/ldm_patched/contrib/external_post_processing.py index 432c53fb..93cb1212 100644 --- a/ldm_patched/contrib/external_post_processing.py +++ b/ldm_patched/contrib/external_post_processing.py @@ -35,6 +35,7 @@ class Blend: CATEGORY = "image/postprocessing" def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): + image2 = image2.to(image1.device) if image1.shape != image2.shape: image2 = image2.permute(0, 3, 1, 2) image2 = ldm_patched.modules.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center') diff --git a/ldm_patched/contrib/external_stable3d.py b/ldm_patched/contrib/external_stable3d.py index e56cd0fa..bae2623f 100644 --- a/ldm_patched/contrib/external_stable3d.py +++ b/ldm_patched/contrib/external_stable3d.py @@ -48,7 +48,7 @@ class StableZero123_Conditioning: encode_pixels = pixels[:,:,:,:3] t = vae.encode(encode_pixels) cam_embeds = camera_embeddings(elevation, azimuth) - cond = torch.cat([pooled, cam_embeds.repeat((pooled.shape[0], 1, 1))], dim=-1) + cond = torch.cat([pooled, cam_embeds.to(pooled.device).repeat((pooled.shape[0], 1, 1))], dim=-1) positive = [[cond, {"concat_latent_image": t}]] negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] diff --git a/ldm_patched/modules/args_parser.py b/ldm_patched/modules/args_parser.py index e5b84dc1..a0d32c7a 100644 --- a/ldm_patched/modules/args_parser.py +++ b/ldm_patched/modules/args_parser.py @@ -33,8 +33,8 @@ class EnumAction(argparse.Action): parser = argparse.ArgumentParser() -parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0") -parser.add_argument("--port", type=int, default=8188) +#parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0") +#parser.add_argument("--port", type=int, default=8188) parser.add_argument("--disable-header-check", type=str, default=None, metavar="ORIGIN", nargs="?", const="*") parser.add_argument("--web-upload-size", type=float, default=100) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index f1251b6c..ef2ecd7b 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -2,8 +2,9 @@ import argparse import json import os from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401 +from ldm_patched.modules import args_parser -parser = argparse.ArgumentParser() +parser = args_parser.parser parser.add_argument("-f", action='store_true', help=argparse.SUPPRESS) # allows running as root; implemented outside of webui parser.add_argument("--update-all-extensions", action='store_true', help="launch.py argument: download updates for all extensions when starting the program") diff --git a/modules/shared_cmd_options.py b/modules/shared_cmd_options.py index 1efeb74a..c9626667 100644 --- a/modules/shared_cmd_options.py +++ b/modules/shared_cmd_options.py @@ -9,7 +9,10 @@ parser = cmd_args.parser script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file)) script_loading.preload_extensions(extensions_builtin_dir, parser) -cmd_opts, _ = parser.parse_known_args() +if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None: + cmd_opts = parser.parse_args() +else: + cmd_opts, _ = parser.parse_known_args() cmd_opts.webui_is_non_local = any([cmd_opts.share, cmd_opts.listen, cmd_opts.ngrok, cmd_opts.server_name]) cmd_opts.disable_extension_access = cmd_opts.webui_is_non_local and not cmd_opts.enable_insecure_extension_access