Memory error when running python RAG LLM
Jude Gigy
Posted on July 22, 2024
Everytime I run this LLM I get a memory error. Please help.
importimport os
os
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
from transformers import pipeline
Replace with the path to your local folder containing the text files
folder_path = "C:\Users\asokw\Downloads\new"
Function to read and process text files
def read_text_files(folder_path):
all_files = os.listdir(folder_path)
text_files = [os.path.join(folder_path, f) for f in all_files if f.endswith('.txt')]
documents = []
for file_path in text_files:
with open(file_path, 'r', encoding='utf-8') as file:
content = file.read()
documents.append(content)
return documents
Load and preprocess documents
documents = read_text_files(folder_path)
Initialize RAG tokenizer, retriever, and model
tokenizer = RagTokenizer.from_pretrained('facebook/rag-token-base')
retriever = RagRetriever.from_pretrained('facebook/rag-token-base', index_name='exact', passages=documents)
model = RagTokenForGeneration.from_pretrained('facebook/rag-token-base', retriever=retriever)
your_prompt = "What information can be found in these documents?"
inputs = tokenizer(your_prompt, return_tensors="pt")
retrieval_output = model.get_retrieval_vector(inputs)
generation_inputs = {
"input_ids": inputs.input_ids,
"attention_mask": inputs.attention_mask,
"retrieval_logits": retrieval_output,
}
generation_output = model.generate(**generation_inputs)
generated_text = tokenizer.decode(generation_output.sequences[0])
print(f"Retrieved documents:", retrieval_output)
print(f"Generated text:", generated_text)
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
from transformers import pipeline
Posted on July 22, 2024
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.