Support loading textual inversion embeddings from safetensors files

This commit is contained in:
Lee Bousfield 2023-01-10 18:40:34 -07:00
parent 9cfd10cdef
commit f9706acf43
No known key found for this signature in database
GPG Key ID: 51137D1C9B477CBD

View File

@ -9,6 +9,7 @@ import tqdm
import html import html
import datetime import datetime
import csv import csv
import safetensors.torch
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
@ -150,6 +151,8 @@ class EmbeddingDatabase:
name = data.get('name', name) name = data.get('name', name)
elif ext in ['.BIN', '.PT']: elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu") data = torch.load(path, map_location="cpu")
elif ext in ['.SAFETENSORS']:
data = safetensors.torch.load_file(path, device="cpu")
else: else:
return return