refactoring
params dont save
This commit is contained in:
parent
ad3473c63e
commit
23755ee917
204
bot.py
204
bot.py
@ -27,18 +27,13 @@ from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||
import inspect
|
||||
from translate import Translator
|
||||
|
||||
API_TOKEN = "TOKEN_HERE"
|
||||
API_TOKEN = "900510503:AAG5Xug_JEERhKlf7dpOpzxXcJIzlTbWX1M"
|
||||
|
||||
bot = Bot(token=API_TOKEN)
|
||||
dp = Dispatcher(bot)
|
||||
|
||||
# Получаем список аргументов функции api.txt2img и возвращаем JSON {"/prompt": "","/seed": "-1",...}
|
||||
def getAttrtxt2img():
|
||||
#argspec = inspect.getfullargspec(api.txt2img)
|
||||
#defaults = argspec.defaults or []
|
||||
#args = argspec.args[1:]
|
||||
#values = list(defaults) + [None] * (len(args) - len(defaults))
|
||||
#params = {arg: str(value) if value is not None else "" for arg, value in zip(args, values)}
|
||||
spec = inspect.getfullargspec(api.txt2img)
|
||||
arguments = spec.args
|
||||
values = [spec.defaults[i] if i >= (len(spec.defaults) or 0)*-1 else None for i in range(-1, (-1)*(len(arguments)+1), -1)][::-1]
|
||||
@ -46,27 +41,8 @@ def getAttrtxt2img():
|
||||
params = {arg: json.loads(value) if isinstance(value, str) and value.startswith(('{', '[')) else json.loads(json.dumps(value)) if value is not None else None for arg, value in params.items()}
|
||||
return params
|
||||
|
||||
def get_argspec_json2(function):
|
||||
spec = inspect.getfullargspec(function)
|
||||
arguments = spec.args
|
||||
# Формирование списка значений аргументов функции
|
||||
# значения берутся из списка defaults (если они определены)
|
||||
# если значение не определено, используется тип по умолчанию
|
||||
values = [spec.defaults[i] if i >= (len(spec.defaults) or 0)*-1 else None for i in range(-1, (-1)*(len(arguments)+1), -1)][::-1]
|
||||
# Формирование словаря аргументов и их значений
|
||||
params = {arg: value for arg, value in zip(arguments, values) if value is not None}
|
||||
|
||||
# Конвертация значений в соответствующие типы
|
||||
params = {arg: json.loads(value) if isinstance(value, str) and value.startswith(('{', '[')) else json.loads(json.dumps(value)) if value is not None else None for arg, value in params.items()}
|
||||
# Формирование словаря типов данных аргументов
|
||||
types = {arg: str(type(value).__name__) for arg, value in params.items()}
|
||||
|
||||
# Конвертация словарей в JSON
|
||||
return json.dumps({'params': params, 'arg_types': types})
|
||||
|
||||
#print(get_argspec_json(api.txt2img))
|
||||
|
||||
# -------- GLOBAL ----------
|
||||
formatted_date = datetime.today().strftime("%Y-%m-%d")
|
||||
host = "127.0.0.1"
|
||||
port = "7861"
|
||||
# https://github.com/mix1009/sdwebuiapi
|
||||
@ -77,11 +53,26 @@ process = None
|
||||
sd = "❌"
|
||||
|
||||
data = getAttrtxt2img()
|
||||
data['prompt'] = 'cat in space' # Ý
|
||||
data['steps'] = 15
|
||||
dataParams = {"img_thumb": "true", "img_tg": "true", "img_real": "true"}
|
||||
dataOld = data.copy()
|
||||
dataOldParams = dataParams.copy()
|
||||
dataOrig = data.copy()
|
||||
|
||||
# -------- CLASSES ----------
|
||||
|
||||
# https://aiogram-birdi7.readthedocs.io/en/latest/examples/finite_state_machine_example.html
|
||||
# Dynamically create a new class with the desired attributes
|
||||
state_classes = {}
|
||||
for key in data:
|
||||
state_classes[key] = State()
|
||||
for key in dataParams:
|
||||
state_classes[key] = State()
|
||||
|
||||
# Inherit from the dynamically created class
|
||||
Form = type("Form", (StatesGroup,), state_classes)
|
||||
|
||||
# -------- FUNCTIONS ----------
|
||||
# Запуск SD через subprocess и запись в глобальную переменную process
|
||||
def start_sd():
|
||||
@ -115,9 +106,9 @@ def pilToImages(res, typeImages="tg"):
|
||||
i = -1
|
||||
for image in imagesAll:
|
||||
# костыль для отсечения первой картинки с гридами
|
||||
if i == -1:
|
||||
i = i + 1
|
||||
continue
|
||||
#if i == -1:
|
||||
# i = i + 1
|
||||
# continue
|
||||
seed = str(res.info["all_seeds"][i])
|
||||
image_buffer = io.BytesIO()
|
||||
image.save(image_buffer, format="PNG")
|
||||
@ -147,6 +138,15 @@ def pilToImages(res, typeImages="tg"):
|
||||
i = i + 1
|
||||
return media_group
|
||||
|
||||
def getJson(params=0):
|
||||
if params == 0:
|
||||
d = data
|
||||
else:
|
||||
d = dataParams
|
||||
json_list = [f"/{key} = {value}" for key, value in d.items()]
|
||||
json_str = "\n".join(json_list)
|
||||
return json_str
|
||||
|
||||
# генератор промптов https://huggingface.co/FredZhang7/distilgpt2-stable-diffusion-v2
|
||||
def get_random_prompt():
|
||||
text = data["prompt"] # from JSON
|
||||
@ -169,6 +169,11 @@ def get_random_prompt():
|
||||
prompt = tokenizer.decode(txt[0], skip_special_tokens=True)
|
||||
return prompt
|
||||
|
||||
# Translate
|
||||
def translateRuToEng(text):
|
||||
translator = Translator(from_lang="ru", to_lang="en")
|
||||
return translator.translate(text)
|
||||
|
||||
# -------- MENU ----------
|
||||
# Стартовое меню
|
||||
def getKeyboard(keysArr, returnAll):
|
||||
@ -179,6 +184,21 @@ def getKeyboard(keysArr, returnAll):
|
||||
else:
|
||||
return keys
|
||||
|
||||
# Стандартное меню
|
||||
async def getKeyboardUnion(txt, message, keyboard):
|
||||
# Если команда /settings
|
||||
if hasattr(message, "content_type"):
|
||||
await bot.send_message(
|
||||
chat_id=message.from_user.id, text=txt, reply_markup=keyboard
|
||||
)
|
||||
else:
|
||||
await bot.edit_message_text(
|
||||
chat_id=message.message.chat.id,
|
||||
message_id=message.message.message_id,
|
||||
text=txt,
|
||||
reply_markup=keyboard,
|
||||
)
|
||||
|
||||
def getStart(returnAll=1) -> InlineKeyboardMarkup:
|
||||
keysArr = [
|
||||
InlineKeyboardButton("sd" + sd, callback_data="sd"),
|
||||
@ -215,7 +235,6 @@ def getOpt(returnAll=1) -> InlineKeyboardMarkup:
|
||||
def getScripts(returnAll=1) -> InlineKeyboardMarkup:
|
||||
keysArr = [
|
||||
InlineKeyboardButton("get_lora", callback_data="get_lora"),
|
||||
InlineKeyboardButton("seed2img", callback_data="seed2img"),
|
||||
]
|
||||
return (getKeyboard(keysArr, returnAll))
|
||||
|
||||
@ -246,7 +265,7 @@ def getGen(returnAll=1) -> InlineKeyboardMarkup:
|
||||
|
||||
# Меню текста
|
||||
def getTxt():
|
||||
return "/start /opt /gen /skip /status /seed2img /help"
|
||||
return "/start /opt /gen /skip /status /help"
|
||||
|
||||
# Проверка связи до запущенной локальной SD с nowebui
|
||||
def ping(status: str):
|
||||
@ -300,6 +319,9 @@ async def inl_sd(callback: types.CallbackQuery) -> None:
|
||||
await callback.message.edit_text(
|
||||
"Запускаем SD\n" + getTxt(), reply_markup=getStart()
|
||||
)
|
||||
#options = {}
|
||||
#options['outdir_txt2img_samples'] = '../../outputs/txt2img-images'
|
||||
#api.set_options(options)
|
||||
ping("start")
|
||||
sd = "✅"
|
||||
await callback.message.edit_text(
|
||||
@ -336,17 +358,19 @@ async def inl_gen1(callback: types.CallbackQuery) -> None:
|
||||
print("inl_gen1")
|
||||
keyboard = InlineKeyboardMarkup(inline_keyboard=[getGen(0), getStart(0)])
|
||||
if callback.data == "gen1":
|
||||
dataOrig["batch_size"] = 1
|
||||
data["batch_size"] = 1
|
||||
if callback.data == "gen4" or callback.data == "gen_hr4":
|
||||
dataOrig["batch_size"] = 4
|
||||
data["batch_size"] = 4
|
||||
if callback.data == "gen10":
|
||||
dataOrig["batch_size"] = 10
|
||||
data["batch_size"] = 10
|
||||
if callback.data == "gen_hr" or callback.data == "gen_hr4":
|
||||
dataOrig["enable_hr"] = "true"
|
||||
dataOrig["hr_resize_x"] = dataOrig["width"] * 2
|
||||
dataOrig["hr_resize_y"] = dataOrig["height"] * 2
|
||||
print(dataOrig)
|
||||
res = api.txt2img(**dataOrig) # TODO заменить dataOrig на data, исправить костыль
|
||||
data["enable_hr"] = "true"
|
||||
data["hr_resize_x"] = data["width"] * 2
|
||||
data["hr_resize_y"] = data["height"] * 2
|
||||
#data['prompt'] = 'толстый кот в машине'
|
||||
#data['prompt'] = translateRuToEng(data['prompt'])
|
||||
print(data)
|
||||
res = api.txt2img(**data)
|
||||
if dataParams["img_thumb"] == "true" or dataParams["img_thumb"] == "True":
|
||||
await bot.send_media_group(
|
||||
chat_id=callback.message.chat.id, media=pilToImages(res, "thumbs")
|
||||
@ -371,6 +395,43 @@ async def inl_gen1(callback: types.CallbackQuery) -> None:
|
||||
async def random_prompt(callback: types.CallbackQuery) -> None:
|
||||
await bot.send_message(chat_id=callback.from_user.id, text=get_random_prompt())
|
||||
|
||||
# Получить опции
|
||||
@dp.message_handler(commands=["opt"])
|
||||
@dp.callback_query_handler(text="opt")
|
||||
async def cmd_opt(message: Union[types.Message, types.CallbackQuery]) -> None:
|
||||
print("cmd_opt")
|
||||
keyboard = InlineKeyboardMarkup(inline_keyboard=[getOpt(0), getStart(0)])
|
||||
await getKeyboardUnion("Опции", message, keyboard)
|
||||
|
||||
# Вызов settings
|
||||
@dp.message_handler(commands=["settings"])
|
||||
@dp.callback_query_handler(text="settings")
|
||||
async def inl_settings(message: Union[types.Message, types.CallbackQuery]) -> None:
|
||||
print("inl_settings")
|
||||
keyboard = InlineKeyboardMarkup(inline_keyboard=[getSet(0), getStart(0)])
|
||||
await getKeyboardUnion("Настройки", message, keyboard)
|
||||
|
||||
# Вызов script
|
||||
@dp.message_handler(commands=["scripts"])
|
||||
@dp.callback_query_handler(text="scripts")
|
||||
async def inl_scripts(message: Union[types.Message, types.CallbackQuery]) -> None:
|
||||
print("inl_scripts")
|
||||
keyboard = InlineKeyboardMarkup(inline_keyboard=[getScripts(0), getStart(0)])
|
||||
await getKeyboardUnion("Скрипты", message, keyboard)
|
||||
|
||||
# Вызов change_param
|
||||
@dp.callback_query_handler(text="change_param")
|
||||
async def inl_change_param(callback: types.CallbackQuery) -> None:
|
||||
print("inl_change_param")
|
||||
keyboard = InlineKeyboardMarkup(inline_keyboard=[getSet(0), getStart(0)])
|
||||
json_list = [f"/{key} = {value}" for key, value in data.items()]
|
||||
json_list_params = [f"/{key} = {value}" for key, value in dataParams.items()]
|
||||
json_str = "\n".join(json_list)
|
||||
json_str_params = "\n".join(json_list_params)
|
||||
await callback.message.edit_text(
|
||||
f"JSON параметры:\n{json_str}\n{json_str_params}", reply_markup=keyboard
|
||||
)
|
||||
|
||||
# Получить LORA
|
||||
@dp.message_handler(commands=["get_lora"])
|
||||
@dp.callback_query_handler(text="get_lora")
|
||||
@ -423,13 +484,13 @@ async def cmd_test(message: Union[types.Message, types.CallbackQuery]) -> None:
|
||||
#options['ttt'] = '11'
|
||||
#api.set_options(options)
|
||||
#print(api.get_options())
|
||||
translator = Translator(to_lang="en")
|
||||
translator = Translator(from_lang="ru", to_lang="en")
|
||||
|
||||
text = input("Введите текст на русском языке: ")
|
||||
text = 'толстый кот в машине'
|
||||
|
||||
translation = translator.translate(text)
|
||||
|
||||
print("Перевод на английский язык: ", translation)
|
||||
print(translation)
|
||||
|
||||
@dp.message_handler(commands=["start2"])
|
||||
async def cmd_start(message: Union[types.Message, types.CallbackQuery]) -> None:
|
||||
@ -444,6 +505,65 @@ async def cmd_test2(message: Union[types.Message, types.CallbackQuery]) -> None:
|
||||
print(getAttrtxt2img()['enable_hr'])
|
||||
print(data['enable_hr'])
|
||||
|
||||
# Ввели любой текст
|
||||
@dp.message_handler(lambda message: True)
|
||||
async def change_json(message: types.Message):
|
||||
print("change_json")
|
||||
keyboard = InlineKeyboardMarkup(inline_keyboard=[getSet(0), getStart(0)])
|
||||
text = message.text
|
||||
print(514)
|
||||
print(text)
|
||||
nam = text.split()[0][1:] # txt из /txt 321
|
||||
state_names = [attr for attr in dir(Form) if isinstance(getattr(Form, attr), State)]
|
||||
print(516)
|
||||
print(nam)
|
||||
print(state_names)
|
||||
args = message.get_args() # это 321, когда ввели /txt 321
|
||||
# Поиск команд из data
|
||||
if nam in state_names:
|
||||
print(524)
|
||||
if args == "":
|
||||
print(526)
|
||||
await message.answer("Напиши любое " + nam)
|
||||
print(528)
|
||||
if nam in state_names:
|
||||
await getattr(Form, nam).set()
|
||||
else:
|
||||
print("Ошибка какая-то")
|
||||
else:
|
||||
# /txt 321 пишем 321 в data['txt']
|
||||
print(533)
|
||||
data[nam] = args
|
||||
# TODO answer поменять на edit_text
|
||||
await message.answer(
|
||||
f"JSON параметры:\n{getJson()}\n{getJson(1)}", reply_markup=keyboard
|
||||
)
|
||||
else:
|
||||
data["prompt"] = message.text
|
||||
await message.answer(
|
||||
f"Записали промпт. JSON параметры:\n{getJson()}\n{getJson(1)}",
|
||||
reply_markup=keyboard,
|
||||
)
|
||||
|
||||
# Ввели ответ на change_json
|
||||
@dp.message_handler(state=Form)
|
||||
async def answer_handler(message: types.Message, state: FSMContext):
|
||||
print('answer_handler')
|
||||
keyboard = InlineKeyboardMarkup(inline_keyboard=[getSet(0), getStart(0)])
|
||||
current_state = await state.get_state() # Form:команда
|
||||
txt = message.text
|
||||
for key, val in dataParams.items():
|
||||
if current_state == "Form:" + key:
|
||||
dataParams[key] = txt
|
||||
break
|
||||
for key, val in data.items():
|
||||
if current_state == "Form:" + key:
|
||||
data[key] = txt
|
||||
break
|
||||
await state.reset_state()
|
||||
await message.answer(
|
||||
f"JSON параметры:\n{getJson()}\n{getJson(1)}", reply_markup=keyboard
|
||||
)
|
||||
|
||||
|
||||
# -------- BOT POLLING ----------
|
||||
|
Loading…
Reference in New Issue
Block a user