my-sd/modules_forge/shared.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

39 lines
1.1 KiB
Python
Raw Normal View History

2024-01-28 17:08:33 +00:00
import os
import ldm_patched.modules.utils
2024-01-28 17:08:33 +00:00
2024-01-31 23:22:06 +00:00
from modules.paths_internal import models_path
2024-01-28 17:08:33 +00:00
controlnet_dir = os.path.join(models_path, 'ControlNet')
os.makedirs(controlnet_dir, exist_ok=True)
preprocessor_dir = os.path.join(models_path, 'ControlNetPreprocessor')
os.makedirs(preprocessor_dir, exist_ok=True)
2024-01-28 14:44:41 +00:00
supported_preprocessors = {}
supported_control_models = []
2024-01-28 14:44:41 +00:00
def add_supported_preprocessor(preprocessor):
global supported_preprocessors
2024-01-29 16:03:03 +00:00
p = preprocessor
supported_preprocessors[p.name] = p
2024-01-28 18:27:43 +00:00
return
def add_supported_control_model(control_model):
global supported_control_models
supported_control_models.append(control_model)
return
2024-01-28 14:44:41 +00:00
def try_load_supported_control_model(ckpt_path):
global supported_control_models
state_dict = ldm_patched.modules.utils.load_torch_file(ckpt_path, safe_load=True)
for supported_type in supported_control_models:
state_dict_copy = {k: v for k, v in state_dict.items()}
model = supported_type.try_build_from_state_dict(state_dict_copy, ckpt_path)
if model is not None:
return model
return None