Merge branch 'master' of github.com:AUTOMATIC1111/stable-diffusion-webui

This commit is contained in:
unknown 2022-12-25 02:03:55 -06:00
commit 876da12599
46 changed files with 628 additions and 209 deletions

View File

@ -82,8 +82,8 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- Use VAEs
- Estimated completion time in progress bar
- API
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions
## Installation and Running

View File

@ -1,3 +1,4 @@
import os
import gc
import time
import warnings
@ -8,6 +9,7 @@ import torchvision
from PIL import Image
from einops import rearrange, repeat
from omegaconf import OmegaConf
import safetensors.torch
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap
@ -24,12 +26,16 @@ class LDSR:
global cached_ldsr_model
if shared.opts.ldsr_cached and cached_ldsr_model is not None:
print(f"Loading model from cache")
print("Loading model from cache")
model: torch.nn.Module = cached_ldsr_model
else:
print(f"Loading model from {self.modelPath}")
pl_sd = torch.load(self.modelPath, map_location="cpu")
sd = pl_sd["state_dict"]
_, extension = os.path.splitext(self.modelPath)
if extension.lower() == ".safetensors":
pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu")
else:
pl_sd = torch.load(self.modelPath, map_location="cpu")
sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
config = OmegaConf.load(self.yamlPath)
config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
model: torch.nn.Module = instantiate_from_config(config.model)

View File

@ -25,6 +25,7 @@ class UpscalerLDSR(Upscaler):
yaml_path = os.path.join(self.model_path, "project.yaml")
old_model_path = os.path.join(self.model_path, "model.pth")
new_model_path = os.path.join(self.model_path, "model.ckpt")
safetensors_model_path = os.path.join(self.model_path, "model.safetensors")
if os.path.exists(yaml_path):
statinfo = os.stat(yaml_path)
if statinfo.st_size >= 10485760:
@ -33,8 +34,11 @@ class UpscalerLDSR(Upscaler):
if os.path.exists(old_model_path):
print("Renaming model from model.pth to model.ckpt")
os.rename(old_model_path, new_model_path)
model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
file_name="model.ckpt", progress=True)
if os.path.exists(safetensors_model_path):
model = safetensors_model_path
else:
model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
file_name="model.ckpt", progress=True)
yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
file_name="project.yaml", progress=True)

View File

@ -9,7 +9,7 @@ contextMenuInit = function(){
function showContextMenu(event,element,menuEntries){
let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
let oldMenu = gradioApp().querySelector('#context-menu')
if(oldMenu){
@ -61,15 +61,15 @@ contextMenuInit = function(){
}
function appendContextMenuOption(targetEmementSelector,entryName,entryFunction){
currentItems = menuSpecs.get(targetEmementSelector)
function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
currentItems = menuSpecs.get(targetElementSelector)
if(!currentItems){
currentItems = []
menuSpecs.set(targetEmementSelector,currentItems);
menuSpecs.set(targetElementSelector,currentItems);
}
let newItem = {'id':targetEmementSelector+'_'+uid(),
let newItem = {'id':targetElementSelector+'_'+uid(),
'name':entryName,
'func':entryFunction,
'isNew':true}
@ -97,7 +97,7 @@ contextMenuInit = function(){
if(source.id && source.id.indexOf('check_progress')>-1){
return
}
let oldMenu = gradioApp().querySelector('#context-menu')
if(oldMenu){
oldMenu.remove()
@ -117,7 +117,7 @@ contextMenuInit = function(){
})
});
eventListenerApplied=true
}
return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]
@ -152,8 +152,8 @@ addContextMenuEventListener = initResponse[2];
generateOnRepeat('#img2img_generate','#img2img_interrupt');
})
let cancelGenerateForever = function(){
clearInterval(window.generateOnRepeatInterval)
let cancelGenerateForever = function(){
clearInterval(window.generateOnRepeatInterval)
}
appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever)
@ -162,7 +162,7 @@ addContextMenuEventListener = initResponse[2];
appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
appendContextMenuOption('#roll','Roll three',
function(){
function(){
let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
setTimeout(function(){rollbutton.click()},100)
setTimeout(function(){rollbutton.click()},200)

View File

@ -6,6 +6,7 @@ titles = {
"GFPGAN": "Restore low quality faces using GFPGAN neural network",
"Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps to higher than 30-40 does not help",
"DDIM": "Denoising Diffusion Implicit Models - best at inpainting",
"DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution",
"Batch count": "How many batches of images to create",
"Batch size": "How many image to create in a single batch",
@ -17,7 +18,7 @@ titles = {
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
"\u{1f4c2}": "Open images output directory",
"\u{1f4be}": "Save style",
"\U0001F5D1": "Clear prompt"
"\U0001F5D1": "Clear prompt",
"\u{1f4cb}": "Apply selected styles to current prompt",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
@ -96,7 +97,10 @@ titles = {
"Learning rate": "how fast should the training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc."
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
"Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resoluton and lower quality.",
"Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality."
}

View File

@ -15,7 +15,7 @@ onUiUpdate(function(){
}
}
const galleryPreviews = gradioApp().querySelectorAll('img.h-full.w-full.overflow-hidden');
const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"][style*="display: block"] img.h-full.w-full.overflow-hidden');
if (galleryPreviews == null) return;

View File

@ -3,7 +3,7 @@ global_progressbars = {}
galleries = {}
galleryObservers = {}
// this tracks laumnches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running
// this tracks launches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running
timeoutIds = {}
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
@ -20,21 +20,21 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
var skip = id_skip ? gradioApp().getElementById(id_skip) : null
var interrupt = gradioApp().getElementById(id_interrupt)
if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){
if(progressbar.innerText){
let newtitle = '[' + progressbar.innerText.trim() + '] Stable Diffusion';
if(document.title != newtitle){
document.title = newtitle;
document.title = newtitle;
}
}else{
let newtitle = 'Stable Diffusion'
if(document.title != newtitle){
document.title = newtitle;
document.title = newtitle;
}
}
}
if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){
global_progressbars[id_progressbar] = progressbar
@ -63,7 +63,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
skip.style.display = "none"
}
interrupt.style.display = "none"
//disconnect observer once generation finished, so user can close selected image if they want
if (galleryObservers[id_gallery]) {
galleryObservers[id_gallery].disconnect();

View File

@ -100,7 +100,7 @@ function create_submit_args(args){
// As it is currently, txt2img and img2img send back the previous output args (txt2img_gallery, generation_info, html_info) whenever you generate a new image.
// This can lead to uploading a huge gallery of previously generated images, which leads to an unnecessary delay between submitting and beginning to generate.
// I don't know why gradio is seding outputs along with inputs, but we can prevent sending the image gallery here, which seems to be an issue for some.
// I don't know why gradio is sending outputs along with inputs, but we can prevent sending the image gallery here, which seems to be an issue for some.
// If gradio at some point stops sending outputs, this may break something
if(Array.isArray(res[res.length - 3])){
res[res.length - 3] = null

BIN
models/VAE-approx/model.pt Normal file

Binary file not shown.

View File

@ -10,13 +10,17 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest
import modules.shared as shared
from modules import sd_samplers, deepbooru
from modules import sd_samplers, deepbooru, sd_hijack
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.extras import run_extras, run_pnginfo
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
from modules.sd_models import checkpoints_list
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
from typing import List
def upscaler_to_index(name: str):
@ -67,10 +71,10 @@ def encode_pil_to_base64(image):
class Api:
def __init__(self, app: FastAPI, queue_lock: Lock):
if shared.cmd_opts.api_auth:
self.credenticals = dict()
self.credentials = dict()
for auth in shared.cmd_opts.api_auth.split(","):
user, password = auth.split(":")
self.credenticals[user] = password
self.credentials[user] = password
self.router = APIRouter()
self.app = app
@ -93,18 +97,24 @@ class Api:
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
self.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem])
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
def add_api_route(self, path: str, endpoint, **kwargs):
if shared.cmd_opts.api_auth:
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
return self.app.add_api_route(path, endpoint, **kwargs)
def auth(self, credenticals: HTTPBasicCredentials = Depends(HTTPBasic())):
if credenticals.username in self.credenticals:
if compare_digest(credenticals.password, self.credenticals[credenticals.username]):
def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
if credentials.username in self.credentials:
if compare_digest(credentials.password, self.credentials[credentials.username]):
return True
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
@ -180,7 +190,7 @@ class Api:
reqDict['image'] = decode_base64_to_image(reqDict['image'])
with self.queue_lock:
result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", **reqDict)
result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
@ -196,7 +206,7 @@ class Api:
reqDict.pop('imageList')
with self.queue_lock:
result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict)
result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
@ -239,7 +249,7 @@ class Api:
def interrogateapi(self, interrogatereq: InterrogateRequest):
image_b64 = interrogatereq.image
if image_b64 is None:
raise HTTPException(status_code=404, detail="Image not found")
raise HTTPException(status_code=404, detail="Image not found")
img = decode_base64_to_image(image_b64)
img = img.convert('RGB')
@ -252,7 +262,7 @@ class Api:
processed = deepbooru.model.tag(img)
else:
raise HTTPException(status_code=404, detail="Model not found")
return InterrogateResponse(caption=processed)
def interruptapi(self):
@ -308,7 +318,7 @@ class Api:
def get_realesrgan_models(self):
return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
def get_promp_styles(self):
def get_prompt_styles(self):
styleList = []
for k in shared.prompt_styles.styles:
style = shared.prompt_styles.styles[k]
@ -322,6 +332,92 @@ class Api:
def get_artists(self):
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
def refresh_checkpoints(self):
shared.refresh_checkpoints()
def create_embedding(self, args: dict):
try:
shared.state.begin()
filename = create_embedding(**args) # create empty embedding
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
shared.state.end()
return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
except AssertionError as e:
shared.state.end()
return TrainResponse(info = "create embedding error: {error}".format(error = e))
def create_hypernetwork(self, args: dict):
try:
shared.state.begin()
filename = create_hypernetwork(**args) # create empty embedding
shared.state.end()
return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
except AssertionError as e:
shared.state.end()
return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))
def preprocess(self, args: dict):
try:
shared.state.begin()
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
shared.state.end()
return PreprocessResponse(info = 'preprocess complete')
except KeyError as e:
shared.state.end()
return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
except AssertionError as e:
shared.state.end()
return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
except FileNotFoundError as e:
shared.state.end()
return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))
def train_embedding(self, args: dict):
try:
shared.state.begin()
apply_optimizations = shared.opts.training_xattention_optimizations
error = None
filename = ''
if not apply_optimizations:
sd_hijack.undo_optimizations()
try:
embedding, filename = train_embedding(**args) # can take a long time to complete
except Exception as e:
error = e
finally:
if not apply_optimizations:
sd_hijack.apply_optimizations()
shared.state.end()
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
except AssertionError as msg:
shared.state.end()
return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))
def train_hypernetwork(self, args: dict):
try:
shared.state.begin()
initial_hypernetwork = shared.loaded_hypernetwork
apply_optimizations = shared.opts.training_xattention_optimizations
error = None
filename = ''
if not apply_optimizations:
sd_hijack.undo_optimizations()
try:
hypernetwork, filename = train_hypernetwork(*args)
except Exception as e:
error = e
finally:
shared.loaded_hypernetwork = initial_hypernetwork
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
if not apply_optimizations:
sd_hijack.apply_optimizations()
shared.state.end()
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
except AssertionError as msg:
shared.state.end()
return TrainResponse(info = "train embedding error: {error}".format(error = error))
def launch(self, server_name, port):
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port)

View File

@ -128,7 +128,7 @@ class ExtrasBaseRequest(BaseModel):
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?")
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
@ -175,6 +175,15 @@ class InterrogateRequest(BaseModel):
class InterrogateResponse(BaseModel):
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
class TrainResponse(BaseModel):
info: str = Field(title="Train info", description="Response string from train embedding or hypernetwork task.")
class CreateResponse(BaseModel):
info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
class PreprocessResponse(BaseModel):
info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
fields = {}
for key, metadata in opts.data_labels.items():
value = opts.data.get(key)

View File

@ -382,7 +382,7 @@ class VQAutoEncoder(nn.Module):
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
logger.info(f'vqgan is loaded from: {model_path} [params]')
else:
raise ValueError(f'Wrong params!')
raise ValueError('Wrong params!')
def forward(self, x):
@ -431,7 +431,7 @@ class VQGANDiscriminator(nn.Module):
elif 'params' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
else:
raise ValueError(f'Wrong params!')
raise ValueError('Wrong params!')
def forward(self, x):
return self.main(x)

View File

@ -79,7 +79,9 @@ class DeepDanbooru:
res = []
for tag in tags:
filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
for tag in [x for x in tags if x not in filtertags]:
probability = probability_dict[tag]
tag_outformat = tag
if use_spaces:

View File

@ -125,7 +125,16 @@ def layer_norm_fix(*args, **kwargs):
return orig_layer_norm(*args, **kwargs)
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
orig_tensor_numpy = torch.Tensor.numpy
def numpy_fix(self, *args, **kwargs):
if self.requires_grad:
self = self.detach()
return orig_tensor_numpy(self, *args, **kwargs)
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
if has_mps() and version.parse(torch.__version__) < version.parse("1.13"):
torch.Tensor.to = tensor_to_fix
torch.nn.functional.layer_norm = layer_norm_fix
torch.Tensor.numpy = numpy_fix

View File

@ -55,7 +55,7 @@ class LruCache(OrderedDict):
cached_images: LruCache = LruCache(max_size=5)
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool):
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
devices.torch_gc()
imageArr = []
@ -188,13 +188,20 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
for op in extras_ops:
image, info = op(image, info)
if opts.use_original_name_batch and image_name != None:
if opts.use_original_name_batch and image_name is not None:
basename = os.path.splitext(os.path.basename(image_name))[0]
else:
basename = ''
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
if save_output:
# Add upscaler name as a suffix.
suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
# Add second upscaler if applicable.
if suffix and extras_upscaler_2 and extras_upscaler_2_visibility:
suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}"
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
if opts.enable_pnginfo:
image.info = existing_pnginfo

View File

@ -14,6 +14,7 @@ re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
type_of_gr_update = type(gr.update())
paste_fields = {}
bind_list = []
@ -139,6 +140,30 @@ def run_bind():
)
def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
"""Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
If the infotext has no hash, then a hypernet with the same name will be selected instead.
"""
hypernet_name = hypernet_name.lower()
if hypernet_hash is not None:
# Try to match the hash in the name
for hypernet_key in shared.hypernetworks.keys():
result = re_hypernet_hash.search(hypernet_key)
if result is not None and result[1] == hypernet_hash:
return hypernet_key
else:
# Fall back to a hypernet with the same name
for hypernet_key in shared.hypernetworks.keys():
if hypernet_key.lower().startswith(hypernet_name):
return hypernet_key
return None
def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI:
```
@ -188,6 +213,14 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "Clip skip" not in res:
res["Clip skip"] = "1"
if "Hypernet strength" not in res:
res["Hypernet strength"] = "1"
if "Hypernet" in res:
hypernet_name = res["Hypernet"]
hypernet_hash = res.get("Hypernet hash", None)
res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash)
return res

View File

@ -277,7 +277,7 @@ def load_hypernetwork(filename):
print(traceback.format_exc(), file=sys.stderr)
else:
if shared.loaded_hypernetwork is not None:
print(f"Unloading hypernetwork")
print("Unloading hypernetwork")
shared.loaded_hypernetwork = None
@ -378,6 +378,32 @@ def report_statistics(loss_info:dict):
print(e)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists"
if type(layer_structure) == str:
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
name=name,
enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure,
activation_func=activation_func,
weight_init=weight_init,
add_layer_norm=add_layer_norm,
use_dropout=use_dropout,
)
hypernet.save(fn)
shared.reload_hypernetworks()
return fn
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
@ -417,7 +443,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
initial_step = hypernetwork.step or 0
if initial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
shared.state.textinfo = "Model has already been trained beyond specified max steps"
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)

View File

@ -3,39 +3,16 @@ import os
import re
import gradio as gr
import modules.textual_inversion.preprocess
import modules.textual_inversion.textual_inversion
import modules.hypernetworks.hypernetwork
from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork
not_available = ["hardswish", "multiheadattention"]
keys = list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout)
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists"
if type(layer_structure) == str:
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
name=name,
enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure,
activation_func=activation_func,
weight_init=weight_init,
add_layer_norm=add_layer_norm,
use_dropout=use_dropout,
)
hypernet.save(fn)
shared.reload_hypernetworks()
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", ""
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
def train_hypernetwork(*args):

View File

@ -136,8 +136,19 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
lines.append(word)
return lines
def draw_texts(drawing, draw_x, draw_y, lines):
def get_font(fontsize):
try:
return ImageFont.truetype(opts.font or Roboto, fontsize)
except Exception:
return ImageFont.truetype(Roboto, fontsize)
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
for i, line in enumerate(lines):
fnt = initial_fnt
fontsize = initial_fontsize
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
fontsize -= 1
fnt = get_font(fontsize)
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
if not line.is_active:
@ -148,10 +159,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
fontsize = (width + height) // 25
line_spacing = fontsize // 2
try:
fnt = ImageFont.truetype(opts.font or Roboto, fontsize)
except Exception:
fnt = ImageFont.truetype(Roboto, fontsize)
fnt = get_font(fontsize)
color_active = (0, 0, 0)
color_inactive = (153, 153, 153)
@ -178,6 +186,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
for line in texts:
bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
line.allowed_width = allowed_width
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
@ -194,13 +203,13 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
x = pad_left + width * col + width / 2
y = pad_top / 2 - hor_text_heights[col] / 2
draw_texts(d, x, y, hor_texts[col])
draw_texts(d, x, y, hor_texts[col], fnt, fontsize)
for row in range(rows):
x = pad_left / 2
y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2
draw_texts(d, x, y, ver_texts[row])
draw_texts(d, x, y, ver_texts[row], fnt, fontsize)
return result
@ -429,7 +438,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
basename (`str`):
The base filename which will be applied to `filename pattern`.
seed, prompt, short_filename,
seed, prompt, short_filename,
extension (`str`):
Image file extension, default is `png`.
pngsectionname (`str`):
@ -590,7 +599,7 @@ def read_info_from_image(image):
Negative prompt: {json_info["uc"]}
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
except Exception:
print(f"Error parsing NovelAI iamge generation parameters:", file=sys.stderr)
print("Error parsing NovelAI image generation parameters:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return geninfo, items
@ -613,3 +622,14 @@ def image_data(data):
pass
return '', None
def flatten(img, bgcolor):
"""replaces transparency with bgcolor (example: "#ffffff"), returning an RGB mode image with no transparency"""
if img.mode == "RGBA":
background = Image.new('RGBA', img.size, bgcolor)
background.paste(img, mask=img)
img = background
return img.convert('RGB')

5
modules/import_hook.py Normal file
View File

@ -0,0 +1,5 @@
import sys
# this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it
if "--xformers" not in "".join(sys.argv):
sys.modules["xformers"] = None

View File

@ -172,7 +172,7 @@ class InterrogateModels:
res += ", " + match
except Exception:
print(f"Error interrogating", file=sys.stderr)
print("Error interrogating", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
res += "<error>"

View File

@ -55,18 +55,20 @@ def setup_for_low_vram(sd_model, use_medvram):
if hasattr(sd_model.cond_stage_model, 'model'):
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
# remove three big modules, cond, first_stage, and unet from the model and then
# remove four big modules, cond, first_stage, depth (if applicable), and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU.
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), sd_model.model
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = None, None, None, None
sd_model.to(devices.device)
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = stored
# register hooks for those the first two models
# register hooks for those the first three models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
if sd_model.depth_model:
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if hasattr(sd_model.cond_stage_model, 'model'):

View File

@ -2,7 +2,7 @@ from pyngrok import ngrok, conf, exception
def connect(token, port, region):
account = None
if token == None:
if token is None:
token = 'None'
else:
if ':' in token:
@ -14,7 +14,7 @@ def connect(token, port, region):
auth_token=token, region=region
)
try:
if account == None:
if account is None:
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
else:
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True, auth=account).public_url

View File

@ -27,6 +27,7 @@ from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
@ -39,17 +40,19 @@ def setup_color_correction(image):
return correction_target
def apply_color_correction(correction, image):
def apply_color_correction(correction, original_image):
logging.info("Applying color correction.")
image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
cv2.cvtColor(
np.asarray(image),
np.asarray(original_image),
cv2.COLOR_RGB2LAB
),
correction,
channel_axis=2
), cv2.COLOR_LAB2RGB).astype("uint8"))
image = blendLayers(image, original_image, BlendType.LUMINOSITY)
return image
@ -77,7 +80,7 @@ class StableDiffusionProcessing():
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, sampler_index: int = None):
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None):
if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
@ -118,6 +121,7 @@ class StableDiffusionProcessing():
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
self.s_noise = s_noise or opts.s_noise
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
self.override_settings_restore_afterwards = override_settings_restore_afterwards
self.is_using_inpainting_conditioning = False
if not seed_enable_extras:
@ -147,11 +151,11 @@ class StableDiffusionProcessing():
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
image_conditioning = image_conditioning.to(x.dtype)
return image_conditioning
@ -199,7 +203,7 @@ class StableDiffusionProcessing():
source_image * (1.0 - conditioning_mask),
getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
)
# Encode the new masked image using first stage of network.
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
@ -314,7 +318,7 @@ class Processed:
return json.dumps(obj)
def infotext(self, p: StableDiffusionProcessing, index):
def infotext(self, p: StableDiffusionProcessing, index):
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
@ -429,6 +433,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
"Hypernet hash": (None if shared.loaded_hypernetwork is None else sd_models.model_hash(shared.loaded_hypernetwork.filename)),
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
@ -446,7 +451,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
@ -463,12 +468,14 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
res = process_images_inner(p)
finally: # restore opts to original state
for k, v in stored_opts.items():
setattr(opts, k, v)
if k == 'sd_hypernetwork': shared.reload_hypernetworks()
if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
if k == 'sd_vae': sd_vae.reload_vae_weights()
finally:
# restore opts to original state
if p.override_settings_restore_afterwards:
for k, v in stored_opts.items():
setattr(opts, k, v)
if k == 'sd_hypernetwork': shared.reload_hypernetworks()
if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
if k == 'sd_vae': sd_vae.reload_vae_weights()
return res
@ -537,7 +544,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
for n in range(p.n_iter):
if state.skipped:
state.skipped = False
if state.interrupted:
break
@ -612,7 +619,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image.info["parameters"] = text
output_images.append(image)
del x_samples_ddim
del x_samples_ddim
devices.torch_gc()
@ -704,7 +711,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
"""saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images"""
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
def save_intermediate(image, index):
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
return
@ -720,7 +727,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
# Avoid making the inpainting conditioning unless necessary as
# Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again.
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
@ -829,9 +836,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.color_corrections = []
imgs = []
for img in self.init_images:
image = img.convert("RGB")
image = images.flatten(img, opts.img2img_background_color)
if crop_region is None:
if crop_region is None and self.resize_mode != 3:
image = images.resize_image(self.resize_mode, image, self.width, self.height)
if image_mask is not None:
@ -840,6 +847,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.overlay_images.append(image_masked.convert('RGBA'))
# crop_region is not None if we are doing inpaint full res
if crop_region is not None:
image = image.crop(crop_region)
image = images.resize_image(2, image, self.width, self.height)
@ -876,6 +884,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
if self.resize_mode == 3:
self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
if image_mask is not None:
init_mask = latent_mask
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))

View File

@ -37,16 +37,16 @@ class RestrictedUnpickler(pickle.Unpickler):
if module == 'collections' and name == 'OrderedDict':
return getattr(collections, name)
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
return getattr(torch._utils, name)
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage']:
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
return getattr(torch, name)
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
return getattr(torch.nn.modules.container, name)
if module == 'numpy.core.multiarray' and name == 'scalar':
return numpy.core.multiarray.scalar
if module == 'numpy' and name == 'dtype':
return numpy.dtype
if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
return getattr(numpy.core.multiarray, name)
if module == 'numpy' and name in ['dtype', 'ndarray']:
return getattr(numpy, name)
if module == '_codecs' and name == 'encode':
return encode
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
@ -80,7 +80,7 @@ def check_pt(filename, extra_handler):
# new pytorch format is a zip file
with zipfile.ZipFile(filename) as z:
check_zip_filenames(filename, z.namelist())
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
if len(data_pkl_filenames) == 0:
@ -103,12 +103,12 @@ def check_pt(filename, extra_handler):
def load(filename, *args, **kwargs):
return load_with_extra(filename, *args, **kwargs)
return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
"""
this functon is intended to be used by extensions that want to load models with
this function is intended to be used by extensions that want to load models with
some extra classes in them that the usual unpickler would find suspicious.
Use the extra_handler argument to specify a function that takes module and field name as text,
@ -137,19 +137,56 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
except pickle.UnpicklingError:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
return None
except Exception:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
return None
return unsafe_torch_load(filename, *args, **kwargs)
class Extra:
"""
A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
(because it's not your code making the torch.load call). The intended use is like this:
```
import torch
from modules import safe
def handler(module, name):
if module == 'torch' and name in ['float64', 'float16']:
return getattr(torch, name)
return None
with safe.Extra(handler):
x = torch.load('model.pt')
```
"""
def __init__(self, handler):
self.handler = handler
def __enter__(self):
global global_extra_handler
assert global_extra_handler is None, 'already inside an Extra() block'
global_extra_handler = self.handler
def __exit__(self, exc_type, exc_val, exc_tb):
global global_extra_handler
global_extra_handler = None
unsafe_torch_load = torch.load
torch.load = load
global_extra_handler = None

View File

@ -36,7 +36,7 @@ class Script:
def ui(self, is_img2img):
"""this function should create gradio UI elements. See https://gradio.app/docs/#components
The return value should be an array of all components that are used in processing.
Values of those returned componenbts will be passed to run() and process() functions.
Values of those returned components will be passed to run() and process() functions.
"""
pass
@ -47,7 +47,7 @@ class Script:
This function should return:
- False if the script should not be shown in UI at all
- True if the script should be shown in UI if it's scelected in the scripts drowpdown
- True if the script should be shown in UI if it's selected in the scripts dropdown
- script.AlwaysVisible if the script should be shown in UI at all times
"""

View File

@ -1,3 +1,4 @@
import os
import torch
from einops import repeat
@ -209,7 +210,7 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
@ -278,7 +279,7 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t
# =================================================================================================
# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
# Adapted from:
@ -319,17 +320,18 @@ class LatentInpaintDiffusion(LatentDiffusion):
def should_hijack_inpainting(checkpoint_info):
return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml")
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
cfg_basename = os.path.basename(checkpoint_info.config).lower()
return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename
def do_inpainting_hijack():
# most of this stuff seems to no longer be needed because it is already included into SD2.0
# LatentInpaintDiffusion remains because SD2.0's LatentInpaintDiffusion can't be loaded without specifying a checkpoint
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
# this file should be cleaned up later if weverything tuens out to work fine
# this file should be cleaned up later if everything turns out to work fine
# ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
# ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
# ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
# ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim

View File

@ -127,7 +127,7 @@ def check_for_psutil():
invokeAI_mps_available = check_for_psutil()
# -- Taken from https://github.com/invoke-ai/InvokeAI --
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
if invokeAI_mps_available:
import psutil
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
@ -152,14 +152,16 @@ def einsum_op_slice_1(q, k, v, slice_size):
return r
def einsum_op_mps_v1(q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
return einsum_op_compvis(q, k, v)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
if slice_size % 4096 == 0:
slice_size -= 1
return einsum_op_slice_1(q, k, v, slice_size)
def einsum_op_mps_v2(q, k, v):
if mem_total_gb > 8 and q.shape[1] <= 4096:
if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
return einsum_op_compvis(q, k, v)
else:
return einsum_op_slice_0(q, k, v, 1)
@ -188,7 +190,7 @@ def einsum_op(q, k, v):
return einsum_op_cuda(q, k, v)
if q.device.type == 'mps':
if mem_total_gb >= 32:
if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
return einsum_op_mps_v1(q, k, v)
return einsum_op_mps_v2(q, k, v)

View File

@ -4,7 +4,7 @@ import torch
class TorchHijackForUnet:
"""
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
this makes it possible to create pictures with dimensions that are muliples of 8 rather than 64
this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
"""
def __getattr__(self, item):

View File

@ -111,18 +111,19 @@ def model_hash(filename):
def select_checkpoint():
model_checkpoint = shared.opts.sd_model_checkpoint
checkpoint_info = checkpoints_list.get(model_checkpoint, None)
if checkpoint_info is not None:
return checkpoint_info
if len(checkpoints_list) == 0:
print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
if shared.cmd_opts.ckpt is not None:
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
print(f" - directory {model_path}", file=sys.stderr)
if shared.cmd_opts.ckpt_dir is not None:
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
print("Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
exit(1)
checkpoint_info = next(iter(checkpoints_list.values()))
@ -293,13 +294,16 @@ def load_model(checkpoint_info=None):
if should_hijack_inpainting(checkpoint_info):
# Hardcoded config for now...
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
sd_config.model.params.use_ema = False
sd_config.model.params.conditioning_key = "hybrid"
sd_config.model.params.unet_config.params.in_channels = 9
sd_config.model.params.finetune_keys = None
# Create a "fake" config with a different name so that we know to unload it when switching models.
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
if not hasattr(sd_config.model.params, "use_ema"):
sd_config.model.params.use_ema = False
do_inpainting_hijack()
if shared.cmd_opts.no_half:
@ -320,7 +324,7 @@ def load_model(checkpoint_info=None):
script_callbacks.model_loaded_callback(sd_model)
print(f"Model loaded.")
print("Model loaded.")
return sd_model
@ -355,5 +359,5 @@ def reload_model_weights(sd_model=None, info=None):
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
print(f"Weights loaded.")
print("Weights loaded.")
return sd_model

View File

@ -9,7 +9,7 @@ import k_diffusion.sampling
import torchsde._brownian.brownian_interval
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
from modules import prompt_parser, devices, processing, images
from modules import prompt_parser, devices, processing, images, sd_vae_approx
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@ -23,16 +23,16 @@ samplers_k_diffusion = [
('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms'], {}),
('Heun', 'sample_heun', ['k_heun'], {}),
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}),
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}),
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True}),
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True}),
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
@ -106,20 +106,32 @@ def setup_img2img_steps(p, steps=None):
return steps, t_enc
def single_sample_to_image(sample):
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
def single_sample_to_image(sample, approximation=None):
if approximation is None:
approximation = approximation_indexes.get(opts.show_progress_type, 0)
if approximation == 2:
x_sample = sd_vae_approx.cheap_approximation(sample)
elif approximation == 1:
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
else:
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample)
def sample_to_image(samples, index=0):
return single_sample_to_image(samples[index])
def sample_to_image(samples, index=0, approximation=None):
return single_sample_to_image(samples[index], approximation)
def samples_to_image_grid(samples):
return images.image_grid([single_sample_to_image(sample) for sample in samples])
def samples_to_image_grid(samples, approximation=None):
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
def store_latent(decoded):
@ -288,6 +300,16 @@ class CFGDenoiser(torch.nn.Module):
self.init_latent = None
self.step = 0
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
for i, conds in enumerate(conds_list):
for cond_index, weight in conds:
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
return denoised
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
if state.interrupted or state.skipped:
raise InterruptedException
@ -329,12 +351,7 @@ class CFGDenoiser(torch.nn.Module):
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
for i, conds in enumerate(conds_list):
for cond_index, weight in conds:
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
@ -444,9 +461,7 @@ class KDiffusionSampler:
return extra_params_kwargs
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
def get_sigmas(self, p, steps):
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
@ -454,6 +469,16 @@ class KDiffusionSampler:
else:
sigmas = self.model_wrap.get_sigmas(steps)
if self.config is not None and self.config.options.get('discard_next_to_last_sigma', False):
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
return sigmas
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
sigmas = self.get_sigmas(p, steps)
sigma_sched = sigmas[steps - t_enc - 1:]
xi = x + noise * sigma_sched[0]
@ -485,12 +510,7 @@ class KDiffusionSampler:
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
steps = steps or p.steps
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
else:
sigmas = self.model_wrap.get_sigmas(steps)
sigmas = self.get_sigmas(p, steps)
x = x * sigmas[0]

View File

@ -208,5 +208,5 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
print(f"VAE Weights loaded.")
print("VAE Weights loaded.")
return sd_model

58
modules/sd_vae_approx.py Normal file
View File

@ -0,0 +1,58 @@
import os
import torch
from torch import nn
from modules import devices, paths
sd_vae_approx_model = None
class VAEApprox(nn.Module):
def __init__(self):
super(VAEApprox, self).__init__()
self.conv1 = nn.Conv2d(4, 8, (7, 7))
self.conv2 = nn.Conv2d(8, 16, (5, 5))
self.conv3 = nn.Conv2d(16, 32, (3, 3))
self.conv4 = nn.Conv2d(32, 64, (3, 3))
self.conv5 = nn.Conv2d(64, 32, (3, 3))
self.conv6 = nn.Conv2d(32, 16, (3, 3))
self.conv7 = nn.Conv2d(16, 8, (3, 3))
self.conv8 = nn.Conv2d(8, 3, (3, 3))
def forward(self, x):
extra = 11
x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
x = nn.functional.pad(x, (extra, extra, extra, extra))
for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]:
x = layer(x)
x = nn.functional.leaky_relu(x, 0.1)
return x
def model():
global sd_vae_approx_model
if sd_vae_approx_model is None:
sd_vae_approx_model = VAEApprox()
sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt")))
sd_vae_approx_model.eval()
sd_vae_approx_model.to(devices.device, devices.dtype)
return sd_vae_approx_model
def cheap_approximation(sample):
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
coefs = torch.tensor([
[0.298, 0.207, 0.208],
[0.187, 0.286, 0.173],
[-0.158, 0.189, 0.264],
[-0.184, -0.271, -0.473],
]).to(sample.device)
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
return x_sample

View File

@ -5,6 +5,7 @@ import os
import sys
import time
from PIL import Image
import gradio as gr
import tqdm
@ -293,6 +294,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
@ -362,6 +364,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", gr.ColorPicker, {}),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
@ -383,11 +386,13 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
"deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"),
"deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"),
}))
options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
"show_progress_type": OptionInfo("Full", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),

View File

@ -28,9 +28,9 @@ class DatasetEntry:
class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
self.width = width
@ -50,14 +50,14 @@ class PersonalizedBase(Dataset):
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
if shared.state.interrupted:
raise Exception("inturrupted")
raise Exception("interrupted")
try:
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
except Exception:
@ -144,7 +144,7 @@ class PersonalizedDataLoader(DataLoader):
self.collate_fn = collate_wrapper_random
else:
self.collate_fn = collate_wrapper
class BatchLoader:
def __init__(self, data):

View File

@ -133,7 +133,7 @@ class EmbeddingDatabase:
process_file(fullfn, fn)
except Exception:
print(f"Error loading emedding {fn}:", file=sys.stderr)
print(f"Error loading embedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
@ -194,7 +194,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
csv_writer.writeheader()
epoch = (step - 1) // epoch_len
epoch_step = (step - 1) % epoch_len
epoch_step = (step - 1) % epoch_len
csv_writer.writerow({
"step": step,
@ -263,16 +263,16 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
initial_step = embedding.step or 0
if initial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
shared.state.textinfo = "Model has already been trained beyond specified max steps"
return embedding, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
pin_memory = shared.opts.pin_memory
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
latent_sampling_method = ds.latent_sampling_method
@ -295,12 +295,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
loss_step = 0
_loss_step = 0 #internal
last_saved_file = "<none>"
last_saved_image = "<none>"
forced_filename = "<none>"
embedding_yet_to_be_embedded = False
pbar = tqdm.tqdm(total=steps - initial_step)
try:
for i in range((steps-initial_step) * gradient_step):
@ -327,10 +327,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
c = shared.sd_model.cond_stage_model(batch.cond_text)
loss = shared.sd_model(x, c)[0] / gradient_step
del x
_loss_step += loss.item()
scaler.scale(loss).backward()
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue

View File

@ -49,10 +49,14 @@ if not cmd_opts.share and not cmd_opts.listen:
gradio.utils.version_check = lambda: None
gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
if cmd_opts.ngrok != None:
if cmd_opts.ngrok is not None:
import modules.ngrok as ngrok
print('ngrok authtoken detected, trying to connect...')
ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860, cmd_opts.ngrok_region)
ngrok.connect(
cmd_opts.ngrok,
cmd_opts.port if cmd_opts.port is not None else 7860,
cmd_opts.ngrok_region
)
def gr_show(visible=True):
@ -266,7 +270,7 @@ def apply_styles(prompt, prompt_neg, style1_name, style2_name):
def interrogate(image):
prompt = shared.interrogator.interrogate(image)
prompt = shared.interrogator.interrogate(image.convert("RGB"))
return gr_show(True) if prompt is None else prompt
@ -653,7 +657,7 @@ def create_ui():
setup_progressbar(progressbar, txt2img_preview, 'txt2img')
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
with gr.Column(variant='panel', elem_id="txt2img_settings"):
steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")
@ -808,11 +812,11 @@ def create_ui():
setup_progressbar(progressbar, img2img_preview, 'img2img')
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
with gr.Column(variant='panel', elem_id="img2img_settings"):
with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode:
with gr.TabItem('img2img', id='img2img'):
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480)
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480)
with gr.TabItem('Inpaint', id='inpaint'):
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
@ -853,7 +857,7 @@ def create_ui():
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs)
with gr.Row():
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize")
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20)
sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")

View File

@ -9,6 +9,8 @@ import git
import gradio as gr
import html
import shutil
import errno
from modules import extensions, shared, paths
@ -138,7 +140,18 @@ def install_extension_from_url(dirname, url):
repo = git.Repo.clone_from(url, tmpdir)
repo.remote().fetch()
os.rename(tmpdir, target_dir)
try:
os.rename(tmpdir, target_dir)
except OSError as err:
# TODO what does this do on windows? I think it'll be a different error code but I don't have a system to check it
# Shouldn't cause any new issues at least but we probably want to handle it there too.
if err.errno == errno.EXDEV:
# Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems
# Since we can't use a rename, do the slower but more versitile shutil.move()
shutil.move(tmpdir, target_dir)
else:
# Something else, not enough free space, permissions, etc. rethrow it so that it gets handled.
raise(err)
import launch
launch.run_extension_installer(target_dir)

View File

@ -1,3 +1,4 @@
blendmodes
accelerate
basicsr
fairscale==0.4.4

View File

@ -1,3 +1,4 @@
blendmodes==2022
transformers==4.19.2
accelerate==0.12.0
basicsr==1.4.2
@ -26,3 +27,4 @@ inflection==0.5.1
GitPython==3.1.27
torchsde==0.2.5
safetensors==0.2.5
httpcore<=0.15

View File

@ -18,7 +18,7 @@ def draw_xy_grid(xs, ys, x_label, y_label, cell):
ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys]
hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs]
first_pocessed = None
first_processed = None
state.job_count = len(xs) * len(ys)
@ -27,17 +27,17 @@ def draw_xy_grid(xs, ys, x_label, y_label, cell):
state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
processed = cell(x, y)
if first_pocessed is None:
first_pocessed = processed
if first_processed is None:
first_processed = processed
res.append(processed.images[0])
grid = images.image_grid(res, rows=len(ys))
grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
first_pocessed.images = [grid]
first_processed.images = [grid]
return first_pocessed
return first_processed
class Script(scripts.Script):

View File

@ -9,6 +9,7 @@ import shlex
import modules.scripts as scripts
import gradio as gr
from modules import sd_samplers
from modules.processing import Processed, process_images
from PIL import Image
from modules.shared import opts, cmd_opts, state
@ -44,6 +45,7 @@ prompt_tags = {
"seed_resize_from_h": process_int_tag,
"seed_resize_from_w": process_int_tag,
"sampler_index": process_int_tag,
"sampler_name": process_string_tag,
"batch_size": process_int_tag,
"n_iter": process_int_tag,
"steps": process_int_tag,
@ -66,14 +68,28 @@ def cmdargs(line):
arg = args[pos]
assert arg.startswith("--"), f'must start with "--": {arg}'
assert pos+1 < len(args), f'missing argument for command line option {arg}'
tag = arg[2:]
if tag == "prompt" or tag == "negative_prompt":
pos += 1
prompt = args[pos]
pos += 1
while pos < len(args) and not args[pos].startswith("--"):
prompt += " "
prompt += args[pos]
pos += 1
res[tag] = prompt
continue
func = prompt_tags.get(tag, None)
assert func, f'unknown commandline option: {arg}'
assert pos+1 < len(args), f'missing argument for command line option {arg}'
val = args[pos+1]
if tag == "sampler_name":
val = sd_samplers.samplers_map.get(val.lower(), None)
res[tag] = func(val)
@ -124,7 +140,7 @@ class Script(scripts.Script):
try:
args = cmdargs(line)
except Exception:
print(f"Error parsing line [line] as commandline:", file=sys.stderr)
print(f"Error parsing line {line} as commandline:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
args = {"prompt": line}
else:

View File

@ -35,8 +35,9 @@ class Script(scripts.Script):
seed = p.seed
init_img = p.init_images[0]
init_img = images.flatten(init_img, opts.img2img_background_color)
if (upscaler.name != "None"):
if upscaler.name != "None":
img = upscaler.scaler.upscale(init_img, scale_factor, upscaler.data_path)
else:
img = init_img

View File

@ -10,13 +10,16 @@ import numpy as np
import modules.scripts as scripts
import gradio as gr
from modules import images, sd_samplers
from modules import images, paths, sd_samplers
from modules.hypernetworks import hypernetwork
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.sd_samplers
import modules.sd_models
import modules.sd_vae
import glob
import os
import re
@ -114,6 +117,38 @@ def apply_clip_skip(p, x, xs):
opts.data["CLIP_stop_at_last_layers"] = x
def apply_upscale_latent_space(p, x, xs):
if x.lower().strip() != '0':
opts.data["use_scale_latent_for_hires_fix"] = True
else:
opts.data["use_scale_latent_for_hires_fix"] = False
def find_vae(name: str):
if name.lower() in ['auto', 'none']:
return name
else:
vae_path = os.path.abspath(os.path.join(paths.models_path, 'VAE'))
found = glob.glob(os.path.join(vae_path, f'**/{name}.*pt'), recursive=True)
if found:
return found[0]
else:
return 'auto'
def apply_vae(p, x, xs):
if x.lower().strip() == 'none':
modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file='None')
else:
found = find_vae(x)
if found:
v = modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=found)
def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
p.styles = x.split(',')
def format_value_add_label(p, opt, x):
if type(x) == float:
x = round(x, 8)
@ -167,7 +202,10 @@ axis_options = [
AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None),
AxisOption("Upscale latent space for hires.", str, apply_upscale_latent_space, format_value_add_label, None),
AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None),
AxisOption("VAE", str, apply_vae, format_value_add_label, None),
AxisOption("Styles", str, apply_styles, format_value_add_label, None),
]
@ -229,14 +267,18 @@ class SharedSettingsStackHelper(object):
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
self.hypernetwork = opts.sd_hypernetwork
self.model = shared.sd_model
self.use_scale_latent_for_hires_fix = opts.use_scale_latent_for_hires_fix
self.vae = opts.sd_vae
def __exit__(self, exc_type, exc_value, tb):
modules.sd_models.reload_model_weights(self.model)
modules.sd_vae.reload_vae_weights(self.model, vae_file=find_vae(self.vae))
hypernetwork.load_hypernetwork(self.hypernetwork)
hypernetwork.apply_strength()
opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
opts.data["use_scale_latent_for_hires_fix"] = self.use_scale_latent_for_hires_fix
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")

View File

@ -8,6 +8,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from modules import import_hook
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
from modules.paths import script_path
@ -153,8 +154,8 @@ def webui():
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the
# running web ui and do whatever the attcker wants, including installing an extension and
# runnnig its code. We disable this here. Suggested by RyotaK.
# running web ui and do whatever the attacker wants, including installing an extension and
# running its code. We disable this here. Suggested by RyotaK.
app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
setup_cors(app)

View File

@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
#################################################
# Please do not make any changes to this file, #
# change the variables in webui-user.sh instead #