forked from soaska/sd_telegram
fix use_prompt from random
add logging (not completely)
This commit is contained in:
parent
82ed25552e
commit
51b8eca33c
82
bot.py
82
bot.py
@ -30,6 +30,15 @@ import inspect
|
|||||||
from translate import Translator
|
from translate import Translator
|
||||||
import base64
|
import base64
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# Настройка логгера
|
||||||
|
logging.basicConfig(format="[%(asctime)s] %(levelname)s : %(name)s : %(message)s",
|
||||||
|
level=logging.DEBUG, datefmt="%d-%m-%y %H:%M:%S")
|
||||||
|
|
||||||
|
logging.getLogger('aiogram').setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# from https://t.me/BotFather
|
# from https://t.me/BotFather
|
||||||
API_TOKEN = "TOKEN_HERE"
|
API_TOKEN = "TOKEN_HERE"
|
||||||
@ -157,7 +166,9 @@ def getJson(params=0):
|
|||||||
return json_str
|
return json_str
|
||||||
|
|
||||||
# генератор промптов https://huggingface.co/FredZhang7/distilgpt2-stable-diffusion-v2
|
# генератор промптов https://huggingface.co/FredZhang7/distilgpt2-stable-diffusion-v2
|
||||||
def get_random_prompt(text = data["prompt"], max_length = 120):
|
def get_random_prompt(text = data['prompt'], max_length = 120):
|
||||||
|
if str(dataParams['use_prompt']).lower() == 'true':
|
||||||
|
text = data['prompt']
|
||||||
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
||||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
model = GPT2LMHeadModel.from_pretrained("FredZhang7/distilgpt2-stable-diffusion-v2")
|
model = GPT2LMHeadModel.from_pretrained("FredZhang7/distilgpt2-stable-diffusion-v2")
|
||||||
@ -178,7 +189,10 @@ def get_random_prompt(text = data["prompt"], max_length = 120):
|
|||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def rnd_prmt_lxc():
|
def rnd_prmt_lxc():
|
||||||
txt = random.choice(submit_get('https://lexica.art/api/v1/search?q='+data['prompt'], '').json()['images'])['prompt']
|
txt = data['prompt']
|
||||||
|
if str(dataParams['use_prompt']).lower() == 'false':
|
||||||
|
txt = dataOrig['prompt']
|
||||||
|
txt = random.choice(submit_get('https://lexica.art/api/v1/search?q='+txt, '').json()['images'])['prompt']
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
# get settings. TODO - cut 4000 symbols
|
# get settings. TODO - cut 4000 symbols
|
||||||
@ -480,7 +494,8 @@ async def show_thumbs(chat_id, res):
|
|||||||
@dp.message_handler(commands=["help"])
|
@dp.message_handler(commands=["help"])
|
||||||
@dp.message_handler(commands=["start"])
|
@dp.message_handler(commands=["start"])
|
||||||
async def cmd_start(message: Union[types.Message, types.CallbackQuery]) -> None:
|
async def cmd_start(message: Union[types.Message, types.CallbackQuery]) -> None:
|
||||||
print("cmd_start")
|
logging.info("cmd_start")
|
||||||
|
#print("cmd_start")
|
||||||
txt = "Это бот для локального запуска SD\n" + getTxt()
|
txt = "Это бот для локального запуска SD\n" + getTxt()
|
||||||
await getKeyboardUnion(txt, message, getStart())
|
await getKeyboardUnion(txt, message, getStart())
|
||||||
|
|
||||||
@ -636,35 +651,39 @@ async def inl_gen(message: Union[types.Message, types.CallbackQuery]) -> None:
|
|||||||
dataPromptOld = data['prompt']
|
dataPromptOld = data['prompt']
|
||||||
if sd == '✅':
|
if sd == '✅':
|
||||||
for itemTxt in data['prompt'].split(';'):
|
for itemTxt in data['prompt'].split(';'):
|
||||||
msgTime = await bot.send_message(
|
try:
|
||||||
chat_id=chatId,
|
msgTime = await bot.send_message(
|
||||||
text='Начали'
|
chat_id=chatId,
|
||||||
)
|
text='Начали'
|
||||||
# Включаем асинхрон, чтоб заработал await api.txt2img
|
|
||||||
data["use_async"] = "True"
|
|
||||||
data["prompt"] = itemTxt # только для **data
|
|
||||||
asyncio.create_task(getProgress(msgTime))
|
|
||||||
# TODO try catch if wrong data
|
|
||||||
res = await api.txt2img(**data)
|
|
||||||
# show_thumbs dont work because use_async
|
|
||||||
if dataParams["img_thumb"] == "true" or dataParams["img_thumb"] == "True":
|
|
||||||
await bot.send_media_group(
|
|
||||||
chat_id=chatId, media=pilToImages(res, "thumbs")
|
|
||||||
)
|
)
|
||||||
if dataParams["img_tg"] == "true" or dataParams["img_tg"] == "True":
|
# Включаем асинхрон, чтоб заработал await api.txt2img
|
||||||
await bot.send_media_group(
|
data["use_async"] = "True"
|
||||||
chat_id=chatId, media=pilToImages(res, "tg")
|
data["prompt"] = itemTxt # только для **data
|
||||||
|
asyncio.create_task(getProgress(msgTime))
|
||||||
|
# TODO try catch if wrong data
|
||||||
|
res = await api.txt2img(**data)
|
||||||
|
# show_thumbs dont work because use_async
|
||||||
|
if dataParams["img_thumb"] == "true" or dataParams["img_thumb"] == "True":
|
||||||
|
await bot.send_media_group(
|
||||||
|
chat_id=chatId, media=pilToImages(res, "thumbs")
|
||||||
|
)
|
||||||
|
if dataParams["img_tg"] == "true" or dataParams["img_tg"] == "True":
|
||||||
|
await bot.send_media_group(
|
||||||
|
chat_id=chatId, media=pilToImages(res, "tg")
|
||||||
|
)
|
||||||
|
if dataParams["img_real"] == "true" or dataParams["img_real"] == "True":
|
||||||
|
await bot.send_media_group(
|
||||||
|
chat_id=chatId, media=pilToImages(res, "real")
|
||||||
|
)
|
||||||
|
await bot.send_message(
|
||||||
|
chat_id=chatId,
|
||||||
|
text=data["prompt"] + "\n" + str(res.info["all_seeds"])
|
||||||
)
|
)
|
||||||
if dataParams["img_real"] == "true" or dataParams["img_real"] == "True":
|
# Удаляем сообщение с прогрессом
|
||||||
await bot.send_media_group(
|
await bot.delete_message(chat_id=msgTime.chat.id, message_id=msgTime.message_id)
|
||||||
chat_id=chatId, media=pilToImages(res, "real")
|
except Exception as e:
|
||||||
)
|
logging.error(f"gen error: {e}")
|
||||||
await bot.send_message(
|
|
||||||
chat_id=chatId,
|
|
||||||
text=data["prompt"] + "\n" + str(res.info["all_seeds"])
|
|
||||||
)
|
|
||||||
# Удаляем сообщение с прогрессом
|
|
||||||
await bot.delete_message(chat_id=msgTime.chat.id, message_id=msgTime.message_id)
|
|
||||||
keyboard = InlineKeyboardMarkup(inline_keyboard=[getSet(0), getOpt(0), getStart(0)])
|
keyboard = InlineKeyboardMarkup(inline_keyboard=[getSet(0), getOpt(0), getStart(0)])
|
||||||
await bot.send_message(
|
await bot.send_message(
|
||||||
chat_id=chatId,
|
chat_id=chatId,
|
||||||
@ -840,7 +859,7 @@ async def inl_rnd_inf(message: Union[types.Message, types.CallbackQuery]) -> Non
|
|||||||
# PROMPT
|
# PROMPT
|
||||||
if str(dataParams['use_prompt']).lower() == 'false':
|
if str(dataParams['use_prompt']).lower() == 'false':
|
||||||
t = requests.get('https://random-word-api.herokuapp.com/word?lang=en').text
|
t = requests.get('https://random-word-api.herokuapp.com/word?lang=en').text
|
||||||
text = json.loads(t)[0] # data["prompt"] # from JSON
|
text = json.loads(t)[0] # from JSON
|
||||||
prompt = get_random_prompt(text, 20)
|
prompt = get_random_prompt(text, 20)
|
||||||
prompt_lxc = random.choice(submit_get('https://lexica.art/api/v1/search?q=' + prompt, '').json()['images'])['prompt']
|
prompt_lxc = random.choice(submit_get('https://lexica.art/api/v1/search?q=' + prompt, '').json()['images'])['prompt']
|
||||||
data['prompt'] = prompt + ', ' + prompt_lxc
|
data['prompt'] = prompt + ', ' + prompt_lxc
|
||||||
@ -1057,6 +1076,7 @@ async def handle_file(message: types.Message):
|
|||||||
|
|
||||||
# -------- BOT POLLING ----------
|
# -------- BOT POLLING ----------
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
logging.info("starting bot")
|
||||||
executor.start_polling(dp, skip_updates=True)
|
executor.start_polling(dp, skip_updates=True)
|
||||||
|
|
||||||
# -------- COPYRIGHT ----------
|
# -------- COPYRIGHT ----------
|
||||||
|
Loading…
Reference in New Issue
Block a user