diff --git a/modules/safe.py b/modules/safe.py index ff5a4ca4..fe771c1b 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -1,39 +1,194 @@ +# this code is adapted from the script contributed by anon from /h/ + +import pickle +import collections + import torch -import contextlib +import numpy +import _codecs +import zipfile +import re -TypedStorage = None +# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage +from modules import errors + +TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage def encode(*args): - pass + out = _codecs.encode(*args) + return out -class RestrictedUnpickler: - pass +class RestrictedUnpickler(pickle.Unpickler): + extra_handler = None + + def persistent_load(self, saved_id): + assert saved_id[0] == 'storage' + + try: + return TypedStorage(_internal=True) + except TypeError: + return TypedStorage() # PyTorch before 2.0 does not have the _internal argument + + def find_class(self, module, name): + if self.extra_handler is not None: + res = self.extra_handler(module, name) + if res is not None: + return res + + if module == 'collections' and name == 'OrderedDict': + return getattr(collections, name) + if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']: + return getattr(torch._utils, name) + if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']: + return getattr(torch, name) + if module == 'torch.nn.modules.container' and name in ['ParameterDict']: + return getattr(torch.nn.modules.container, name) + if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']: + return getattr(numpy.core.multiarray, name) + if module == 'numpy' and name in ['dtype', 'ndarray']: + return getattr(numpy, name) + if module == '_codecs' and name == 'encode': + return encode + if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint': + import pytorch_lightning.callbacks + return pytorch_lightning.callbacks.model_checkpoint + if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint': + import pytorch_lightning.callbacks.model_checkpoint + return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint + if module == "__builtin__" and name == 'set': + return set + + # Forbid everything else. + raise Exception(f"global '{module}/{name}' is forbidden") -allowed_zip_names_re = None -data_pkl_re = None - +# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/' +allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$") +data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$") def check_zip_filenames(filename, names): - pass + for name in names: + if allowed_zip_names_re.match(name): + continue + + raise Exception(f"bad file inside {filename}: {name}") def check_pt(filename, extra_handler): - pass + try: + + # new pytorch format is a zip file + with zipfile.ZipFile(filename) as z: + check_zip_filenames(filename, z.namelist()) + + # find filename of data.pkl in zip file: '/data.pkl' + data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)] + if len(data_pkl_filenames) == 0: + raise Exception(f"data.pkl not found in {filename}") + if len(data_pkl_filenames) > 1: + raise Exception(f"Multiple data.pkl found in {filename}") + with z.open(data_pkl_filenames[0]) as file: + unpickler = RestrictedUnpickler(file) + unpickler.extra_handler = extra_handler + unpickler.load() + + except zipfile.BadZipfile: + + # if it's not a zip file, it's an old pytorch format, with five objects written to pickle + with open(filename, "rb") as file: + unpickler = RestrictedUnpickler(file) + unpickler.extra_handler = extra_handler + for _ in range(5): + unpickler.load() def load(filename, *args, **kwargs): - pass + return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs) def load_with_extra(filename, extra_handler=None, *args, **kwargs): - pass + """ + this function is intended to be used by extensions that want to load models with + some extra classes in them that the usual unpickler would find suspicious. + + Use the extra_handler argument to specify a function that takes module and field name as text, + and returns that field's value: + + ```python + def extra(module, name): + if module == 'collections' and name == 'OrderedDict': + return collections.OrderedDict + + return None + + safe.load_with_extra('model.pt', extra_handler=extra) + ``` + + The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is + definitely unsafe. + """ + + from modules import shared + + try: + if not shared.cmd_opts.disable_safe_unpickle: + check_pt(filename, extra_handler) + + except pickle.UnpicklingError: + errors.report( + f"Error verifying pickled file from {filename}\n" + "-----> !!!! The file is most likely corrupted !!!! <-----\n" + "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", + exc_info=True, + ) + return None + except Exception: + errors.report( + f"Error verifying pickled file from {filename}\n" + f"The file may be malicious, so the program is not going to read it.\n" + f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", + exc_info=True, + ) + return None + + return unsafe_torch_load(filename, *args, **kwargs) -def Extra(*args, **kwargs): - return contextlib.nullcontext() +class Extra: + """ + A class for temporarily setting the global handler for when you can't explicitly call load_with_extra + (because it's not your code making the torch.load call). The intended use is like this: + +``` +import torch +from modules import safe + +def handler(module, name): + if module == 'torch' and name in ['float64', 'float16']: + return getattr(torch, name) + + return None + +with safe.Extra(handler): + x = torch.load('model.pt') +``` + """ + + def __init__(self, handler): + self.handler = handler + + def __enter__(self): + global global_extra_handler + + assert global_extra_handler is None, 'already inside an Extra() block' + global_extra_handler = self.handler + + def __exit__(self, exc_type, exc_val, exc_tb): + global global_extra_handler + + global_extra_handler = None unsafe_torch_load = torch.load