Update safe.py

This commit is contained in:
lllyasviel 2024-02-02 22:43:27 -08:00
parent 2129c64aed
commit a5de9d46f3

View File

@ -1,39 +1,194 @@
# this code is adapted from the script contributed by anon from /h/
import pickle
import collections
import torch
import contextlib
import numpy
import _codecs
import zipfile
import re
TypedStorage = None
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
from modules import errors
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
def encode(*args):
pass
out = _codecs.encode(*args)
return out
class RestrictedUnpickler:
pass
class RestrictedUnpickler(pickle.Unpickler):
extra_handler = None
def persistent_load(self, saved_id):
assert saved_id[0] == 'storage'
try:
return TypedStorage(_internal=True)
except TypeError:
return TypedStorage() # PyTorch before 2.0 does not have the _internal argument
def find_class(self, module, name):
if self.extra_handler is not None:
res = self.extra_handler(module, name)
if res is not None:
return res
if module == 'collections' and name == 'OrderedDict':
return getattr(collections, name)
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
return getattr(torch._utils, name)
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:
return getattr(torch, name)
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
return getattr(torch.nn.modules.container, name)
if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
return getattr(numpy.core.multiarray, name)
if module == 'numpy' and name in ['dtype', 'ndarray']:
return getattr(numpy, name)
if module == '_codecs' and name == 'encode':
return encode
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
import pytorch_lightning.callbacks
return pytorch_lightning.callbacks.model_checkpoint
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
import pytorch_lightning.callbacks.model_checkpoint
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
if module == "__builtin__" and name == 'set':
return set
# Forbid everything else.
raise Exception(f"global '{module}/{name}' is forbidden")
allowed_zip_names_re = None
data_pkl_re = None
# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
def check_zip_filenames(filename, names):
pass
for name in names:
if allowed_zip_names_re.match(name):
continue
raise Exception(f"bad file inside {filename}: {name}")
def check_pt(filename, extra_handler):
pass
try:
# new pytorch format is a zip file
with zipfile.ZipFile(filename) as z:
check_zip_filenames(filename, z.namelist())
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
if len(data_pkl_filenames) == 0:
raise Exception(f"data.pkl not found in {filename}")
if len(data_pkl_filenames) > 1:
raise Exception(f"Multiple data.pkl found in {filename}")
with z.open(data_pkl_filenames[0]) as file:
unpickler = RestrictedUnpickler(file)
unpickler.extra_handler = extra_handler
unpickler.load()
except zipfile.BadZipfile:
# if it's not a zip file, it's an old pytorch format, with five objects written to pickle
with open(filename, "rb") as file:
unpickler = RestrictedUnpickler(file)
unpickler.extra_handler = extra_handler
for _ in range(5):
unpickler.load()
def load(filename, *args, **kwargs):
pass
return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
pass
"""
this function is intended to be used by extensions that want to load models with
some extra classes in them that the usual unpickler would find suspicious.
Use the extra_handler argument to specify a function that takes module and field name as text,
and returns that field's value:
```python
def extra(module, name):
if module == 'collections' and name == 'OrderedDict':
return collections.OrderedDict
return None
safe.load_with_extra('model.pt', extra_handler=extra)
```
The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
definitely unsafe.
"""
from modules import shared
try:
if not shared.cmd_opts.disable_safe_unpickle:
check_pt(filename, extra_handler)
except pickle.UnpicklingError:
errors.report(
f"Error verifying pickled file from {filename}\n"
"-----> !!!! The file is most likely corrupted !!!! <-----\n"
"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
exc_info=True,
)
return None
except Exception:
errors.report(
f"Error verifying pickled file from {filename}\n"
f"The file may be malicious, so the program is not going to read it.\n"
f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n",
exc_info=True,
)
return None
return unsafe_torch_load(filename, *args, **kwargs)
def Extra(*args, **kwargs):
return contextlib.nullcontext()
class Extra:
"""
A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
(because it's not your code making the torch.load call). The intended use is like this:
```
import torch
from modules import safe
def handler(module, name):
if module == 'torch' and name in ['float64', 'float16']:
return getattr(torch, name)
return None
with safe.Extra(handler):
x = torch.load('model.pt')
```
"""
def __init__(self, handler):
self.handler = handler
def __enter__(self):
global global_extra_handler
assert global_extra_handler is None, 'already inside an Extra() block'
global_extra_handler = self.handler
def __exit__(self, exc_type, exc_val, exc_tb):
global global_extra_handler
global_extra_handler = None
unsafe_torch_load = torch.load