ini mask support

This commit is contained in:
lllyasviel 2024-02-01 21:21:41 -08:00
parent 07977a1fec
commit 9f6ee2a688
9 changed files with 109 additions and 236 deletions

View File

@ -29,10 +29,11 @@ class PreprocessorInpaintOnly(PreprocessorInpaint):
self.image = None
self.mask = None
self.latent = None
self.fill_mask_with_one_when_resize_and_fill = True
def process_before_every_sampling(self, process, cond, *args, **kwargs):
self.image = kwargs['cond_before_inpaint_fix'][:, 0:3]
self.mask = kwargs['cond_before_inpaint_fix'][:, 3:]
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
self.image = cond
self.mask = mask
vae = process.sd_model.forge_objects.vae
# This is a powerful VAE with integrated memory management, bf16, and tiled fallback.
@ -58,7 +59,10 @@ class PreprocessorInpaintOnly(PreprocessorInpaint):
process.sd_model.forge_objects.unet = unet
self.latent = latent_image
return
mixed_cond = cond * (1.0 - mask) - mask
return mixed_cond, None
def process_after_every_sampling(self, process, params, *args, **kwargs):
a1111_batch_result = args[0]
@ -95,19 +99,21 @@ class PreprocessorInpaintLama(PreprocessorInpaintOnly):
self.setup_model_patcher(model)
return
def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider_3=None, **kwargs):
def __call__(self, input_image, resolution, slider_1=None, slider_2=None, slider_3=None, input_mask=None, **kwargs):
H, W, C = input_image.shape
raw_color = input_image[:, :, 0:3].copy()
raw_mask = input_image[:, :, 3:4].copy()
raw_color = input_image.copy()
raw_mask = input_mask.copy()
input_image, remove_pad = resize_image_with_pad(input_image, 256)
input_mask, remove_pad = resize_image_with_pad(input_mask, 256)
input_mask = input_mask[..., :1]
self.load_model()
self.move_all_model_patchers_to_gpu()
color = np.ascontiguousarray(input_image[:, :, 0:3]).astype(np.float32) / 255.0
mask = np.ascontiguousarray(input_image[:, :, 3:4]).astype(np.float32) / 255.0
color = np.ascontiguousarray(input_image).astype(np.float32) / 255.0
mask = np.ascontiguousarray(input_mask).astype(np.float32) / 255.0
with torch.no_grad():
color = self.send_tensor_to_model_device(torch.from_numpy(color))
mask = self.send_tensor_to_model_device(torch.from_numpy(mask))
@ -128,15 +134,14 @@ class PreprocessorInpaintLama(PreprocessorInpaintOnly):
fin_color = prd_color.astype(np.float32) * alpha + raw_color.astype(np.float32) * (1 - alpha)
fin_color = fin_color.clip(0, 255).astype(np.uint8)
result = np.concatenate([fin_color, raw_mask], axis=2)
return result
return fin_color
def process_before_every_sampling(self, process, cond, *args, **kwargs):
super().process_before_every_sampling(process, cond, *args, **kwargs)
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
cond, mask = super().process_before_every_sampling(process, cond, mask, *args, **kwargs)
sigma_max = process.sd_model.forge_objects.unet.model.model_sampling.sigma_max
original_noise = kwargs['noise']
process.modified_noise = original_noise + self.latent.to(original_noise) / sigma_max.to(original_noise)
return
return cond, mask
add_supported_preprocessor(PreprocessorInpaint())

View File

@ -38,9 +38,9 @@ class PreprocessorRecolor(Preprocessor):
result = cv2.cvtColor(result, cv2.COLOR_GRAY2RGB)
return result
def process_before_every_sampling(self, process, cond, *args, **kwargs):
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
self.current_cond = cond
return
return cond, mask
def process_after_every_sampling(self, process, params, *args, **kwargs):
a1111_batch_result = args[0]

View File

@ -47,7 +47,7 @@ class PreprocessorReference(Preprocessor):
self.recorded_attn1 = {}
self.recorded_h = {}
def process_before_every_sampling(self, process, cond, *args, **kwargs):
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
unit = kwargs['unit']
weight = float(unit.weight)
style_fidelity = float(unit.threshold_a)
@ -190,7 +190,7 @@ class PreprocessorReference(Preprocessor):
process.sd_model.forge_objects.unet = unet
return
return cond, mask
add_supported_preprocessor(PreprocessorReference(

View File

@ -63,7 +63,7 @@ class PreprocessorClipVisionForRevision(PreprocessorClipVision):
self.slider_1 = PreprocessorParameter(
label="Noise Augmentation", minimum=0.0, maximum=1.0, value=0.0, visible=True)
def process_before_every_sampling(self, process, cond, *args, **kwargs):
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
unit = kwargs['unit']
weight = float(unit.weight)
@ -84,7 +84,8 @@ class PreprocessorClipVisionForRevision(PreprocessorClipVision):
unet.add_conditioning_modifier(revision_conditioning_modifier, ensure_uniqueness=True)
process.sd_model.forge_objects.unet = unet
return
return cond, mask
add_supported_preprocessor(PreprocessorClipVisionForRevision(

View File

@ -37,7 +37,7 @@ class PreprocessorTileColorFix(PreprocessorTile):
self.variation = 8
self.sharpness = None
def process_before_every_sampling(self, process, cond, *args, **kwargs):
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
self.variation = int(kwargs['unit'].threshold_a)
latent = self.register_latent(process, cond)
@ -77,7 +77,7 @@ class PreprocessorTileColorFix(PreprocessorTile):
process.sd_model.forge_objects.unet = unet
return
return cond, mask
class PreprocessorTileColorFixSharp(PreprocessorTileColorFix):
@ -86,10 +86,9 @@ class PreprocessorTileColorFixSharp(PreprocessorTileColorFix):
self.name = 'tile_colorfix+sharp'
self.slider_2 = PreprocessorParameter(label='Sharpness', value=1.0, minimum=0.0, maximum=2.0, step=0.01, visible=True)
def process_before_every_sampling(self, process, cond, *args, **kwargs):
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
self.sharpness = float(kwargs['unit'].threshold_b)
super().process_before_every_sampling(process, cond, *args, **kwargs)
return
return super().process_before_every_sampling(process, cond, mask, *args, **kwargs)
add_supported_preprocessor(PreprocessorTile())

View File

@ -215,7 +215,6 @@ class ControlNetUiGroup(object):
self.image_upload_panel = None
self.save_detected_map = None
self.input_mode = gr.State(InputMode.SIMPLE)
self.inpaint_crop_input_image = None
self.hr_option = None
self.batch_image_dir_state = None
self.output_dir_state = None
@ -443,15 +442,6 @@ class ControlNetUiGroup(object):
else:
self.upload_independent_img_in_img2img = None
# Note: The checkbox needs to exist for both img2img and txt2img as infotext
# needs the checkbox value.
self.inpaint_crop_input_image = gr.Checkbox(
label="Crop input image based on A1111 mask",
value=False,
elem_classes=["cnet-crop-input-image"],
visible=False,
)
with gr.Row(elem_classes=["controlnet_control_type", "controlnet_row"]):
self.type_filter = gr.Radio(
global_state.get_all_preprocessor_tags(),
@ -805,18 +795,10 @@ class ControlNetUiGroup(object):
)
img = HWC3(image["image"])
has_mask = not (
(image["mask"][:, :, 0] <= 5).all()
or (image["mask"][:, :, 0] >= 250).all()
)
if "inpaint" in module:
color = HWC3(image["image"])
alpha = image["mask"][:, :, 0:1]
img = np.concatenate([color, alpha], axis=2)
elif has_mask and not shared.opts.data.get(
"controlnet_ignore_noninpaint_mask", False
):
img = HWC3(image["mask"][:, :, 0])
mask = HWC3(image["mask"])
if not (mask > 5).any():
mask = None
preprocessor = global_state.get_preprocessor(module)
@ -854,6 +836,7 @@ class ControlNetUiGroup(object):
resolution=pres,
slider_1=pthr_a,
slider_2=pthr_b,
input_mask=mask,
low_vram=shared.opts.data.get("controlnet_clip_detector_on_cpu", False),
json_pose_callback=json_acceptor.accept
if is_openpose(module)
@ -976,52 +959,9 @@ class ControlNetUiGroup(object):
)
def register_shift_crop_input_image(self):
# A1111 < 1.7.0 compatibility.
if any(c is None for c in ControlNetUiGroup.a1111_context.img2img_inpaint_tabs):
self.inpaint_crop_input_image.visible = True
self.inpaint_crop_input_image.value = True
return
is_inpaint_tab = gr.State(False)
def shift_crop_input_image(is_inpaint: bool, inpaint_area: int):
# Note: inpaint_area (0: Whole picture, 1: Only masked)
# By default set value to True, as most preprocessors need cropped result.
return gr.update(value=True, visible=is_inpaint and inpaint_area == 1)
gradio_kwargs = dict(
fn=shift_crop_input_image,
inputs=[
is_inpaint_tab,
ControlNetUiGroup.a1111_context.img2img_inpaint_area,
],
outputs=[self.inpaint_crop_input_image],
show_progress=False,
)
for elem in ControlNetUiGroup.a1111_context.img2img_inpaint_tabs:
elem.select(fn=lambda: True, inputs=[], outputs=[is_inpaint_tab]).then(
**gradio_kwargs
)
for elem in ControlNetUiGroup.a1111_context.img2img_non_inpaint_tabs:
elem.select(fn=lambda: False, inputs=[], outputs=[is_inpaint_tab]).then(
**gradio_kwargs
)
ControlNetUiGroup.a1111_context.img2img_inpaint_area.change(**gradio_kwargs)
return
def register_shift_hr_options(self):
# # A1111 version < 1.6.0.
# if not ControlNetUiGroup.a1111_context.txt2img_enable_hr:
# return
#
# ControlNetUiGroup.a1111_context.txt2img_enable_hr.change(
# fn=lambda checked: gr.update(visible=checked),
# inputs=[ControlNetUiGroup.a1111_context.txt2img_enable_hr],
# outputs=[self.hr_option],
# show_progress=False,
# )
return
def register_shift_upload_mask(self):

View File

@ -317,11 +317,6 @@ def high_quality_resize(x, size):
# Written by lvmin
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
inpaint_mask = None
if x.ndim == 3 and x.shape[2] == 4:
inpaint_mask = x[:, :, 3]
x = x[:, :, 0:3]
if x.shape[0] != size[1] or x.shape[1] != size[0]:
new_size_is_smaller = (size[0] * size[1]) < (x.shape[0] * x.shape[1])
new_size_is_bigger = (size[0] * size[1]) > (x.shape[0] * x.shape[1])
@ -346,8 +341,6 @@ def high_quality_resize(x, size):
interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS
y = cv2.resize(x, size, interpolation=interpolation)
if inpaint_mask is not None:
inpaint_mask = cv2.resize(inpaint_mask, size, interpolation=interpolation)
if is_binary:
y = np.mean(y.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8)
@ -361,15 +354,10 @@ def high_quality_resize(x, size):
else:
y = x
if inpaint_mask is not None:
inpaint_mask = (inpaint_mask > 127).astype(np.float32) * 255.0
inpaint_mask = inpaint_mask[:, :, None].clip(0, 255).astype(np.uint8)
y = np.concatenate([y, inpaint_mask], axis=2)
return y
def crop_and_resize_image(detected_map, resize_mode, h, w):
def crop_and_resize_image(detected_map, resize_mode, h, w, fill_border_with_255=False):
if resize_mode == external_code.ResizeMode.RESIZE:
detected_map = high_quality_resize(detected_map, (w, h))
detected_map = safe_numpy(detected_map)
@ -387,9 +375,8 @@ def crop_and_resize_image(detected_map, resize_mode, h, w):
k = min(k0, k1)
borders = np.concatenate([detected_map[0, :, :], detected_map[-1, :, :], detected_map[:, 0, :], detected_map[:, -1, :]], axis=0)
high_quality_border_color = np.median(borders, axis=0).astype(detected_map.dtype)
if len(high_quality_border_color) == 4:
# Inpaint hijack
high_quality_border_color[3] = 255
if fill_border_with_255:
high_quality_border_color = np.zeros_like(high_quality_border_color) + 255
high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1])
detected_map = high_quality_resize(detected_map, (safeint(old_w * k), safeint(old_h * k)))
new_h, new_w, _ = detected_map.shape

View File

@ -1,6 +1,7 @@
import os
from typing import Dict, Optional, Tuple, List, Union
import cv2
import torch
import modules.scripts as scripts
@ -46,6 +47,8 @@ class ControlNetCachedParameters:
self.model = None
self.control_cond = None
self.control_cond_for_hr_fix = None
self.control_mask = None
self.control_mask_for_hr_fix = None
class ControlNetForForgeOfficial(scripts.Script):
@ -95,105 +98,6 @@ class ControlNetForForgeOfficial(scripts.Script):
enabled_units = [x for x in units if x.enabled]
return enabled_units
def choose_input_image(
self,
p: processing.StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
) -> Tuple[np.ndarray, external_code.ResizeMode]:
""" Choose input image from following sources with descending priority:
- p.image_control: [Deprecated] Lagacy way to pass image to controlnet.
- p.control_net_input_image: [Deprecated] Lagacy way to pass image to controlnet.
- unit.image: ControlNet tab input image.
- p.init_images: A1111 img2img tab input image.
Returns:
- The input image in ndarray form.
- The resize mode.
"""
def parse_unit_image(unit: external_code.ControlNetUnit) -> Union[
List[Dict[str, np.ndarray]], Dict[str, np.ndarray]]:
unit_has_multiple_images = (
isinstance(unit.image, list) and
len(unit.image) > 0 and
"image" in unit.image[0]
)
if unit_has_multiple_images:
return [
d
for img in unit.image
for d in (image_dict_from_any(img),)
if d is not None
]
return image_dict_from_any(unit.image)
def decode_image(img) -> np.ndarray:
"""Need to check the image for API compatibility."""
if isinstance(img, str):
return np.asarray(decode_base64_to_image(image['image']))
else:
assert isinstance(img, np.ndarray)
return img
# 4 input image sources.
image = parse_unit_image(unit)
a1111_image = getattr(p, "init_images", [None])[0]
resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
if image is not None:
if isinstance(image, list):
# Add mask logic if later there is a processor that accepts mask
# on multiple inputs.
input_image = [HWC3(decode_image(img['image'])) for img in image]
else:
input_image = HWC3(decode_image(image['image']))
if 'mask' in image and image['mask'] is not None:
while len(image['mask'].shape) < 3:
image['mask'] = image['mask'][..., np.newaxis]
if 'inpaint' in unit.module:
logger.info("using inpaint as input")
color = HWC3(image['image'])
alpha = image['mask'][:, :, 0:1]
input_image = np.concatenate([color, alpha], axis=2)
elif (
not shared.opts.data.get("controlnet_ignore_noninpaint_mask", False) and
# There is wield gradio issue that would produce mask that is
# not pure color when no scribble is made on canvas.
# See https://github.com/Mikubill/sd-webui-controlnet/issues/1638.
not (
(image['mask'][:, :, 0] <= 5).all() or
(image['mask'][:, :, 0] >= 250).all()
)
):
logger.info("using mask as input")
input_image = HWC3(image['mask'][:, :, 0])
unit.module = 'none' # Always use black bg and white line
elif a1111_image is not None:
input_image = HWC3(np.asarray(a1111_image))
a1111_i2i_resize_mode = getattr(p, "resize_mode", None)
assert a1111_i2i_resize_mode is not None
resize_mode = external_code.resize_mode_from_value(a1111_i2i_resize_mode)
a1111_mask_image: Optional[Image.Image] = getattr(p, "image_mask", None)
if 'inpaint' in unit.module:
if a1111_mask_image is not None:
a1111_mask = np.array(prepare_mask(a1111_mask_image, p))
assert a1111_mask.ndim == 2
assert a1111_mask.shape[0] == input_image.shape[0]
assert a1111_mask.shape[1] == input_image.shape[1]
input_image = np.concatenate([input_image[:, :, 0:3], a1111_mask[:, :, None]], axis=2)
else:
input_image = np.concatenate([
input_image[:, :, 0:3],
np.zeros_like(input_image, dtype=np.uint8)[:, :, 0:1],
], axis=2)
else:
raise ValueError("controlnet is enabled but no input image is given")
assert isinstance(input_image, (np.ndarray, list))
return input_image, resize_mode
@staticmethod
def try_crop_image_with_a1111_mask(
p: StableDiffusionProcessing,
@ -202,19 +106,6 @@ class ControlNetForForgeOfficial(scripts.Script):
resize_mode: external_code.ResizeMode,
preprocessor
) -> np.ndarray:
"""
Crop ControlNet input image based on A1111 inpaint mask given.
This logic is crutial in upscale scripts, as they use A1111 mask + inpaint_full_res
to crop tiles.
"""
# Note: The method determining whether the active script is an upscale script is purely
# based on `extra_generation_params` these scripts attach on `p`, and subject to change
# in the future.
# TODO: Change this to a more robust condition once A1111 offers a way to verify script name.
is_upscale_script = any("upscale" in k.lower() for k in getattr(p, "extra_generation_params", {}).keys())
logger.debug(f"is_upscale_script={is_upscale_script}")
# Note: `inpaint_full_res` is "inpaint area" on UI. The flag is `True` when "Only masked"
# option is selected.
a1111_mask_image: Optional[Image.Image] = getattr(p, "image_mask", None)
is_only_masked_inpaint = (
issubclass(type(p), StableDiffusionProcessingImg2Img) and
@ -224,9 +115,8 @@ class ControlNetForForgeOfficial(scripts.Script):
if (
preprocessor.corp_image_with_a1111_mask_when_in_img2img_inpaint_tab
and is_only_masked_inpaint
and (is_upscale_script or unit.inpaint_crop_input_image)
):
logger.debug("Crop input image based on A1111 mask.")
logger.info("Crop input image based on A1111 mask.")
input_image = [input_image[:, :, i] for i in range(input_image.shape[2])]
input_image = [Image.fromarray(x) for x in input_image]
@ -251,12 +141,47 @@ class ControlNetForForgeOfficial(scripts.Script):
return input_image
def get_input_data(self, p, unit, preprocessor):
mask = None
input_image, resize_mode = self.choose_input_image(p, unit)
assert isinstance(input_image, np.ndarray), 'Invalid input image!'
input_image = self.try_crop_image_with_a1111_mask(p, unit, input_image, resize_mode, preprocessor)
input_image = np.ascontiguousarray(input_image.copy()).copy() # safe numpy
return input_image, mask, resize_mode
a1111_i2i_image = getattr(p, "init_images", [None])[0]
a1111_i2i_mask = getattr(p, "image_mask", None)
using_a1111_data = False
resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
if unit.use_preview_as_input and unit.generated_image is not None:
image = unit.generated_image
elif unit.image is None:
resize_mode = external_code.resize_mode_from_value(p.resize_mode)
image = HWC3(np.asarray(a1111_i2i_image))
using_a1111_data = True
elif (unit.image['image'] < 5).all() and (unit.image['mask'] > 5).any():
image = unit.image['mask']
else:
image = unit.image['image']
if not isinstance(image, np.ndarray):
raise ValueError("controlnet is enabled but no input image is given")
image = HWC3(image)
if using_a1111_data:
mask = HWC3(np.asarray(a1111_i2i_mask))
elif unit.mask_image is not None and (unit.mask_image['image'] > 5).any():
mask = unit.mask_image['image']
elif unit.mask_image is not None and (unit.mask_image['mask'] > 5).any():
mask = unit.mask_image['mask']
elif unit.image is not None and (unit.image['mask'] > 5).any():
mask = unit.image['mask']
else:
mask = None
image = self.try_crop_image_with_a1111_mask(p, unit, image, resize_mode, preprocessor)
if mask is not None:
mask = cv2.resize(HWC3(mask), (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
mask = self.try_crop_image_with_a1111_mask(p, unit, mask, resize_mode, preprocessor)
return image, mask, resize_mode
@staticmethod
def get_target_dimensions(p: StableDiffusionProcessing) -> Tuple[int, int, int, int]:
@ -296,9 +221,13 @@ class ControlNetForForgeOfficial(scripts.Script):
and getattr(p, 'enable_hr', False)
)
if unit.use_preview_as_input:
unit.module = 'None'
preprocessor = global_state.get_preprocessor(unit.module)
input_image, input_mask, resize_mode = self.get_input_data(p, unit, preprocessor)
# p.extra_result_images.append(input_image)
if unit.pixel_perfect:
unit.processor_res = external_code.pixel_perfect_resolution(
@ -339,10 +268,24 @@ class ControlNetForForgeOfficial(scripts.Script):
params.control_cond_for_hr_fix = preprocessor_output
p.extra_result_images.append(input_image)
if input_mask is not None:
fill_border = preprocessor.fill_mask_with_one_when_resize_and_fill
params.control_mask = crop_and_resize_image(input_mask, resize_mode, h, w, fill_border)
p.extra_result_images.append(params.control_mask)
params.control_mask = numpy_to_pytorch(params.control_mask).movedim(-1, 1)[:, :1]
if has_high_res_fix:
params.control_mask_for_hr_fix = crop_and_resize_image(input_mask, resize_mode, hr_y, hr_x, fill_border)
p.extra_result_images.append(params.control_mask_for_hr_fix)
params.control_mask_for_hr_fix = numpy_to_pytorch(params.control_mask_for_hr_fix).movedim(-1, 1)[:, :1]
else:
params.control_mask_for_hr_fix = params.control_mask
if preprocessor.do_not_need_model:
model_filename = 'Not Needed'
params.model = ControlModelPatcher()
else:
assert unit.model != 'None', 'You have not selected any control model!'
model_filename = global_state.get_controlnet_filename(unit.model)
params.model = cached_controlnet_loader(model_filename)
assert params.model is not None, logger.error(f"Recognizing Control Model failed: {model_filename}")
@ -366,16 +309,13 @@ class ControlNetForForgeOfficial(scripts.Script):
if is_hr_pass:
cond = params.control_cond_for_hr_fix
mask = params.control_mask_for_hr_fix
else:
cond = params.control_cond
mask = params.control_mask
kwargs.update(dict(unit=unit, params=params))
# CN inpaint fix
if isinstance(cond, torch.Tensor) and cond.ndim == 4 and cond.shape[1] == 4:
kwargs['cond_before_inpaint_fix'] = cond.clone()
cond = cond[:, :3] * (1.0 - cond[:, 3:]) - cond[:, 3:]
params.model.strength = float(unit.weight)
params.model.start_percent = float(unit.guidance_start)
params.model.end_percent = float(unit.guidance_end)
@ -411,8 +351,8 @@ class ControlNetForForgeOfficial(scripts.Script):
params.model.positive_advanced_weighting = soft_weighting.copy()
params.model.negative_advanced_weighting = soft_weighting.copy()
params.preprocessor.process_before_every_sampling(p, cond, *args, **kwargs)
params.model.process_before_every_sampling(p, cond, *args, **kwargs)
cond, mask = params.preprocessor.process_before_every_sampling(p, cond, mask, *args, **kwargs)
params.model.process_before_every_sampling(p, cond, mask, *args, **kwargs)
logger.info(f"ControlNet Method {params.preprocessor.name} patched.")
return

View File

@ -31,6 +31,7 @@ class Preprocessor:
self.do_not_need_model = False
self.sorting_priority = 0 # higher goes to top in the list
self.corp_image_with_a1111_mask_when_in_img2img_inpaint_tab = True
self.fill_mask_with_one_when_resize_and_fill = False
def setup_model_patcher(self, model, load_device=None, offload_device=None, dtype=torch.float32, **kwargs):
if load_device is None:
@ -59,8 +60,8 @@ class Preprocessor:
def process_after_running_preprocessors(self, process, params, *args, **kwargs):
return
def process_before_every_sampling(self, process, cond, *args, **kwargs):
return
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):
return cond, mask
def process_after_every_sampling(self, process, params, *args, **kwargs):
return