From e2169c75be495ffca186b341e7850e1f4ad0bb8f Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Thu, 25 Jan 2024 13:01:21 -0800 Subject: [PATCH] Update deepbooru.py --- modules/deepbooru.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 547e1b4c..246c9b25 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -4,7 +4,10 @@ import re import torch import numpy as np -from modules import modelloader, paths, deepbooru_model, devices, images, shared +from modules import modelloader, paths, deepbooru_model, images, shared +from ldm_patched.modules import model_management +from ldm_patched.modules.model_patcher import ModelPatcher + re_special = re.compile(r'([\\()])') @@ -12,6 +15,14 @@ re_special = re.compile(r'([\\()])') class DeepDanbooru: def __init__(self): self.model = None + self.load_device = model_management.text_encoder_device() + self.offload_device = model_management.text_encoder_offload_device() + self.dtype = torch.float32 + + if model_management.should_use_fp16(device=self.load_device): + self.dtype = torch.float16 + + self.patcher = None def load(self): if self.model is not None: @@ -28,16 +39,16 @@ class DeepDanbooru: self.model.load_state_dict(torch.load(files[0], map_location="cpu")) self.model.eval() - self.model.to(devices.cpu, devices.dtype) + self.model.to(self.offload_device, self.dtype) + + self.patcher = ModelPatcher(self.model, load_device=self.load_device, offload_device=self.offload_device) def start(self): self.load() - self.model.to(devices.device) + model_management.load_models_gpu([self.patcher]) def stop(self): - if not shared.opts.interrogate_keep_models_in_memory: - self.model.to(devices.cpu) - devices.torch_gc() + pass def tag(self, pil_image): self.start() @@ -56,8 +67,8 @@ class DeepDanbooru: pic = images.resize_image(2, pil_image.convert("RGB"), 512, 512) a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255 - with torch.no_grad(), devices.autocast(): - x = torch.from_numpy(a).to(devices.device) + with torch.no_grad(): + x = torch.from_numpy(a).to(self.load_device, self.dtype) y = self.model(x)[0].detach().cpu().numpy() probability_dict = {}