2022-10-04 09:32:22 +00:00
|
|
|
import contextlib
|
2022-09-11 05:11:27 +00:00
|
|
|
import torch
|
2024-01-24 18:51:36 +00:00
|
|
|
import ldm_patched.modules.model_management as model_management
|
2023-12-02 06:00:46 +00:00
|
|
|
|
|
|
|
|
|
|
|
def has_xpu() -> bool:
|
2024-01-24 18:51:36 +00:00
|
|
|
return model_management.xpu_available
|
2023-12-02 06:00:46 +00:00
|
|
|
|
2022-11-12 07:00:49 +00:00
|
|
|
|
2022-11-12 03:02:40 +00:00
|
|
|
def has_mps() -> bool:
|
2024-01-24 18:51:36 +00:00
|
|
|
return model_management.mps_mode()
|
2022-09-11 15:48:36 +00:00
|
|
|
|
2022-11-12 07:00:49 +00:00
|
|
|
|
2023-10-28 07:24:26 +00:00
|
|
|
def cuda_no_autocast(device_id=None) -> bool:
|
2024-01-24 18:51:36 +00:00
|
|
|
return False
|
2023-10-28 07:24:26 +00:00
|
|
|
|
|
|
|
|
|
|
|
def get_cuda_device_id():
|
2024-01-24 18:51:36 +00:00
|
|
|
return model_management.get_torch_device().index
|
2023-10-28 07:24:26 +00:00
|
|
|
|
|
|
|
|
2022-11-27 10:08:54 +00:00
|
|
|
def get_cuda_device_string():
|
2024-01-24 18:51:36 +00:00
|
|
|
return str(model_management.get_torch_device())
|
2022-10-22 11:04:14 +00:00
|
|
|
|
2022-11-27 10:08:54 +00:00
|
|
|
|
2023-01-27 08:28:12 +00:00
|
|
|
def get_optimal_device_name():
|
2024-01-24 18:51:36 +00:00
|
|
|
return model_management.get_torch_device().type
|
2022-09-11 15:48:36 +00:00
|
|
|
|
2023-01-27 08:28:12 +00:00
|
|
|
|
|
|
|
def get_optimal_device():
|
2024-01-24 18:51:36 +00:00
|
|
|
return model_management.get_torch_device()
|
2022-09-11 20:24:24 +00:00
|
|
|
|
|
|
|
|
2022-12-03 15:06:33 +00:00
|
|
|
def get_device_for(task):
|
|
|
|
return get_optimal_device()
|
|
|
|
|
|
|
|
|
2022-09-11 20:24:24 +00:00
|
|
|
def torch_gc():
|
2024-01-24 18:51:36 +00:00
|
|
|
model_management.soft_empty_cache()
|
2023-12-02 06:00:46 +00:00
|
|
|
|
2022-09-12 13:34:13 +00:00
|
|
|
|
|
|
|
def enable_tf32():
|
2024-01-24 18:51:36 +00:00
|
|
|
return
|
2022-12-03 13:01:23 +00:00
|
|
|
|
2022-09-12 17:09:32 +00:00
|
|
|
|
2023-08-03 04:18:55 +00:00
|
|
|
cpu: torch.device = torch.device("cpu")
|
2023-10-23 17:49:05 +00:00
|
|
|
fp8: bool = False
|
2024-01-24 18:51:36 +00:00
|
|
|
device: torch.device = model_management.get_torch_device()
|
2024-01-25 20:46:39 +00:00
|
|
|
device_interrogate: torch.device = cpu # not used
|
|
|
|
device_gfpgan: torch.device = cpu
|
2024-01-25 21:17:51 +00:00
|
|
|
device_esrgan: torch.device = model_management.get_torch_device() # will be managed in special way
|
2024-01-25 20:46:39 +00:00
|
|
|
device_codeformer: torch.device = cpu
|
2024-01-24 18:51:36 +00:00
|
|
|
dtype: torch.dtype = model_management.unet_dtype()
|
|
|
|
dtype_vae: torch.dtype = model_management.vae_dtype()
|
|
|
|
dtype_unet: torch.dtype = model_management.unet_dtype()
|
|
|
|
dtype_inference: torch.dtype = model_management.unet_dtype()
|
2023-01-25 04:51:45 +00:00
|
|
|
unet_needs_upcast = False
|
2022-09-12 17:09:32 +00:00
|
|
|
|
2022-11-12 07:00:49 +00:00
|
|
|
|
2023-01-27 15:19:43 +00:00
|
|
|
def cond_cast_unet(input):
|
2024-01-24 18:51:36 +00:00
|
|
|
return input
|
2023-01-27 15:19:43 +00:00
|
|
|
|
|
|
|
|
|
|
|
def cond_cast_float(input):
|
2024-01-24 18:51:36 +00:00
|
|
|
return input
|
2023-01-27 15:19:43 +00:00
|
|
|
|
|
|
|
|
2023-08-02 21:00:23 +00:00
|
|
|
nv_rng = None
|
2024-01-24 18:51:36 +00:00
|
|
|
patch_module_list = []
|
2023-10-28 07:24:26 +00:00
|
|
|
|
2023-11-19 07:50:06 +00:00
|
|
|
|
2024-01-09 14:11:44 +00:00
|
|
|
def manual_cast_forward(target_dtype):
|
2024-01-24 18:51:36 +00:00
|
|
|
return
|
2023-11-19 07:50:06 +00:00
|
|
|
|
|
|
|
|
2023-10-28 07:24:26 +00:00
|
|
|
@contextlib.contextmanager
|
2024-01-09 14:11:44 +00:00
|
|
|
def manual_cast(target_dtype):
|
2024-01-24 18:51:36 +00:00
|
|
|
return
|
2023-10-28 07:24:26 +00:00
|
|
|
|
|
|
|
|
|
|
|
def autocast(disable=False):
|
2024-01-16 10:33:39 +00:00
|
|
|
return contextlib.nullcontext()
|
2022-10-25 06:01:57 +00:00
|
|
|
|
2022-11-12 07:00:49 +00:00
|
|
|
|
2023-01-25 05:23:10 +00:00
|
|
|
def without_autocast(disable=False):
|
2024-01-16 10:33:39 +00:00
|
|
|
return contextlib.nullcontext()
|
2023-01-25 05:23:10 +00:00
|
|
|
|
|
|
|
|
2023-01-16 19:59:46 +00:00
|
|
|
class NansException(Exception):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def test_for_nans(x, where):
|
2024-01-24 18:51:36 +00:00
|
|
|
return
|
2023-05-21 18:55:14 +00:00
|
|
|
|
|
|
|
|
|
|
|
def first_time_calculation():
|
2024-01-24 18:51:36 +00:00
|
|
|
return
|
2023-08-09 05:43:31 +00:00
|
|
|
|