Update deepbooru.py

This commit is contained in:
lllyasviel 2024-01-25 13:01:21 -08:00
parent 497645b2f8
commit e2169c75be

View File

@ -4,7 +4,10 @@ import re
import torch import torch
import numpy as np import numpy as np
from modules import modelloader, paths, deepbooru_model, devices, images, shared from modules import modelloader, paths, deepbooru_model, images, shared
from ldm_patched.modules import model_management
from ldm_patched.modules.model_patcher import ModelPatcher
re_special = re.compile(r'([\\()])') re_special = re.compile(r'([\\()])')
@ -12,6 +15,14 @@ re_special = re.compile(r'([\\()])')
class DeepDanbooru: class DeepDanbooru:
def __init__(self): def __init__(self):
self.model = None self.model = None
self.load_device = model_management.text_encoder_device()
self.offload_device = model_management.text_encoder_offload_device()
self.dtype = torch.float32
if model_management.should_use_fp16(device=self.load_device):
self.dtype = torch.float16
self.patcher = None
def load(self): def load(self):
if self.model is not None: if self.model is not None:
@ -28,16 +39,16 @@ class DeepDanbooru:
self.model.load_state_dict(torch.load(files[0], map_location="cpu")) self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
self.model.eval() self.model.eval()
self.model.to(devices.cpu, devices.dtype) self.model.to(self.offload_device, self.dtype)
self.patcher = ModelPatcher(self.model, load_device=self.load_device, offload_device=self.offload_device)
def start(self): def start(self):
self.load() self.load()
self.model.to(devices.device) model_management.load_models_gpu([self.patcher])
def stop(self): def stop(self):
if not shared.opts.interrogate_keep_models_in_memory: pass
self.model.to(devices.cpu)
devices.torch_gc()
def tag(self, pil_image): def tag(self, pil_image):
self.start() self.start()
@ -56,8 +67,8 @@ class DeepDanbooru:
pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512) pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255 a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
with torch.no_grad(), devices.autocast(): with torch.no_grad():
x = torch.from_numpy(a).to(devices.device) x = torch.from_numpy(a).to(self.load_device, self.dtype)
y = self.model(x)[0].detach().cpu().numpy() y = self.model(x)[0].detach().cpu().numpy()
probability_dict = {} probability_dict = {}