gpt2 to function

stop sd after run all rnd scripts
This commit is contained in:
Mikhail Guseletov 2023-06-27 14:47:35 +07:00
parent fe965a4084
commit bdecad585d

30
bot.py
View File

@ -61,7 +61,10 @@ data = getAttrtxt2img()
data['prompt'] = 'cat in space' # Ý
data['steps'] = 15
data['sampler_name'] = 'Euler a'
dataParams = {"img_thumb": "true", "img_tg": "true", "img_real": "true"}
dataParams = {"img_thumb": "true",
"img_tg": "true",
"img_real": "true",
"stop_sd": "true"}
dataOld = data.copy()
dataOldParams = dataParams.copy()
dataOrig = data.copy()
@ -150,8 +153,7 @@ def getJson(params=0):
return json_str
# генератор промптов https://huggingface.co/FredZhang7/distilgpt2-stable-diffusion-v2
def get_random_prompt():
text = data["prompt"] # from JSON
def get_random_prompt(text = data["prompt"], max_length = 120):
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
model = GPT2LMHeadModel.from_pretrained("FredZhang7/distilgpt2-stable-diffusion-v2")
@ -161,7 +163,7 @@ def get_random_prompt():
do_sample=True,
temperature=0.8,
top_k=8,
max_length=120,
max_length=max_length,
num_return_sequences=1,
repetition_penalty=1.2,
penalty_alpha=0.6,
@ -439,6 +441,8 @@ async def rnd_script(message, typeScript):
,
reply_markup=keyboard
)
if str(dataParams['stop_sd']).lower() == 'true':
await stop_sd()
# show thumb/tg/real
async def show_thumbs(chat_id, res):
@ -721,23 +725,7 @@ async def inl_rnd_inf(message: Union[types.Message, types.CallbackQuery]) -> Non
# PROMPT
t = requests.get('https://random-word-api.herokuapp.com/word?lang=en').text
text = json.loads(t)[0] # data["prompt"] # from JSON
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
model = GPT2LMHeadModel.from_pretrained("FredZhang7/distilgpt2-stable-diffusion-v2")
input_ids = tokenizer(text, return_tensors="pt").input_ids
txt = model.generate(
input_ids,
do_sample=True,
temperature=0.7,
top_k=4,
max_length=20,
num_return_sequences=1,
repetition_penalty=1.2,
penalty_alpha=0.5,
no_repeat_ngram_size=0,
early_stopping=True,
)
prompt = tokenizer.decode(txt[0], skip_special_tokens=True)
prompt = get_random_prompt(text, 20)
prompt_lxc = random.choice(submit_get('https://lexica.art/api/v1/search?q=' + prompt, '').json()['images'])['prompt']
data['prompt'] = prompt + ', ' + prompt_lxc
# SCALE