Add utility to inspect a model's parameters (to get dtype/device)

This commit is contained in:
Aarni Koskela 2023-12-31 00:20:30 +02:00
parent a84e842189
commit 5768afc776
8 changed files with 53 additions and 7 deletions

View File

@ -4,6 +4,7 @@ from functools import lru_cache
import torch import torch
from modules import errors, shared from modules import errors, shared
from modules.torch_utils import get_param
if sys.platform == "darwin": if sys.platform == "darwin":
from modules import mac_specific from modules import mac_specific
@ -131,7 +132,7 @@ patch_module_list = [
def manual_cast_forward(self, *args, **kwargs): def manual_cast_forward(self, *args, **kwargs):
org_dtype = next(self.parameters()).dtype org_dtype = get_param(self).dtype
self.to(dtype) self.to(dtype)
args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}

View File

@ -11,6 +11,7 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from modules import devices, paths, shared, lowvram, modelloader, errors from modules import devices, paths, shared, lowvram, modelloader, errors
from modules.torch_utils import get_param
blip_image_eval_size = 384 blip_image_eval_size = 384
clip_model_name = 'ViT-L/14' clip_model_name = 'ViT-L/14'
@ -131,7 +132,7 @@ class InterrogateModels:
self.clip_model = self.clip_model.to(devices.device_interrogate) self.clip_model = self.clip_model.to(devices.device_interrogate)
self.dtype = next(self.clip_model.parameters()).dtype self.dtype = get_param(self.clip_model).dtype
def send_clip_to_ram(self): def send_clip_to_ram(self):
if not shared.opts.interrogate_keep_models_in_memory: if not shared.opts.interrogate_keep_models_in_memory:

View File

@ -6,6 +6,7 @@ import sgm.models.diffusion
import sgm.modules.diffusionmodules.denoiser_scaling import sgm.modules.diffusionmodules.denoiser_scaling
import sgm.modules.diffusionmodules.discretizer import sgm.modules.diffusionmodules.discretizer
from modules import devices, shared, prompt_parser from modules import devices, shared, prompt_parser
from modules.torch_utils import get_param
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
@ -90,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt
def extend_sdxl(model): def extend_sdxl(model):
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
dtype = next(model.model.diffusion_model.parameters()).dtype dtype = get_param(model.model.diffusion_model).dtype
model.model.diffusion_model.dtype = dtype model.model.diffusion_model.dtype = dtype
model.model.conditioning_key = 'crossattn' model.model.conditioning_key = 'crossattn'
model.cond_stage_key = 'txt' model.cond_stage_key = 'txt'

17
modules/torch_utils.py Normal file
View File

@ -0,0 +1,17 @@
from __future__ import annotations
import torch.nn
def get_param(model) -> torch.nn.Parameter:
"""
Find the first parameter in a model or module.
"""
if hasattr(model, "model") and hasattr(model.model, "parameters"):
# Unpeel a model descriptor to get at the actual Torch module.
model = model.model
for param in model.parameters():
return param
raise ValueError(f"No parameters found in model {model!r}")

View File

@ -7,6 +7,7 @@ import tqdm
from PIL import Image from PIL import Image
from modules import images, shared from modules import images, shared
from modules.torch_utils import get_param
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -17,8 +18,8 @@ def upscale_without_tiling(model, img: Image.Image):
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
model_weight = next(iter(model.model.parameters())) param = get_param(model)
img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype) img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)
with torch.no_grad(): with torch.no_grad():
output = model(img) output = model(img)

View File

@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta
from transformers import XLMRobertaModel,XLMRobertaTokenizer from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional from typing import Optional
from modules.torch_utils import get_param
class BertSeriesConfig(BertConfig): class BertSeriesConfig(BertConfig):
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
@ -62,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
self.post_init() self.post_init()
def encode(self,c): def encode(self,c):
device = next(self.parameters()).device device = get_param(self).device
text = self.tokenizer(c, text = self.tokenizer(c,
truncation=True, truncation=True,
max_length=77, max_length=77,

View File

@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta
from transformers import XLMRobertaModel,XLMRobertaTokenizer from transformers import XLMRobertaModel,XLMRobertaTokenizer
from typing import Optional from typing import Optional
from modules.torch_utils import get_param
class BertSeriesConfig(BertConfig): class BertSeriesConfig(BertConfig):
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
@ -68,7 +71,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
self.post_init() self.post_init()
def encode(self,c): def encode(self,c):
device = next(self.parameters()).device device = get_param(self).device
text = self.tokenizer(c, text = self.tokenizer(c,
truncation=True, truncation=True,
max_length=77, max_length=77,

19
test/test_torch_utils.py Normal file
View File

@ -0,0 +1,19 @@
import types
import pytest
import torch
from modules.torch_utils import get_param
@pytest.mark.parametrize("wrapped", [True, False])
def test_get_param(wrapped):
mod = torch.nn.Linear(1, 1)
cpu = torch.device("cpu")
mod.to(dtype=torch.float16, device=cpu)
if wrapped:
# more or less how spandrel wraps a thing
mod = types.SimpleNamespace(model=mod)
p = get_param(mod)
assert p.dtype == torch.float16
assert p.device == cpu