load_spandrel_model: always return a model descriptor

This commit is contained in:
Aarni Koskela 2023-12-31 00:04:47 +02:00
parent 3be9074031
commit c0ca6348e8

View File

@ -1,8 +1,9 @@
from __future__ import annotations
import importlib
import logging
import os
import importlib
from typing import TYPE_CHECKING
from urllib.parse import urlparse
import torch
@ -10,6 +11,8 @@ import torch
from modules import shared
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
if TYPE_CHECKING:
import spandrel
logger = logging.getLogger(__name__)
@ -142,17 +145,17 @@ def load_spandrel_model(
half: bool = False,
dtype: str | None = None,
expected_architecture: str | None = None,
):
) -> spandrel.ModelDescriptor:
import spandrel
model = spandrel.ModelLoader(device=device).load_from_file(path)
if expected_architecture and model.architecture != expected_architecture:
model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path)
if expected_architecture and model_descriptor.architecture != expected_architecture:
logger.warning(
f"Model {path!r} is not a {expected_architecture!r} model (got {model.architecture!r})",
f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
)
if half:
model = model.model.half()
model_descriptor.model.half()
if dtype:
model = model.model.to(dtype=dtype)
model.eval()
logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype)
return model
model_descriptor.model.to(dtype=dtype)
model_descriptor.model.eval()
logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype)
return model_descriptor