Enhancing Text-to-Image AI: Prompt Recommendation System for Stable Diffusion Using Qdrant Vector Search and RAG
azhar
Posted on January 18, 2024
Stable Diffusion has emerged as a groundbreaking text-to-image model, transforming the way digital art and image synthesis are approached. By converting textual descriptions into detailed and nuanced images, Stable Diffusion opens a world of possibilities for artists, designers, and content creators. However, the effectiveness of this technology hinges on the quality of the input prompts, which guide the AI in generating relevant images.
Before we proceed, let’s stay connected! Please consider following me on DEV, and don’t forget to connect with me on LinkedIn for a regular dose of data science and deep learning insights.” 🚀📊🤖
The Challenge of Prompting Stable Diffusion
Crafting the perfect prompt for Stable Diffusion is a nuanced art. The model responds to the intricacies of language, and a well-constructed prompt can lead to stunning visual outputs. Conversely, vague or poorly structured prompts may result in unsatisfactory images. The challenge for users is navigating through and understanding the vast array of potential prompts to find one that aligns with their vision.
Solution
To assist users in this task, a sophisticated system using Vector Search and Retrieval Augmented Generation (RAG) can be employed. This system aims to analyze a vast database of successful prompts, identifying and suggesting the most relevant ones to the user’s input, thus streamlining the process of initiating Stable Diffusion.
Vector Search — A Key Solution
Vector Search plays a pivotal role in this system. It involves transforming textual data into high-dimensional vectors using models like BGE embeddings. These vectors capture the semantic essence of the text, enabling the system to perform semantic searches. By comparing the vector of a user’s input with vectors from a prompt database, the system can identify the most semantically similar prompts.
Utilizing Qdrant for Vector Database
Qdrant, chosen for its efficiency and scalability, serves as the vector database. It offers fast indexing and querying capabilities, essential for handling large volumes of vector data. Qdrant’s support for different distance metrics and filtering options further enhances the search’s accuracy and relevance.
This system would involve several key steps:
1. Prompt Database Creation
Here’s compiling a diverse and comprehensive collection of prompts previously used with Stable Diffusion.
The Importance of Diversity and Comprehensiveness
- Diversity: This implies that the prompts should cover a wide range of subjects, styles, and themes. The goal is to encompass as many different types of imagery as possible — from landscapes and portraits to abstract art and specific object representations. Diversity ensures that the system can cater to a broad spectrum of user requests.
- Comprehensiveness: A comprehensive database is one that not only covers a wide range of subjects but also includes variations in the detail, complexity, and structure of the prompts. This includes prompts of varying lengths, different levels of descriptiveness, and diverse linguistic styles. A comprehensive database allows the system to understand and generate more nuanced and tailored prompts.
import numpy as np
import pandas as pd
import json, csv, os
from datasets import load_dataset
######################### Part 1: Load DiffusionDB ############################
from urllib.request import urlretrieve
# Download the parquet table
table_url = f'https://huggingface.co/datasets/poloclub/diffusiondb/resolve/main/metadata.parquet'
urlretrieve(table_url, 'metadata.parquet')
# Read the table using Pandas
raw_df = pd.read_parquet('metadata.parquet')
raw_df.head()
# Keep top 10K prompts
prompts_raw = raw_df['prompt'][0:10000]
del raw_df
######################### Part 2: Data Preparation ############################
# Remove prompts with word count less than 10
def filter_strings_with_word_count(strings):
filtered_strings = []
for text in strings:
words = text.split()
if len(words) >= 10:
filtered_strings.append(text)
return filtered_strings
prompts_filtered = filter_strings_with_word_count(prompts_raw)
# remove prompts with very high similarities
import Levenshtein
import concurrent.futures
def remove_similar_strings(strings, threshold):
unique_strings = []
step_counter = 0
def is_unique(s):
nonlocal unique_strings
for us in unique_strings:
distance = Levenshtein.distance(s, us)
if distance <= threshold:
return False
return True
with concurrent.futures.ThreadPoolExecutor() as executor:
for i, s in enumerate(strings):
if executor.submit(is_unique, s).result():
unique_strings.append(s)
# Print number of strings processed for every 1000 steps
#if (i + 1) % 1000 == 0:
# print(f"Processed {i + 1} strings")
return unique_strings
# Set a similarity threshold (adjust as needed)
similarity_threshold = 10 # Adjust threshold as desired
# Remove similar prompts
prompts_unique = remove_similar_strings(prompts_filtered, similarity_threshold)
########################## Part 3: Data Storage ###############################
# Specify the CSV file name
csv_file_name = "prompts_unique.csv"
# Open the CSV file for writing with UTF-8 encoding
with open(csv_file_name, mode="w", newline="", encoding="utf-8") as csv_file:
csv_writer = csv.writer(csv_file)
csv_writer.writerow(["prompt example"])
# Write each string as a separate row in the CSV file
for string in prompts_unique[0:1000]:
csv_writer.writerow([string])
This script is part of a pipeline to process a large dataset of prompts for a model like Stable Diffusion. It involves downloading and filtering this dataset to ensure the prompts are diverse and unique, and then storing a subset of these prompts in a CSV file for further use.
This kind of preprocessing is crucial for creating an effective dataset for tasks like training AI models or creating a prompt recommendation system.
2. Vector Embedding
We’re using a language model to convert these prompts into semantic vectors and indexing them in Qdrant.
- Semantic Representation: The vectors produced by the language model are not just random numbers. They are carefully structured so that similar prompts have similar vector representations. This similarity in vector space ideally reflects semantic similarity.
- High-Dimensional Space: These vectors usually exist in a high-dimensional space (hundreds or thousands of dimensions), enabling them to encapsulate a wide range of linguistic features.
model_name = "BAAI/bge-small-en-v1.5"
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
embeddings = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs={'device': 'cpu'},
encode_kwargs=encode_kwargs
)
# Debugging: Check if the file exists
file_path = 'prompts_unique.csv' # Or the correct relative path to your file
if not os.path.exists(file_path):
raise FileNotFoundError(f"The file {file_path} was not found.")
loader = CSVLoader(file_path=file_path, encoding='utf-8')
documents = loader.load()
index_from_loader = Qdrant.from_documents(
documents,
embeddings,
location=":memory:", # Local mode with in-memory storage only
collection_name="my_documents",
)
The conversion of prompts into semantic vectors and their subsequent indexing in a vector database like Qdrant is a foundational step in creating a prompt recommendation system for Stable Diffusion.
This process enables the system to understand and work with prompts in a machine-readable format, paving the way for advanced search and retrieval functions based on the semantic content of the prompts. This step is vital for leveraging the full capabilities of AI in generating relevant and effective prompts for text-to-image models.
3. Semantic Search Implementation
When a user inputs a prompt, the system converts it into a vector and performs a semantic search in Qdrant, retrieving closely related prompts.
def semantic_search(index, original_prompt): #rag client function
relevant_prompts = index.similarity_search(original_prompt)
list_prompts = []
for i in range(len(relevant_prompts)):
list_prompts.append(relevant_prompts[i].page_content)
return list_prompts
The Process of Semantic Search Implementation
User Input Conversion
- Initial Step: When a user inputs a prompt into the system, the first step is to interpret this input in a way that the machine understands — as a vector.
- The system employs a language model to convert the textual prompt into a high-dimensional vector. This process involves analyzing the linguistic characteristics of the prompt and encoding them into numerical form.
Performing the Semantic Search in Qdrant
- **Searching for Similar Vectors: **The user’s input vector is then used to query a vector database — in this case, Qdrant.
- How Qdrant Works: Qdrant has indexed a vast array of prompts (also converted into vectors) in its database. When it receives the vector representation of a user’s prompt, it performs a search to find the most similar vectors from its index.
- Semantic Similarity: The similarity between vectors is determined based on their positioning in the high-dimensional space. Vectors that are close to each other represent prompts that are semantically similar.
Retrieving Closely Related Prompts
- Result Generation: The output of this search is a list of prompts whose vectors are most similar to the vector of the user’s input. These are the prompts that, semantically, closely relate to what the user is looking for.
- Advantage Over Keyword Searches: This method is more efficient and accurate than traditional keyword searches as it understands and matches the context and nuances of the user’s input, rather than just matching words.
The implementation of semantic search within this system is a vital component that significantly enhances the user experience. It brings sophistication and precision to the process of finding the right prompts for text-to-image generation models, ensuring that the creative intent of the user is accurately captured and reflected in the AI-generated images.
4. Integration with RAG
The top results from the vector search are then fed into a RAG setup, which intelligently combines elements from these prompts with the user’s original input, refining the prompt further.
For the Retrieval Augmented Generation (RAG) component, we utilized the Mistral 7B model, sourced from LM Studio.
Integration Process
Combining with User’s Original Input:
- The RAG setup takes these top-ranked prompts and intelligently merges their elements with the user’s original input.
- This integration is crucial as it ensures that the essence of the user’s initial intent is preserved, while enriching it with ideas and expressions from the retrieved prompts.
Refining the Prompt
- The RAG language model then works on this combined input to generate a new, refined prompt.
- This refinement process involves creatively fusing the various elements, ensuring that the new prompt is not only relevant but also likely to produce more effective and accurate results when used in a text-to-image model.
# LM Studio Endpoint URL
url = "http://localhost:1234/v1/chat/completions"
# Headers
headers = {
"Content-Type": "application/json"
}
# Data payload
data = {
"messages": [
{"role": "system", "content": "This app is to generate prompt for image generation. the user will provide Original Prompt for image generation. Based on Selected prompt, Only slightly revise Original Prompt. \
Please keep the Generated Prompt clear, complete, and less than 50 words. "},
{"role": "user", "content": f"""Original Prompt: {original_prompt}\n\n
Selected Prompt: {selected_prompt}\n\n
Generated Prompt: """}
],
"temperature": 0.7,
"max_tokens": -1,
"stream": False
}
# Make the POST request
response = requests.post(url, headers=headers, data=json.dumps(data))
# Check if the request was successful
if response.status_code == 200:
print("Success:")
data = response.json()
message = data['choices'][0]['message']['content']
return message
else:
print("Error:")
return response.text
It not only streamlines the process of prompt creation for complex models like Stable Diffusion but also elevates the quality and effectiveness of these prompts. This approach showcases how the combination of retrieval and generative techniques can lead to innovative solutions in AI applications.
5. Prompt Testing with Stable Diffusion
The refined prompts can be tested with the Stable Diffusion model to demonstrate their effectiveness in generating high-quality images.
The primary goal of prompt testing is to evaluate how well the refined prompts perform when used with the Stable Diffusion text-to-image model. This involves feeding the refined prompts into Stable Diffusion and analyzing the quality, relevance, and accuracy of the images produced.
The goal is a user-friendly system that significantly reduces the time and effort needed to discover effective prompts for Stable Diffusion. By leveraging Vector Search and RAG, users can quickly find and refine prompts, leading to more satisfying and relevant image generation outcomes.
Code
GitHub Code : Vector Search and RAG for Stable Diffusion using Qdrant DB
Conclusion
The integration of Vector Search and RAG into the process of generating prompts for Stable Diffusion represents a significant step forward in democratizing AI-driven art creation. It addresses a key challenge faced by many users of these advanced models and opens up new avenues for creative expression. As these technologies continue to evolve, we can expect even more sophisticated tools and systems to emerge, further enhancing the accessibility and utility of AI in artistic and design endeavors.
Posted on January 18, 2024
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.
Related
January 18, 2024