my-sd/extensions-builtin/sd_forge_controlnet/lib_controlnet/infotext.py
lllyasviel ac374e0b97
Cnet (#22)
* ini

* remove shit

* Create control_model.py

* i

* i

* Update controlnet_supported.py

* Update controlnet_supported.py

* Update controlnet_supported.py

* i

* i

* Update controlnet_supported.py

* i

* Update controlnet_supported.py

* remove shits

* remove shit

* Update global_state.py

* i

* i

* Update legacy_preprocessors.py

* Update legacy_preprocessors.py

* remove shit

* Update batch_hijack.py

* remove shit

* remove shit

* i

* i

* i

* Update external_code.py

* Update global_state.py

* Update infotext.py

* Update utils.py

* Update external_code.py

* i

* i

* i

* Update controlnet_ui_group.py

* remove shit

* remove shit

* i

* Update controlnet.py

* Update controlnet.py

* Update controlnet.py

* Update controlnet.py

* Update controlnet.py

* i

* Update global_state.py

* Update global_state.py

* i

* Update global_state.py

* Update global_state.py

* Update global_state.py

* Update global_state.py

* Update controlnet_ui_group.py

* i

* Update global_state.py

* Update controlnet_ui_group.py

* Update controlnet_ui_group.py

* i

* Update controlnet_ui_group.py

* Update controlnet_ui_group.py

* Update controlnet_ui_group.py

* Update controlnet_ui_group.py
2024-01-29 14:25:03 -08:00

136 lines
4.6 KiB
Python

from typing import List, Tuple, Union
import gradio as gr
from modules.processing import StableDiffusionProcessing
from lib_controlnet import external_code
from lib_controlnet.logging import logger
def field_to_displaytext(fieldname: str) -> str:
return " ".join([word.capitalize() for word in fieldname.split("_")])
def displaytext_to_field(text: str) -> str:
return "_".join([word.lower() for word in text.split(" ")])
def parse_value(value: str) -> Union[str, float, int, bool]:
if value in ("True", "False"):
return value == "True"
try:
return int(value)
except ValueError:
try:
return float(value)
except ValueError:
return value # Plain string.
def serialize_unit(unit: external_code.ControlNetUnit) -> str:
excluded_fields = (
"image",
"enabled",
# Note: "advanced_weighting" is excluded as it is an API-only field.
"advanced_weighting",
# Note: "inpaint_crop_image" is img2img inpaint only flag, which does not
# provide much information when restoring the unit.
"inpaint_crop_input_image",
)
log_value = {
field_to_displaytext(field): getattr(unit, field)
for field in vars(external_code.ControlNetUnit()).keys()
if field not in excluded_fields and getattr(unit, field) != -1
# Note: exclude hidden slider values.
}
if not all("," not in str(v) and ":" not in str(v) for v in log_value.values()):
logger.error(f"Unexpected tokens encountered:\n{log_value}")
return ""
return ", ".join(f"{field}: {value}" for field, value in log_value.items())
def parse_unit(text: str) -> external_code.ControlNetUnit:
return external_code.ControlNetUnit(
enabled=True,
**{
displaytext_to_field(key): parse_value(value)
for item in text.split(",")
for (key, value) in (item.strip().split(": "),)
},
)
class Infotext(object):
def __init__(self) -> None:
self.infotext_fields: List[Tuple[gr.components.IOComponent, str]] = []
self.paste_field_names: List[str] = []
@staticmethod
def unit_prefix(unit_index: int) -> str:
return f"ControlNet {unit_index}"
def register_unit(self, unit_index: int, uigroup) -> None:
"""Register the unit's UI group. By regsitering the unit, A1111 will be
able to paste values from infotext to IOComponents.
Args:
unit_index: The index of the ControlNet unit
uigroup: The ControlNetUiGroup instance that contains all gradio
iocomponents.
"""
unit_prefix = Infotext.unit_prefix(unit_index)
for field in vars(external_code.ControlNetUnit()).keys():
# Exclude image for infotext.
if field == "image":
continue
# Every field in ControlNetUnit should have a cooresponding
# IOComponent in ControlNetUiGroup.
io_component = getattr(uigroup, field)
component_locator = f"{unit_prefix} {field}"
self.infotext_fields.append((io_component, component_locator))
self.paste_field_names.append(component_locator)
@staticmethod
def write_infotext(
units: List[external_code.ControlNetUnit], p: StableDiffusionProcessing
):
"""Write infotext to `p`."""
p.extra_generation_params.update(
{
Infotext.unit_prefix(i): serialize_unit(unit)
for i, unit in enumerate(units)
if unit.enabled
}
)
@staticmethod
def on_infotext_pasted(infotext: str, results: dict) -> None:
"""Parse ControlNet infotext string and write result to `results` dict."""
updates = {}
for k, v in results.items():
if not k.startswith("ControlNet"):
continue
assert isinstance(v, str), f"Expect string but got {v}."
try:
for field, value in vars(parse_unit(v)).items():
if field == "image":
continue
if value is None:
logger.debug(f"InfoText: Skipping {field} because value is None.")
continue
component_locator = f"{k} {field}"
updates[component_locator] = value
logger.debug(f"InfoText: Setting {component_locator} = {value}")
except Exception as e:
logger.warn(
f"Failed to parse infotext, legacy format infotext is no longer supported:\n{v}\n{e}"
)
results.update(updates)