return old functions + refactoring + new JSON from api

This commit is contained in:
Mikhail Guseletov 2023-06-16 23:22:45 +07:00
parent 06078a31bd
commit ad3473c63e

429
bot.py
View File

@ -25,65 +25,413 @@ from typing import Union
from PIL import Image from PIL import Image
from transformers import GPT2Tokenizer, GPT2LMHeadModel from transformers import GPT2Tokenizer, GPT2LMHeadModel
import inspect import inspect
from translate import Translator
API_TOKEN = "TOKEN_HERE" API_TOKEN = "TOKEN_HERE"
bot = Bot(token=API_TOKEN) bot = Bot(token=API_TOKEN)
dp = Dispatcher(bot) dp = Dispatcher(bot)
# Получаем список аргументов функции api.txt2img и возвращаем JSON {"/prompt": "","/seed": "-1",...}
def getAttrtxt2img():
#argspec = inspect.getfullargspec(api.txt2img)
#defaults = argspec.defaults or []
#args = argspec.args[1:]
#values = list(defaults) + [None] * (len(args) - len(defaults))
#params = {arg: str(value) if value is not None else "" for arg, value in zip(args, values)}
spec = inspect.getfullargspec(api.txt2img)
arguments = spec.args
values = [spec.defaults[i] if i >= (len(spec.defaults) or 0)*-1 else None for i in range(-1, (-1)*(len(arguments)+1), -1)][::-1]
params = {arg: value for arg, value in zip(arguments, values) if value is not None}
params = {arg: json.loads(value) if isinstance(value, str) and value.startswith(('{', '[')) else json.loads(json.dumps(value)) if value is not None else None for arg, value in params.items()}
return params
def get_argspec_json2(function):
spec = inspect.getfullargspec(function)
arguments = spec.args
# Формирование списка значений аргументов функции
# значения берутся из списка defaults (если они определены)
# если значение не определено, используется тип по умолчанию
values = [spec.defaults[i] if i >= (len(spec.defaults) or 0)*-1 else None for i in range(-1, (-1)*(len(arguments)+1), -1)][::-1]
# Формирование словаря аргументов и их значений
params = {arg: value for arg, value in zip(arguments, values) if value is not None}
# Конвертация значений в соответствующие типы
params = {arg: json.loads(value) if isinstance(value, str) and value.startswith(('{', '[')) else json.loads(json.dumps(value)) if value is not None else None for arg, value in params.items()}
# Формирование словаря типов данных аргументов
types = {arg: str(type(value).__name__) for arg, value in params.items()}
# Конвертация словарей в JSON
return json.dumps({'params': params, 'arg_types': types})
#print(get_argspec_json(api.txt2img))
# -------- GLOBAL ---------- # -------- GLOBAL ----------
host = "127.0.0.1" host = "127.0.0.1"
port = "7861" port = "7861"
# https://github.com/mix1009/sdwebuiapi
api = webuiapi.WebUIApi(host=host, port=port) api = webuiapi.WebUIApi(host=host, port=port)
# TODO --share used shared link. https://123456.gradio.live/docs does not work # TODO --share used shared link. https://123456.gradio.live/docs does not work
local = "http://" + host + ":" + port local = "http://" + host + ":" + port
process = None process = None
sd = "" sd = ""
data = getAttrtxt2img()
dataParams = {"img_thumb": "true", "img_tg": "true", "img_real": "true"} dataParams = {"img_thumb": "true", "img_tg": "true", "img_real": "true"}
dataOld = data.copy()
dataOldParams = dataParams.copy()
dataOrig = data.copy()
# -------- FUNCTIONS ----------
# Запуск SD через subprocess и запись в глобальную переменную process
def start_sd():
global process
if not process:
print("start_process sd")
try:
# ../../ launch.py = смотрим из stable-diffusion-webui/extensions/sd_telegram в корень папки stable-diffusion-webui
process = subprocess.Popen(
["python", "../../launch.py", "--nowebui", "--xformers"]
)
# TODO stderr, stdout выводить в сообщение телеграм
except subprocess.CalledProcessError as e:
print("e:", e)
# Остановка SD
def stop_sd():
global process, sd
if process:
print("stop_process sd")
process.terminate()
process = None
sd = ""
def pilToImages(res, typeImages="tg"):
media_group = []
imagesAll = res.images
if len(res.images) == 1:
i = 0
if len(res.images) > 1:
i = -1
for image in imagesAll:
# костыль для отсечения первой картинки с гридами
if i == -1:
i = i + 1
continue
seed = str(res.info["all_seeds"][i])
image_buffer = io.BytesIO()
image.save(image_buffer, format="PNG")
image_buffer.seek(0)
# картинка в телеге
if typeImages == "tg":
media_group.append(types.InputMediaPhoto(media=image_buffer, caption=seed))
# оригинал
if typeImages == "real":
media_group.append(
types.InputMediaDocument(
media=InputFile(image_buffer, filename=seed + ".png"), caption=seed
)
)
# превью
if typeImages == "thumbs":
img = Image.open(image_buffer)
width, height = img.size
# пропорции
ratio = min(256 / width, 256 / height)
new_size = (round(width * ratio), round(height * ratio))
img = img.resize(new_size)
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
media_group.append(types.InputMediaPhoto(media=img_byte_arr, caption=seed))
i = i + 1
return media_group
# генератор промптов https://huggingface.co/FredZhang7/distilgpt2-stable-diffusion-v2
def get_random_prompt():
text = data["prompt"] # from JSON
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
model = GPT2LMHeadModel.from_pretrained("FredZhang7/distilgpt2-stable-diffusion-v2")
input_ids = tokenizer(text, return_tensors="pt").input_ids
txt = model.generate(
input_ids,
do_sample=True,
temperature=0.8,
top_k=8,
max_length=120,
num_return_sequences=1,
repetition_penalty=1.2,
penalty_alpha=0.6,
no_repeat_ngram_size=0,
early_stopping=True,
)
prompt = tokenizer.decode(txt[0], skip_special_tokens=True)
return prompt
# -------- MENU ----------
# Стартовое меню
def getKeyboard(keysArr, returnAll):
keys = keysArr
keyAll = InlineKeyboardMarkup(inline_keyboard=[keys])
if returnAll == 1:
return keyAll
else:
return keys
def getStart(returnAll=1) -> InlineKeyboardMarkup:
keysArr = [
InlineKeyboardButton("sd" + sd, callback_data="sd"),
InlineKeyboardButton("opt", callback_data="opt"),
InlineKeyboardButton("gen", callback_data="gen"),
InlineKeyboardButton("skip", callback_data="skip"),
InlineKeyboardButton("progress", callback_data="progress"),
InlineKeyboardButton("help", callback_data="help"),
]
return (getKeyboard(keysArr, returnAll))
# Меню генераций
def getGen(returnAll=1) -> InlineKeyboardMarkup:
keysArr = [
InlineKeyboardButton("gen1", callback_data="gen1"),
InlineKeyboardButton("gen4", callback_data="gen4"),
InlineKeyboardButton("gen10", callback_data="gen10"),
InlineKeyboardButton("gen_hr", callback_data="gen_hr"),
InlineKeyboardButton("gen_hr4", callback_data="gen_hr4"),
]
return (getKeyboard(keysArr, returnAll))
# Меню опций
def getOpt(returnAll=1) -> InlineKeyboardMarkup:
keysArr = [
InlineKeyboardButton("settings", callback_data="settings"),
InlineKeyboardButton("scripts", callback_data="scripts"),
InlineKeyboardButton("prompt", callback_data="prompt"),
]
return (getKeyboard(keysArr, returnAll))
# Меню скриптов
def getScripts(returnAll=1) -> InlineKeyboardMarkup:
keysArr = [
InlineKeyboardButton("get_lora", callback_data="get_lora"),
InlineKeyboardButton("seed2img", callback_data="seed2img"),
]
return (getKeyboard(keysArr, returnAll))
# Меню настроек
def getSet(returnAll=1) -> InlineKeyboardMarkup:
keysArr = [
InlineKeyboardButton("change_param", callback_data="change_param"),
InlineKeyboardButton("reset_param", callback_data="reset_param"),
]
return (getKeyboard(keysArr, returnAll))
# Меню настроек
def getPrompt(returnAll=1) -> InlineKeyboardMarkup:
keysArr = [InlineKeyboardButton("random_prompt", callback_data="random_prompt")]
return (getKeyboard(keysArr, returnAll))
# Меню генераций
def getGen(returnAll=1) -> InlineKeyboardMarkup:
keysArr = [
InlineKeyboardButton("gen1", callback_data="gen1"),
InlineKeyboardButton("gen4", callback_data="gen4"),
InlineKeyboardButton("gen10", callback_data="gen10"),
InlineKeyboardButton("gen_hr", callback_data="gen_hr"),
InlineKeyboardButton("gen_hr4", callback_data="gen_hr4"),
]
return (getKeyboard(keysArr, returnAll))
# Меню текста
def getTxt():
return "/start /opt /gen /skip /status /seed2img /help"
# Проверка связи до запущенной локальной SD с nowebui
def ping(status: str):
url = local + "/docs"
while True:
try:
r = requests.get(url, timeout=3)
r.raise_for_status()
code = r.status_code
print(code)
if (status == "stop" and code != 200) or (status == "start" and code == 200):
return True
except requests.exceptions.RequestException as err:
print("Error:", err)
time.sleep(3)
# -------- COMMANDS ----------
# start или help
@dp.callback_query_handler(text="help")
@dp.message_handler(commands=["help"])
@dp.message_handler(commands=["start"])
async def cmd_start(message: Union[types.Message, types.CallbackQuery]) -> None:
print("cmd_start")
txt = "Это бот для локального запуска SD\n" + getTxt()
if hasattr(message, "content_type"):
await message.reply(txt, reply_markup=getStart())
else:
await message.message.edit_text(txt, reply_markup=getStart())
# Запуск/Остановка SD. Завязываемся на глобальную иконку sd
@dp.callback_query_handler(text="sd")
async def inl_sd(callback: types.CallbackQuery) -> None:
print("inl_sd")
global sd
if sd == "":
stop_sd()
sd = ""
await callback.message.edit_text(
"Останавливаем SD\n" + getTxt(), reply_markup=getStart()
)
ping("stop")
sd = ""
await callback.message.edit_text(
"SD остановлена\n" + getTxt(), reply_markup=getStart()
)
else:
start_sd()
sd = ""
await callback.message.edit_text(
"Запускаем SD\n" + getTxt(), reply_markup=getStart()
)
ping("start")
sd = ""
await callback.message.edit_text(
"SD запущена\n" + getTxt(), reply_markup=getStart()
)
# Вызов меню генераций getGen
@dp.message_handler(commands=["gen"])
@dp.callback_query_handler(text="gen")
async def inl_gen(message: Union[types.Message, types.CallbackQuery]) -> None:
print("inl_gen")
keyboard = InlineKeyboardMarkup(inline_keyboard=[getGen(0), getStart(0)])
# Если команда /gen
if hasattr(message, "content_type"):
await bot.send_message(
chat_id=message.from_user.id, text="Виды генераций", reply_markup=keyboard
)
else:
await bot.edit_message_text(
chat_id=message.message.chat.id,
message_id=message.message.message_id,
text="Виды генераций",
reply_markup=keyboard,
)
# Генерация изображений
# TODO gen4/gen10
@dp.callback_query_handler(text="gen1")
@dp.callback_query_handler(text="gen4")
@dp.callback_query_handler(text="gen10")
@dp.callback_query_handler(text="gen_hr")
@dp.callback_query_handler(text="gen_hr4")
async def inl_gen1(callback: types.CallbackQuery) -> None:
print("inl_gen1")
keyboard = InlineKeyboardMarkup(inline_keyboard=[getGen(0), getStart(0)])
if callback.data == "gen1":
dataOrig["batch_size"] = 1
if callback.data == "gen4" or callback.data == "gen_hr4":
dataOrig["batch_size"] = 4
if callback.data == "gen10":
dataOrig["batch_size"] = 10
if callback.data == "gen_hr" or callback.data == "gen_hr4":
dataOrig["enable_hr"] = "true"
dataOrig["hr_resize_x"] = dataOrig["width"] * 2
dataOrig["hr_resize_y"] = dataOrig["height"] * 2
print(dataOrig)
res = api.txt2img(**dataOrig) # TODO заменить dataOrig на data, исправить костыль
if dataParams["img_thumb"] == "true" or dataParams["img_thumb"] == "True":
await bot.send_media_group(
chat_id=callback.message.chat.id, media=pilToImages(res, "thumbs")
)
if dataParams["img_tg"] == "true" or dataParams["img_tg"] == "True":
await bot.send_media_group(
chat_id=callback.message.chat.id, media=pilToImages(res, "tg")
)
if dataParams["img_real"] == "true" or dataParams["img_real"] == "True":
await bot.send_media_group(
chat_id=callback.message.chat.id, media=pilToImages(res, "real")
)
print(data["prompt"])
await bot.send_message(
chat_id=callback.message.chat.id,
text=data["prompt"] + "\n" + str(res.info["all_seeds"]),
reply_markup=keyboard,
parse_mode="Markdown",
)
@dp.callback_query_handler(text="random_prompt")
async def random_prompt(callback: types.CallbackQuery) -> None:
await bot.send_message(chat_id=callback.from_user.id, text=get_random_prompt())
# Получить LORA
@dp.message_handler(commands=["get_lora"])
@dp.callback_query_handler(text="get_lora")
async def getLora(message: Union[types.Message, types.CallbackQuery]) -> None:
print("getLora")
keyboard = InlineKeyboardMarkup(inline_keyboard=[getScripts(0), getStart(0)])
# Путь к папке "Lora"
path = "../../models/Lora"
# Получаем список файлов в папке
file_list = os.listdir(path)
# Фильтруем файлы, выбирая только те, которые заканчиваются на ".safetensors"
lora_files = [
file_name for file_name in file_list if file_name.endswith(".safetensors")
]
# Выводим список файлов, отформатированный в нужном формате
arr = ""
for file_name in lora_files:
name = file_name.replace(".safetensors", "")
arr = arr + f"`<lora:{name}:1>`\n\n"
if hasattr(message, "content_type"):
await bot.send_message(
chat_id=message.from_user.id,
text=arr,
reply_markup=keyboard,
parse_mode="Markdown",
)
else:
await bot.edit_message_text(
chat_id=message.message.chat.id,
message_id=message.message.message_id,
text="Список LORA\n" + arr,
reply_markup=keyboard,
parse_mode="Markdown",
)
@dp.message_handler(commands=["test"]) @dp.message_handler(commands=["test"])
async def cmd_test(message: Union[types.Message, types.CallbackQuery]) -> None: async def cmd_test(message: Union[types.Message, types.CallbackQuery]) -> None:
print("cmd_test") print("cmd_test")
options = {} #print(api.get_options())
options['outdir_txt2img_samples'] = '../../outputs/txt2img-images' #options = {}
api.set_options(options) #options['outdir_txt2img_samples'] = '../../outputs/txt2img-images'
#api.set_options(options)
#print(api.get_options())
res = await api.txt2img(prompt="cat in car", #getAttrtxt2img() # options default
negative_prompt="ugly, out of frame", #options = {}
seed=-1, #options['ttt'] = '11'
styles=["anime"], #api.set_options(options)
cfg_scale=7, #print(api.get_options())
sampler_index='DDIM', translator = Translator(to_lang="en")
steps=15,
use_async=True,
save_images=True,
# enable_hr=True,
# hr_scale=2,
# hr_upscaler=webuiapi.HiResUpscaler.Latent,
# hr_second_pass_steps=20,
# hr_resize_x=1536,
# hr_resize_y=1024,
# denoising_strength=0.4,
)
# images contains the returned images (PIL images)
# result1.images
# image is shorthand for images[0] text = input("Введите текст на русском языке: ")
# result1.image
# info contains text info about the api call translation = translator.translate(text)
print(res.info)
# info contains paramteres of the api call print("Перевод на английский язык: ", translation)
print(res.parameters)
print('api.get_options()') @dp.message_handler(commands=["start2"])
print(api.get_options())
print('api.util_get_current_model()')
print(api.util_get_current_model())
@dp.message_handler(commands=["start"])
async def cmd_start(message: Union[types.Message, types.CallbackQuery]) -> None: async def cmd_start(message: Union[types.Message, types.CallbackQuery]) -> None:
print("cmd_start") print("cmd_start")
subprocess.Popen( subprocess.Popen(
@ -93,17 +441,10 @@ async def cmd_start(message: Union[types.Message, types.CallbackQuery]) -> None:
@dp.message_handler(commands=["test2"]) @dp.message_handler(commands=["test2"])
async def cmd_test2(message: Union[types.Message, types.CallbackQuery]) -> None: async def cmd_test2(message: Union[types.Message, types.CallbackQuery]) -> None:
print("cmd_test2") print("cmd_test2")
print(getAttrtxt2img()) print(getAttrtxt2img()['enable_hr'])
print(data['enable_hr'])
def getAttrtxt2img():
# Получаем список аргументов функции api.txt2img и возвращаем JSON
argspec = inspect.getfullargspec(api.txt2img)
defaults = argspec.defaults or []
args = argspec.args[1:]
values = list(defaults) + [None] * (len(args) - len(defaults))
params = {"/" + arg: str(value) if value is not None else "" for arg, value in zip(args, values)}
result = json.dumps(params, indent=4)
return result
# -------- BOT POLLING ---------- # -------- BOT POLLING ----------
if __name__ == "__main__": if __name__ == "__main__":