my-sd/extensions-builtin/sd_forge_controlnet/lib_controlnet/global_state.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

139 lines
4.4 KiB
Python
Raw Normal View History

import os.path
import stat
from collections import OrderedDict
from modules import shared, sd_models
from lib_controlnet.enums import StableDiffusionVersion
from modules_forge.shared import controlnet_dir, supported_preprocessors
CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin"]
def traverse_all_files(curr_path, model_list):
f_list = [
(os.path.join(curr_path, entry.name), entry.stat())
for entry in os.scandir(curr_path)
if os.path.isdir(curr_path)
]
for f_info in f_list:
fname, fstat = f_info
if os.path.splitext(fname)[1] in CN_MODEL_EXTS:
model_list.append(f_info)
elif stat.S_ISDIR(fstat.st_mode):
model_list = traverse_all_files(fname, model_list)
return model_list
def get_all_models(sort_by, filter_by, path):
res = OrderedDict()
fileinfos = traverse_all_files(path, [])
filter_by = filter_by.strip(" ")
if len(filter_by) != 0:
fileinfos = [x for x in fileinfos if filter_by.lower()
in os.path.basename(x[0]).lower()]
if sort_by == "name":
fileinfos = sorted(fileinfos, key=lambda x: os.path.basename(x[0]))
elif sort_by == "date":
fileinfos = sorted(fileinfos, key=lambda x: -x[1].st_mtime)
elif sort_by == "path name":
fileinfos = sorted(fileinfos)
for finfo in fileinfos:
filename = finfo[0]
name = os.path.splitext(os.path.basename(filename))[0]
# Prevent a hypothetical "None.pt" from being listed.
if name != "None":
res[name + f" [{sd_models.model_hash(filename)}]"] = filename
return res
controlnet_filename_dict = {'None': 'model.safetensors'}
controlnet_names = ['None']
def get_preprocessor(name):
return supported_preprocessors.get(name, None)
def get_sorted_preprocessors():
preprocessors = [p for k, p in supported_preprocessors.items() if k != 'None']
preprocessors = sorted(preprocessors, key=lambda x: str(x.sorting_priority).zfill(8) + x.name)[::-1]
results = OrderedDict()
results['None'] = supported_preprocessors['None']
for p in preprocessors:
results[p.name] = p
return results
def get_all_controlnet_names():
return controlnet_names
def get_controlnet_filename(controlnet_name):
return controlnet_filename_dict[controlnet_name]
def get_all_preprocessor_names():
return list(get_sorted_preprocessors().keys())
def get_all_preprocessor_tags():
tags = []
for k, p in supported_preprocessors.items():
tags += p.tags
tags = list(set(tags))
tags = sorted(tags)
return ['All'] + tags
def get_filtered_preprocessors(tag):
if tag == 'All':
return supported_preprocessors
return {k: v for k, v in get_sorted_preprocessors().items() if tag in v.tags or k == 'None'}
def get_filtered_preprocessor_names(tag):
return list(get_filtered_preprocessors(tag).keys())
def get_filtered_controlnet_names(tag):
filtered_preprocessors = get_filtered_preprocessors(tag)
model_filename_filers = []
for p in filtered_preprocessors.values():
model_filename_filers += p.model_filename_filers
return [x for x in controlnet_names if any(f.lower() in x.lower() for f in model_filename_filers) or x == 'None']
def update_controlnet_filenames():
global controlnet_filename_dict, controlnet_names
controlnet_filename_dict = {'None': 'model.safetensors'}
controlnet_names = ['None']
ext_dirs = (shared.opts.data.get("control_net_models_path", None), getattr(shared.cmd_opts, 'controlnet_dir', None))
extra_lora_paths = (extra_lora_path for extra_lora_path in ext_dirs
if extra_lora_path is not None and os.path.exists(extra_lora_path))
paths = [controlnet_dir, *extra_lora_paths]
for path in paths:
sort_by = shared.opts.data.get("control_net_models_sort_models_by", "name")
filter_by = shared.opts.data.get("control_net_models_name_filter", "")
found = get_all_models(sort_by, filter_by, path)
controlnet_filename_dict.update(found)
controlnet_names = list(controlnet_filename_dict.keys())
return
def get_sd_version() -> StableDiffusionVersion:
if shared.sd_model.is_sdxl:
return StableDiffusionVersion.SDXL
elif shared.sd_model.is_sd2:
return StableDiffusionVersion.SD2x
elif shared.sd_model.is_sd1:
return StableDiffusionVersion.SD1x
else:
return StableDiffusionVersion.UNKNOWN