Update rng.py

This commit is contained in:
lllyasviel 2024-02-05 14:08:34 -08:00
parent 1b9734c45b
commit b174caa275

View File

@ -8,12 +8,12 @@ def randn(seed, shape, generator=None):
Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
manual_seed(seed)
if generator is not None:
# if generator is not none, we must generate a noise with and without
# generator together to avoid future 'randn' get same noise again
torch.randn(shape, device=devices.device)
# If generator is not none, we must use another seed to
# avoid global torch.rand to get same noise again.
manual_seed((seed + 262144) % 65536)
else:
manual_seed(seed)
if shared.opts.randn_source == "NV":
return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)