Update deepbooru.py

This commit is contained in:
lllyasviel 2024-01-25 13:01:21 -08:00
parent 497645b2f8
commit e2169c75be

View File

@ -4,7 +4,10 @@ import re
import torch
import numpy as np
from modules import modelloader, paths, deepbooru_model, devices, images, shared
from modules import modelloader, paths, deepbooru_model, images, shared
from ldm_patched.modules import model_management
from ldm_patched.modules.model_patcher import ModelPatcher
re_special = re.compile(r'([\\()])')
@ -12,6 +15,14 @@ re_special = re.compile(r'([\\()])')
class DeepDanbooru:
def __init__(self):
self.model = None
self.load_device = model_management.text_encoder_device()
self.offload_device = model_management.text_encoder_offload_device()
self.dtype = torch.float32
if model_management.should_use_fp16(device=self.load_device):
self.dtype = torch.float16
self.patcher = None
def load(self):
if self.model is not None:
@ -28,16 +39,16 @@ class DeepDanbooru:
self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
self.model.eval()
self.model.to(devices.cpu, devices.dtype)
self.model.to(self.offload_device, self.dtype)
self.patcher = ModelPatcher(self.model, load_device=self.load_device, offload_device=self.offload_device)
def start(self):
self.load()
self.model.to(devices.device)
model_management.load_models_gpu([self.patcher])
def stop(self):
if not shared.opts.interrogate_keep_models_in_memory:
self.model.to(devices.cpu)
devices.torch_gc()
pass
def tag(self, pil_image):
self.start()
@ -56,8 +67,8 @@ class DeepDanbooru:
pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512)
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
with torch.no_grad(), devices.autocast():
x = torch.from_numpy(a).to(devices.device)
with torch.no_grad():
x = torch.from_numpy(a).to(self.load_device, self.dtype)
y = self.model(x)[0].detach().cpu().numpy()
probability_dict = {}