Update deepbooru.py
This commit is contained in:
parent
497645b2f8
commit
e2169c75be
@ -4,7 +4,10 @@ import re
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from modules import modelloader, paths, deepbooru_model, devices, images, shared
|
from modules import modelloader, paths, deepbooru_model, images, shared
|
||||||
|
from ldm_patched.modules import model_management
|
||||||
|
from ldm_patched.modules.model_patcher import ModelPatcher
|
||||||
|
|
||||||
|
|
||||||
re_special = re.compile(r'([\\()])')
|
re_special = re.compile(r'([\\()])')
|
||||||
|
|
||||||
@ -12,6 +15,14 @@ re_special = re.compile(r'([\\()])')
|
|||||||
class DeepDanbooru:
|
class DeepDanbooru:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.model = None
|
self.model = None
|
||||||
|
self.load_device = model_management.text_encoder_device()
|
||||||
|
self.offload_device = model_management.text_encoder_offload_device()
|
||||||
|
self.dtype = torch.float32
|
||||||
|
|
||||||
|
if model_management.should_use_fp16(device=self.load_device):
|
||||||
|
self.dtype = torch.float16
|
||||||
|
|
||||||
|
self.patcher = None
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
if self.model is not None:
|
if self.model is not None:
|
||||||
@ -28,16 +39,16 @@ class DeepDanbooru:
|
|||||||
self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
|
self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
|
||||||
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
self.model.to(devices.cpu, devices.dtype)
|
self.model.to(self.offload_device, self.dtype)
|
||||||
|
|
||||||
|
self.patcher = ModelPatcher(self.model, load_device=self.load_device, offload_device=self.offload_device)
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
self.load()
|
self.load()
|
||||||
self.model.to(devices.device)
|
model_management.load_models_gpu([self.patcher])
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
if not shared.opts.interrogate_keep_models_in_memory:
|
pass
|
||||||
self.model.to(devices.cpu)
|
|
||||||
devices.torch_gc()
|
|
||||||
|
|
||||||
def tag(self, pil_image):
|
def tag(self, pil_image):
|
||||||
self.start()
|
self.start()
|
||||||
@ -56,8 +67,8 @@ class DeepDanbooru:
|
|||||||
pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
|
pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
|
||||||
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
|
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
|
||||||
|
|
||||||
with torch.no_grad(), devices.autocast():
|
with torch.no_grad():
|
||||||
x = torch.from_numpy(a).to(devices.device)
|
x = torch.from_numpy(a).to(self.load_device, self.dtype)
|
||||||
y = self.model(x)[0].detach().cpu().numpy()
|
y = self.model(x)[0].detach().cpu().numpy()
|
||||||
|
|
||||||
probability_dict = {}
|
probability_dict = {}
|
||||||
|
Loading…
Reference in New Issue
Block a user