UniPC progress bar adjustment

This commit is contained in:
Sakura-Luna 2023-05-11 12:26:04 +08:00
parent 22bcc7be42
commit ae17e97898

View File

@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
import math
from tqdm.auto import trange
import tqdm
class NoiseScheduleVP:
@ -757,40 +757,44 @@ class UniPC:
vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [self.model_fn(x, vec_t)]
t_prev_list = [vec_t]
# Init the first `order` values by lower order multistep DPM-Solver.
for init_order in range(1, order):
vec_t = timesteps[init_order].expand(x.shape[0])
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
if model_x is None:
model_x = self.model_fn(x, vec_t)
if self.after_update is not None:
self.after_update(x, model_x)
model_prev_list.append(model_x)
t_prev_list.append(vec_t)
for step in trange(order, steps + 1):
vec_t = timesteps[step].expand(x.shape[0])
if lower_order_final:
step_order = min(order, steps + 1 - step)
else:
step_order = order
#print('this step order:', step_order)
if step == steps:
#print('do not run corrector at the last step')
use_corrector = False
else:
use_corrector = True
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
if self.after_update is not None:
self.after_update(x, model_x)
for i in range(order - 1):
t_prev_list[i] = t_prev_list[i + 1]
model_prev_list[i] = model_prev_list[i + 1]
t_prev_list[-1] = vec_t
# We do not need to evaluate the final model value.
if step < steps:
with tqdm.tqdm(total=steps) as pbar:
# Init the first `order` values by lower order multistep DPM-Solver.
for init_order in range(1, order):
vec_t = timesteps[init_order].expand(x.shape[0])
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
if model_x is None:
model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x
if self.after_update is not None:
self.after_update(x, model_x)
model_prev_list.append(model_x)
t_prev_list.append(vec_t)
pbar.update()
for step in range(order, steps + 1):
vec_t = timesteps[step].expand(x.shape[0])
if lower_order_final:
step_order = min(order, steps + 1 - step)
else:
step_order = order
#print('this step order:', step_order)
if step == steps:
#print('do not run corrector at the last step')
use_corrector = False
else:
use_corrector = True
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
if self.after_update is not None:
self.after_update(x, model_x)
for i in range(order - 1):
t_prev_list[i] = t_prev_list[i + 1]
model_prev_list[i] = model_prev_list[i + 1]
t_prev_list[-1] = vec_t
# We do not need to evaluate the final model value.
if step < steps:
if model_x is None:
model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x
pbar.update()
else:
raise NotImplementedError()
if denoise_to_zero: