46 lines
1.0 KiB
Python
46 lines
1.0 KiB
Python
import torch
|
|
import numpy as np
|
|
|
|
from ldm_patched.modules.conds import CONDRegular, CONDCrossAttn
|
|
|
|
|
|
def cond_from_a1111_to_patched_ldm(cond):
|
|
if isinstance(cond, torch.Tensor):
|
|
result = dict(
|
|
cross_attn=cond,
|
|
model_conds=dict(
|
|
c_crossattn=CONDCrossAttn(cond),
|
|
)
|
|
)
|
|
return [result, ]
|
|
|
|
cross_attn = cond['crossattn']
|
|
pooled_output = cond['vector']
|
|
|
|
result = dict(
|
|
cross_attn=cross_attn,
|
|
pooled_output=pooled_output,
|
|
model_conds=dict(
|
|
c_crossattn=CONDCrossAttn(cross_attn),
|
|
y=CONDRegular(pooled_output)
|
|
)
|
|
)
|
|
|
|
return [result, ]
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.inference_mode()
|
|
def pytorch_to_numpy(x):
|
|
return [np.clip(255. * y.cpu().numpy(), 0, 255).astype(np.uint8) for y in x]
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.inference_mode()
|
|
def numpy_to_pytorch(x):
|
|
y = x.astype(np.float32) / 255.0
|
|
y = y[None]
|
|
y = np.ascontiguousarray(y.copy())
|
|
y = torch.from_numpy(y).float()
|
|
return y
|