|
| 1 | +import os |
| 2 | +import chromadb |
| 3 | +import streamlit as st |
| 4 | + |
| 5 | +from langchain_community.document_loaders import HuggingFaceDatasetLoader, PyPDFLoader |
| 6 | +from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter |
| 7 | +from langchain_community.embeddings import HuggingFaceEmbeddings |
| 8 | +from langchain_community.vectorstores import FAISS, Chroma |
| 9 | + |
| 10 | +from langchain.chains import RetrievalQA |
| 11 | +from langchain_community.llms import LlamaCpp |
| 12 | +from langchain_core.prompts import PromptTemplate |
| 13 | +from langchain.callbacks.base import BaseCallbackHandler |
| 14 | +from langchain_core.runnables import RunnablePassthrough |
| 15 | +from langchain_core.output_parsers import StrOutputParser |
| 16 | + |
| 17 | +from transformers import AutoTokenizer, AutoModelForQuestionAnswering |
| 18 | +from transformers import AutoTokenizer, pipeline |
| 19 | +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler |
| 20 | + |
| 21 | +# for streaming in Streamlit without LECL |
| 22 | +class StreamHandler(BaseCallbackHandler): |
| 23 | + def __init__(self, container, initial_text=""): |
| 24 | + self.container = container |
| 25 | + self.text = initial_text |
| 26 | + |
| 27 | + def on_llm_new_token(self, token: str, **kwargs) -> None: |
| 28 | + self.text += token |
| 29 | + self.container.markdown(self.text) |
| 30 | + |
| 31 | +####################### Data processing for vectorstore ################################# |
| 32 | +pdf_folder_path = "./data_source" |
| 33 | +documents = [] |
| 34 | + |
| 35 | +for file in os.listdir(pdf_folder_path): |
| 36 | + if file.endswith('.pdf'): |
| 37 | + pdf_path = os.path.join(pdf_folder_path,file) |
| 38 | + loader = PyPDFLoader(pdf_path) |
| 39 | + documents.extend(loader.load()) |
| 40 | + |
| 41 | +text_splitter_rc = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) |
| 42 | +chunked_documents_rc = text_splitter_rc.split_documents(documents) |
| 43 | + |
| 44 | +####################### EMBEDDINGS ################################# |
| 45 | +model_path = "sentence-transformers/all-MiniLM-L6-v2" |
| 46 | +model_kwargs = {'device': 'mps'} |
| 47 | +encode_kwargs = {'normalize_embeddings': False} |
| 48 | +persist_directory="./vector_stores" |
| 49 | + |
| 50 | +if not os.path.exists(persist_directory): |
| 51 | + os.makedirs(persist_directory) |
| 52 | + |
| 53 | + |
| 54 | +embeddings = HuggingFaceEmbeddings( |
| 55 | + model_name=model_path, |
| 56 | + model_kwargs=model_kwargs, |
| 57 | + encode_kwargs=encode_kwargs |
| 58 | +) |
| 59 | + |
| 60 | + |
| 61 | +def format_docs(docs): |
| 62 | + return "\n\n".join([doc.page_content for doc in docs]) |
| 63 | + |
| 64 | +####################### RAG ################################# |
| 65 | + |
| 66 | + |
| 67 | +prompt_template = """Use the following pieces of context regarding titanic ship to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. |
| 68 | +
|
| 69 | +{context} |
| 70 | +
|
| 71 | +Question: {question} |
| 72 | +Helpful Answer: |
| 73 | +""" |
| 74 | + |
| 75 | +prompt = PromptTemplate(template=prompt_template, input_variables=['context', 'question']) |
| 76 | + |
| 77 | +#VectorDB creation and saving to disk |
| 78 | +client = chromadb.Client() |
| 79 | + |
| 80 | +persist_directory="/Users/raunakanand/Documents/Work_R/llm0/vector_stores" |
| 81 | +vectordb = Chroma.from_documents( |
| 82 | + documents=chunked_documents_rc, |
| 83 | + embedding=embeddings, |
| 84 | + persist_directory=persist_directory, |
| 85 | + collection_name='chroma1' |
| 86 | +) |
| 87 | +vectordb.persist() |
| 88 | + |
| 89 | +#VectorDB -loading from disk |
| 90 | +vectordb = Chroma(persist_directory=persist_directory, embedding_function=embeddings, collection_name='chroma1') |
| 91 | +retriever = vectordb.as_retriever(search_kwargs={"k": 3}) |
| 92 | + |
| 93 | + |
| 94 | +n_gpu_layers = 1 |
| 95 | +n_batch = 512 |
| 96 | +# stream_handler = StreamHandler(st.empty()) |
| 97 | + |
| 98 | +llm = LlamaCpp( |
| 99 | + model_path="/Users/raunakanand/Documents/Work_R/llm_models/mistral-7b-v0.1.Q4_K_S.gguf", |
| 100 | + n_gpu_layers=n_gpu_layers, |
| 101 | + n_batch=n_batch, |
| 102 | + n_ctx=1024, |
| 103 | + f16_kv=True, |
| 104 | + verbose=True, |
| 105 | + # streaming=True, |
| 106 | + # callbacks=[stream_handler] |
| 107 | + # callbacks=[StreamingStdOutCallbackHandler()] |
| 108 | +) |
| 109 | + |
| 110 | +qa = RetrievalQA.from_chain_type(llm=llm, chain_type='stuff', |
| 111 | + retriever=retriever, |
| 112 | + # return_source_documents=True, |
| 113 | + chain_type_kwargs={'prompt': prompt}, |
| 114 | + verbose=False) |
| 115 | + |
| 116 | +rag_chain = ({"context": retriever | format_docs, "question": RunnablePassthrough()} | |
| 117 | + prompt | llm | StrOutputParser()) |
| 118 | + |
| 119 | +def inference(query: str): |
| 120 | + # return qa.invoke(query)['result'] |
| 121 | + # return qa.run(query) |
| 122 | + return rag_chain.stream(query) |
| 123 | + |
| 124 | +print('final') |
| 125 | + |
| 126 | + |
| 127 | + |
0 commit comments