Refactor and instead check if mps is being used, not availability

This commit is contained in:
brkirch 2022-11-12 02:17:55 -05:00
parent 0b5dcb3d7c
commit 98ca437edf

View File

@ -182,11 +182,7 @@ def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != devices.device:
if devices.has_mps():
attr = attr.to(device="mps", dtype=torch.float32)
else:
attr = attr.to(devices.device)
attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
setattr(self, name, attr)