backend
This commit is contained in:
parent
affb40340e
commit
40afb9dfb0
@ -510,7 +510,7 @@ def unet_dtype(device=None, model_params=0):
|
|||||||
return torch.float8_e4m3fn
|
return torch.float8_e4m3fn
|
||||||
if args.unet_in_fp8_e5m2:
|
if args.unet_in_fp8_e5m2:
|
||||||
return torch.float8_e5m2
|
return torch.float8_e5m2
|
||||||
if should_use_fp16(device=device, model_params=model_params):
|
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
||||||
return torch.float16
|
return torch.float16
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
@ -710,7 +710,7 @@ def is_device_mps(device):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
|
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
|
|
||||||
if device is not None:
|
if device is not None:
|
||||||
@ -736,10 +736,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
|
|||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if torch.cuda.is_bf16_supported():
|
if torch.version.hip:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
props = torch.cuda.get_device_properties("cuda")
|
props = torch.cuda.get_device_properties("cuda")
|
||||||
|
if props.major >= 8:
|
||||||
|
return True
|
||||||
|
|
||||||
if props.major < 6:
|
if props.major < 6:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -752,7 +755,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True):
|
|||||||
if x in props.name.lower():
|
if x in props.name.lower():
|
||||||
fp16_works = True
|
fp16_works = True
|
||||||
|
|
||||||
if fp16_works:
|
if fp16_works or manual_cast:
|
||||||
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
|
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
|
||||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||||
return True
|
return True
|
||||||
|
@ -413,6 +413,8 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
|
|||||||
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
|
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
|
||||||
for y in range(0, s.shape[2], tile_y - overlap):
|
for y in range(0, s.shape[2], tile_y - overlap):
|
||||||
for x in range(0, s.shape[3], tile_x - overlap):
|
for x in range(0, s.shape[3], tile_x - overlap):
|
||||||
|
x = max(0, min(s.shape[-1] - overlap, x))
|
||||||
|
y = max(0, min(s.shape[-2] - overlap, y))
|
||||||
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
|
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
|
||||||
|
|
||||||
ps = function(s_in).to(output_device)
|
ps = function(s_in).to(output_device)
|
||||||
|
Loading…
Reference in New Issue
Block a user