fp8 for TE

This commit is contained in:
Kohaku-Blueleaf 2023-10-25 11:36:43 +08:00
parent 9c1eba2af3
commit 1df6c8bfec

View File

@ -407,6 +407,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
module.to(torch.float8_e4m3fn)
timer.record("apply fp8 unet for cpu")
else:
if model.is_sdxl:
cond_stage = model.conditioner
else:
cond_stage = model.cond_stage_model
for module in cond_stage.modules():
if isinstance(module, torch.nn.Linear):
module.to(torch.float8_e4m3fn)
model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn)
timer.record("apply fp8 unet")