Update swinir.py

This commit is contained in:
C43H66N12O12S2 2022-09-20 20:09:13 +03:00 committed by AUTOMATIC1111
parent 5f71ecfe6f
commit d8ed699839

View File

@ -1,48 +1,70 @@
import sys import sys
import traceback import traceback
import cv2 import cv2
from collections import OrderedDict
import os import os
import requests import contextlib
from collections import namedtuple
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import torch import torch
import modules.images import modules.images
from modules.shared import cmd_opts, opts, device from modules.shared import cmd_opts, opts, device
from modules.swinir_arch import SwinIR as net from modules.swinir_arch import SwinIR as net
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
def load_model(task = "realsr", large_model = True, model_path="C:/sd/ESRGANn/4x-large.pth", scale=4): precision_scope = (
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
)
def load_model(filename, scale=4):
model = net(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
embed_dim=240,
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="3conv",
)
pretrained_model = torch.load(filename)
model.load_state_dict(pretrained_model["params_ema"], strict=True)
if not cmd_opts.no_half:
model = model.half()
return model
def load_models(dirname):
for file in os.listdir(dirname):
path = os.path.join(dirname, file)
model_name, extension = os.path.splitext(file)
if extension != ".pt" and extension != ".pth":
continue
try: try:
modules.shared.sd_upscalers.append(UpscalerSwin("McSwinnySwin")) modules.shared.sd_upscalers.append(UpscalerSwin(path, model_name))
except Exception: except Exception:
print(f"Error loading ESRGAN model", file=sys.stderr) print(f"Error loading SwinIR model: {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
if not large_model:
# use 'nearest+conv' to avoid block artifacts
model = net(upscale=scale, in_chans=3, img_size=64, window_size=8,
img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2, upsampler='nearest+conv', resi_connection='1conv')
else:
# larger model size; use '3conv' to save parameters and memory; use ema for GAN training
model = net(upscale=scale, in_chans=3, img_size=64, window_size=8,
img_range=1., depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], embed_dim=240,
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
mlp_ratio=2, upsampler='nearest+conv', resi_connection='3conv')
pretrained_model = torch.load(model_path)
model.load_state_dict(pretrained_model["params_ema"], strict=True)
return model.half().to(device) def upscale(
img,
def upscale(img, tile=opts.ESRGAN_tile, tile_overlap=opts.ESRGAN_tile_overlap, window_size = 8, scale = 4): model,
tile=opts.GAN_tile,
tile_overlap=opts.GAN_tile_overlap,
window_size=8,
scale=4,
):
img = np.array(img) img = np.array(img)
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255 img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(device) img = img.unsqueeze(0).to(device)
model = load_model()
with torch.no_grad(), precision_scope("cuda"): with torch.no_grad(), precision_scope("cuda"):
_, _, h_old, w_old = img.size() _, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old h_pad = (h_old // window_size + 1) * window_size - h_old
@ -53,9 +75,11 @@ def upscale(img, tile=opts.ESRGAN_tile, tile_overlap=opts.ESRGAN_tile_overlap, w
output = output[..., : h_old * scale, : w_old * scale] output = output[..., : h_old * scale, : w_old * scale]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3: if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR output = np.transpose(
output[[2, 1, 0], :, :], (1, 2, 0)
) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8 output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
return Image.fromarray(output, 'RGB') return Image.fromarray(output, "RGB")
def inference(img, model, tile, tile_overlap, window_size, scale): def inference(img, model, tile, tile_overlap, window_size, scale):
@ -77,16 +101,23 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
out_patch = model(in_patch) out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch) out_patch_mask = torch.ones_like(out_patch)
E[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch) E[
W[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask) ..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf
].add_(out_patch_mask)
output = E.div_(W) output = E.div_(W)
return output return output
class UpscalerSwin(modules.images.Upscaler): class UpscalerSwin(modules.images.Upscaler):
def __init__(self, title): def __init__(self, filename, title):
self.name = title self.name = title
self.model = load_model(filename)
def do_upscale(self, img): def do_upscale(self, img):
img = upscale(img) model = self.model.to(device)
img = upscale(img, model)
return img return img