remove shits
This commit is contained in:
parent
d08594fa2f
commit
98b0d7a999
@ -1,12 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from copy import copy
|
||||
from typing import List, Any, Optional, Union, Tuple, Dict
|
||||
from typing import List, Optional, Union, Tuple, Dict
|
||||
import numpy as np
|
||||
from modules import scripts, processing, shared
|
||||
from lib_controlnet import global_state
|
||||
from modules import shared
|
||||
from lib_controlnet.logging import logger
|
||||
from lib_controlnet.enums import HiResFixOption, InputMode
|
||||
from lib_controlnet.enums import InputMode
|
||||
|
||||
from modules.api import api
|
||||
|
||||
@ -181,90 +179,6 @@ def to_base64_nparray(encoding: str):
|
||||
return np.array(api.decode_base64_to_image(encoding)).astype('uint8')
|
||||
|
||||
|
||||
def get_all_units_in_processing(p: processing.StableDiffusionProcessing) -> List[ControlNetUnit]:
|
||||
"""
|
||||
Fetch ControlNet processing units from a StableDiffusionProcessing.
|
||||
"""
|
||||
|
||||
return get_all_units(p.scripts, p.script_args)
|
||||
|
||||
|
||||
def get_all_units(script_runner: scripts.ScriptRunner, script_args: List[Any]) -> List[ControlNetUnit]:
|
||||
"""
|
||||
Fetch ControlNet processing units from an existing script runner.
|
||||
Use this function to fetch units from the list of all scripts arguments.
|
||||
"""
|
||||
|
||||
cn_script = find_cn_script(script_runner)
|
||||
if cn_script:
|
||||
return get_all_units_from(script_args[cn_script.args_from:cn_script.args_to])
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def get_all_units_from(script_args: List[Any]) -> List[ControlNetUnit]:
|
||||
"""
|
||||
Fetch ControlNet processing units from ControlNet script arguments.
|
||||
Use `external_code.get_all_units` to fetch units from the list of all scripts arguments.
|
||||
"""
|
||||
|
||||
def is_stale_unit(script_arg: Any) -> bool:
|
||||
""" Returns whether the script_arg is potentially an stale version of
|
||||
ControlNetUnit created before module reload."""
|
||||
return (
|
||||
'ControlNetUnit' in type(script_arg).__name__ and
|
||||
not isinstance(script_arg, ControlNetUnit)
|
||||
)
|
||||
|
||||
def is_controlnet_unit(script_arg: Any) -> bool:
|
||||
""" Returns whether the script_arg is ControlNetUnit or anything that
|
||||
can be treated like ControlNetUnit. """
|
||||
return (
|
||||
isinstance(script_arg, (ControlNetUnit, dict)) or
|
||||
(
|
||||
hasattr(script_arg, '__dict__') and
|
||||
set(vars(ControlNetUnit()).keys()).issubset(
|
||||
set(vars(script_arg).keys()))
|
||||
)
|
||||
)
|
||||
|
||||
all_units = [
|
||||
to_processing_unit(script_arg)
|
||||
for script_arg in script_args
|
||||
if is_controlnet_unit(script_arg)
|
||||
]
|
||||
if not all_units:
|
||||
logger.warning(
|
||||
"No ControlNetUnit detected in args. It is very likely that you are having an extension conflict."
|
||||
f"Here are args received by ControlNet: {script_args}.")
|
||||
if any(is_stale_unit(script_arg) for script_arg in script_args):
|
||||
logger.debug(
|
||||
"Stale version of ControlNetUnit detected. The ControlNetUnit received"
|
||||
"by ControlNet is created before the newest load of ControlNet extension."
|
||||
"They will still be used by ControlNet as long as they provide same fields"
|
||||
"defined in the newest version of ControlNetUnit."
|
||||
)
|
||||
|
||||
return all_units
|
||||
|
||||
|
||||
def get_single_unit_from(script_args: List[Any], index: int = 0) -> Optional[ControlNetUnit]:
|
||||
"""
|
||||
Fetch a single ControlNet processing unit from ControlNet script arguments.
|
||||
The list must not contain script positional arguments. It must only contain processing units.
|
||||
"""
|
||||
|
||||
i = 0
|
||||
while i < len(script_args) and index >= 0:
|
||||
if index == 0 and script_args[i] is not None:
|
||||
return to_processing_unit(script_args[i])
|
||||
i += 1
|
||||
|
||||
index -= 1
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_max_models_num():
|
||||
"""
|
||||
Fetch the maximum number of allowed ControlNet models.
|
||||
@ -272,152 +186,3 @@ def get_max_models_num():
|
||||
|
||||
max_models_num = shared.opts.data.get("control_net_unit_count", 3)
|
||||
return max_models_num
|
||||
|
||||
|
||||
def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNetUnit:
|
||||
"""
|
||||
Convert different types to processing unit.
|
||||
If `unit` is a dict, alternative keys are supported. See `ext_compat_keys` in implementation for details.
|
||||
"""
|
||||
|
||||
ext_compat_keys = {
|
||||
'guessmode': 'guess_mode',
|
||||
'guidance': 'guidance_end',
|
||||
'lowvram': 'low_vram',
|
||||
'input_image': 'image'
|
||||
}
|
||||
|
||||
if isinstance(unit, dict):
|
||||
unit = {ext_compat_keys.get(k, k): v for k, v in unit.items()}
|
||||
|
||||
mask = None
|
||||
if 'mask' in unit:
|
||||
mask = unit['mask']
|
||||
del unit['mask']
|
||||
|
||||
if 'image' in unit and not isinstance(unit['image'], dict):
|
||||
unit['image'] = {'image': unit['image'], 'mask': mask} if mask is not None else unit['image'] if unit[
|
||||
'image'] else None
|
||||
|
||||
if 'guess_mode' in unit:
|
||||
logger.warning('Guess Mode is removed since 1.1.136. Please use Control Mode instead.')
|
||||
|
||||
unit = ControlNetUnit(**{k: v for k, v in unit.items() if k in vars(ControlNetUnit).keys()})
|
||||
|
||||
# temporary, check #602
|
||||
# assert isinstance(unit, ControlNetUnit), f'bad argument to controlnet extension: {unit}\nexpected Union[dict[str, Any], ControlNetUnit]'
|
||||
return unit
|
||||
|
||||
|
||||
def update_cn_script_in_processing(
|
||||
p: processing.StableDiffusionProcessing,
|
||||
cn_units: List[ControlNetUnit],
|
||||
**_kwargs, # for backwards compatibility
|
||||
):
|
||||
"""
|
||||
Update the arguments of the ControlNet script in `p.script_args` in place, reading from `cn_units`.
|
||||
`cn_units` and its elements are not modified. You can call this function repeatedly, as many times as you want.
|
||||
|
||||
Does not update `p.script_args` if any of the folling is true:
|
||||
- ControlNet is not present in `p.scripts`
|
||||
- `p.script_args` is not filled with script arguments for scripts that are processed before ControlNet
|
||||
"""
|
||||
p.script_args = update_cn_script(p.scripts, p.script_args_value, cn_units)
|
||||
|
||||
|
||||
def update_cn_script(
|
||||
script_runner: scripts.ScriptRunner,
|
||||
script_args: Union[Tuple[Any], List[Any]],
|
||||
cn_units: List[ControlNetUnit],
|
||||
) -> Union[Tuple[Any], List[Any]]:
|
||||
"""
|
||||
Returns: The updated `script_args` with given `cn_units` used as ControlNet
|
||||
script args.
|
||||
|
||||
Does not update `script_args` if any of the folling is true:
|
||||
- ControlNet is not present in `script_runner`
|
||||
- `script_args` is not filled with script arguments for scripts that are
|
||||
processed before ControlNet
|
||||
"""
|
||||
script_args_type = type(script_args)
|
||||
assert script_args_type in (tuple, list), script_args_type
|
||||
updated_script_args = list(copy(script_args))
|
||||
|
||||
cn_script = find_cn_script(script_runner)
|
||||
|
||||
if cn_script is None or len(script_args) < cn_script.args_from:
|
||||
return script_args
|
||||
|
||||
# fill in remaining parameters to satisfy max models, just in case script needs it.
|
||||
max_models = shared.opts.data.get("control_net_unit_count", 3)
|
||||
cn_units = cn_units + [ControlNetUnit(enabled=False)] * max(max_models - len(cn_units), 0)
|
||||
|
||||
cn_script_args_diff = 0
|
||||
for script in script_runner.alwayson_scripts:
|
||||
if script is cn_script:
|
||||
cn_script_args_diff = len(cn_units) - (cn_script.args_to - cn_script.args_from)
|
||||
updated_script_args[script.args_from:script.args_to] = cn_units
|
||||
script.args_to = script.args_from + len(cn_units)
|
||||
else:
|
||||
script.args_from += cn_script_args_diff
|
||||
script.args_to += cn_script_args_diff
|
||||
|
||||
return script_args_type(updated_script_args)
|
||||
|
||||
|
||||
def update_cn_script_in_place(
|
||||
script_runner: scripts.ScriptRunner,
|
||||
script_args: List[Any],
|
||||
cn_units: List[ControlNetUnit],
|
||||
**_kwargs, # for backwards compatibility
|
||||
):
|
||||
"""
|
||||
@Deprecated(Raises assertion error if script_args passed in is Tuple)
|
||||
|
||||
Update the arguments of the ControlNet script in `script_args` in place, reading from `cn_units`.
|
||||
`cn_units` and its elements are not modified. You can call this function repeatedly, as many times as you want.
|
||||
|
||||
Does not update `script_args` if any of the folling is true:
|
||||
- ControlNet is not present in `script_runner`
|
||||
- `script_args` is not filled with script arguments for scripts that are processed before ControlNet
|
||||
"""
|
||||
assert isinstance(script_args, list), type(script_args)
|
||||
|
||||
cn_script = find_cn_script(script_runner)
|
||||
if cn_script is None or len(script_args) < cn_script.args_from:
|
||||
return
|
||||
|
||||
# fill in remaining parameters to satisfy max models, just in case script needs it.
|
||||
max_models = shared.opts.data.get("control_net_unit_count", 3)
|
||||
cn_units = cn_units + [ControlNetUnit(enabled=False)] * max(max_models - len(cn_units), 0)
|
||||
|
||||
cn_script_args_diff = 0
|
||||
for script in script_runner.alwayson_scripts:
|
||||
if script is cn_script:
|
||||
cn_script_args_diff = len(cn_units) - (cn_script.args_to - cn_script.args_from)
|
||||
script_args[script.args_from:script.args_to] = cn_units
|
||||
script.args_to = script.args_from + len(cn_units)
|
||||
else:
|
||||
script.args_from += cn_script_args_diff
|
||||
script.args_to += cn_script_args_diff
|
||||
|
||||
|
||||
def find_cn_script(script_runner: scripts.ScriptRunner) -> Optional[scripts.Script]:
|
||||
"""
|
||||
Find the ControlNet script in `script_runner`. Returns `None` if `script_runner` does not contain a ControlNet script.
|
||||
"""
|
||||
|
||||
if script_runner is None:
|
||||
return None
|
||||
|
||||
for script in script_runner.alwayson_scripts:
|
||||
if is_cn_script(script):
|
||||
return script
|
||||
|
||||
|
||||
def is_cn_script(script: scripts.Script) -> bool:
|
||||
"""
|
||||
Determine whether `script` is a ControlNet script.
|
||||
"""
|
||||
|
||||
return script.title().lower() == 'controlnet'
|
||||
|
@ -93,8 +93,7 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
self.paste_field_names = infotext.paste_field_names
|
||||
return tuple(controls)
|
||||
|
||||
def get_enabled_units(self, p):
|
||||
units = external_code.get_all_units_in_processing(p)
|
||||
def get_enabled_units(self, units):
|
||||
enabled_units = [x for x in units if x.enabled]
|
||||
return enabled_units
|
||||
|
||||
@ -397,7 +396,7 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
@torch.no_grad()
|
||||
def process(self, p, *args, **kwargs):
|
||||
self.current_params = {}
|
||||
for i, unit in enumerate(self.get_enabled_units(p)):
|
||||
for i, unit in enumerate(self.get_enabled_units(args)):
|
||||
self.bound_check_params(unit)
|
||||
params = ControlNetCachedParameters()
|
||||
self.process_unit_after_click_generate(p, unit, params, *args, **kwargs)
|
||||
@ -406,14 +405,14 @@ class ControlNetForForgeOfficial(scripts.Script):
|
||||
|
||||
@torch.no_grad()
|
||||
def process_before_every_sampling(self, p, *args, **kwargs):
|
||||
for i, unit in enumerate(self.get_enabled_units(p)):
|
||||
for i, unit in enumerate(self.get_enabled_units(args)):
|
||||
self.process_unit_before_every_sampling(p, unit, self.current_params[i], *args, **kwargs)
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
def postprocess_batch_list(self, p, *args, **kwargs):
|
||||
for i, unit in enumerate(self.get_enabled_units(p)):
|
||||
self.process_unit_after_every_sampling(p, unit, self.current_params[i], *args, **kwargs)
|
||||
def postprocess_batch_list(self, p, pp, *args, **kwargs):
|
||||
for i, unit in enumerate(self.get_enabled_units(args)):
|
||||
self.process_unit_after_every_sampling(p, unit, self.current_params[i], pp, *args, **kwargs)
|
||||
self.current_params = {}
|
||||
return
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user