Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| from io import StringIO | |
| from llama_index.llms import HuggingFaceInferenceAPI | |
| from llama_index.embeddings import HuggingFaceInferenceAPIEmbedding | |
| from llama_index import ServiceContext, VectorStoreIndex | |
| from llama_index.schema import Document | |
| import uuid | |
| from llama_index.vector_stores.types import MetadataFilters, ExactMatchFilter | |
| from typing import List | |
| from pydantic import BaseModel | |
| inference_api_key = st.secrets["INFRERENCE_API_TOKEN"] | |
| # embed_model_name = st.text_input( | |
| # 'Embed Model name', "Gooly/gte-small-en-fine-tuned-e-commerce") | |
| # llm_model_name = st.text_input( | |
| # 'Embed Model name', "mistralai/Mistral-7B-Instruct-v0.2") | |
| class PriceModel(BaseModel): | |
| """Data model for price""" | |
| price: str | |
| embed_model_name = "jinaai/jina-embedding-s-en-v1" | |
| llm_model_name = "mistralai/Mistral-7B-Instruct-v0.2" | |
| llm = HuggingFaceInferenceAPI( | |
| model_name=llm_model_name, token=inference_api_key) | |
| embed_model = HuggingFaceInferenceAPIEmbedding( | |
| model_name=embed_model_name, | |
| token=inference_api_key, | |
| model_kwargs={"device": ""}, | |
| encode_kwargs={"normalize_embeddings": True}, | |
| ) | |
| service_context = ServiceContext.from_defaults( | |
| embed_model=embed_model, llm=llm) | |
| query = st.text_input( | |
| 'Query', "What is the price of the product?" | |
| ) | |
| html_file = st.file_uploader("Upload a html file", type=["html"]) | |
| if html_file is not None: | |
| stringio = StringIO(html_file.getvalue().decode("utf-8")) | |
| string_data = stringio.read() | |
| with st.expander("Uploaded HTML"): | |
| st.write(string_data) | |
| document_id = str(uuid.uuid4()) | |
| document = Document(text=string_data) | |
| document.metadata["id"] = document_id | |
| documents = [document] | |
| filters = MetadataFilters( | |
| filters=[ExactMatchFilter(key="id", value=document_id)]) | |
| index = VectorStoreIndex.from_documents( | |
| documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context) | |
| query_engine = index.as_query_engine( | |
| filters=filters, service_context=service_context, response_mode="tree_summarize", output_cls=PriceModel) | |
| response = query_engine.query(query) | |
| st.write(response.response) | |
| st.write(f'Price: {response.price}') | |
| # if st.button('Start Pipeline'): | |
| # if html_file is not None and embed_model_name is not None and llm_model_name is not None and query is not None: | |
| # st.write('Running Pipeline') | |
| # llm = HuggingFaceInferenceAPI( | |
| # model_name=llm_model_name, token=inference_api_key) | |
| # embed_model = HuggingFaceInferenceAPIEmbedding( | |
| # model_name=embed_model_name, | |
| # token=inference_api_key, | |
| # model_kwargs={"device": ""}, | |
| # encode_kwargs={"normalize_embeddings": True}, | |
| # ) | |
| # service_context = ServiceContext.from_defaults( | |
| # embed_model=embed_model, llm=llm) | |
| # stringio = StringIO(html_file.getvalue().decode("utf-8")) | |
| # string_data = stringio.read() | |
| # with st.expander("Uploaded HTML"): | |
| # st.write(string_data) | |
| # document_id = str(uuid.uuid4()) | |
| # document = Document(text=string_data) | |
| # document.metadata["id"] = document_id | |
| # documents = [document] | |
| # filters = MetadataFilters( | |
| # filters=[ExactMatchFilter(key="id", value=document_id)]) | |
| # index = VectorStoreIndex.from_documents( | |
| # documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context) | |
| # retriever = index.as_retriever() | |
| # ranked_nodes = retriever.retrieve( | |
| # query) | |
| # with st.expander("Ranked Nodes"): | |
| # for node in ranked_nodes: | |
| # st.write(node.node.get_content(), "-> Score:", node.score) | |
| # query_engine = index.as_query_engine( | |
| # filters=filters, service_context=service_context) | |
| # response = query_engine.query(query) | |
| # st.write(response.response) | |
| # st.write(response.source_nodes) | |
| # else: | |
| # st.error('Please fill in all the fields') | |
| # else: | |
| # st.write('Press start to begin') | |
| # # if html_file is not None: | |
| # # stringio = StringIO(html_file.getvalue().decode("utf-8")) | |
| # # string_data = stringio.read() | |
| # # with st.expander("Uploaded HTML"): | |
| # # st.write(string_data) | |
| # # document_id = str(uuid.uuid4()) | |
| # # document = Document(text=string_data) | |
| # # document.metadata["id"] = document_id | |
| # # documents = [document] | |
| # # filters = MetadataFilters( | |
| # # filters=[ExactMatchFilter(key="id", value=document_id)]) | |
| # # index = VectorStoreIndex.from_documents( | |
| # # documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context) | |
| # # retriever = index.as_retriever() | |
| # # ranked_nodes = retriever.retrieve( | |
| # # "Get me all the information about the product") | |
| # # with st.expander("Ranked Nodes"): | |
| # # for node in ranked_nodes: | |
| # # st.write(node.node.get_content(), "-> Score:", node.score) | |
| # # query_engine = index.as_query_engine( | |
| # # filters=filters, service_context=service_context) | |
| # # response = query_engine.query( | |
| # # "Get me all the information about the product") | |
| # # st.write(response) | |