1010import scipy .spatial .distance
1111import streamlit as st
1212import workflows .question_answering .classes as classes
13- import workflows .question_answering .config as config
1413import workflows .question_answering .functions as functions
1514import workflows .question_answering .prompts as prompts
1615from util import ui_components
@@ -110,8 +109,10 @@ def create(sv: SessionVariables, workflow = None):
110109 iteration = 0
111110 source_counts = Counter ()
112111 used_chunks = set ()
112+ functions_embedder = functions .embedder ()
113+
113114 while True :
114- qe = np .array (functions . embedder .embed_store_one (question , sv_home .save_cache .value ))
115+ qe = np .array (functions_embedder .embed_store_one (question , sv_home .save_cache .value ))
115116 iteration += 1
116117 cosine_distances = sorted ([(t , c , scipy .spatial .distance .cosine (qe , v )) for (t , c , v ) in all_units ], key = lambda x :x [2 ], reverse = False )
117118 chunk_index = sv .answering_target_matches .value
@@ -185,15 +186,16 @@ def create(sv: SessionVariables, workflow = None):
185186 qas_raw = ui_components .generate_text (messages , callbacks = [on_callback ])
186187 status_history += qas_raw + '<br/><br/>'
187188 try :
189+ functions_embedder = functions .embedder ()
188190 qas = json .loads (qas_raw )
189191 for qa in qas :
190192 q = qa ['question' ]
191193 a = qa ['answer' ]
192194 raw_refs = qa ['source' ]
193195 file_page_refs = [tuple ([int (x [1 :]) for x in r .split (';' )]) for r in raw_refs ]
194196
195- q_vec = np .array (functions . embedder .embed_store_one (q , sv_home .save_cache .value ))
196- a_vec = np .array (functions .embedder .embed_store_one (a , sv_home .save_cache .value ))
197+ q_vec = np .array (functions_embedder .embed_store_one (q , sv_home .save_cache .value ))
198+ a_vec = np .array (functions_embedder .embedder .embed_store_one (a , sv_home .save_cache .value ))
197199
198200 qid = sv .answering_next_q_id .value
199201 sv .answering_next_q_id .value += 1
@@ -209,7 +211,7 @@ def create(sv: SessionVariables, workflow = None):
209211 if t == 'chunk' and c [0 ].id == f .id and c [1 ] == cx :
210212 all_units .remove ((t , c , v ))
211213
212- status_history += f 'Augmenting user question with partial answers:<br/>'
214+ status_history += 'Augmenting user question with partial answers:<br/>'
213215 new_question = functions .update_question (sv , sv .answering_question_history .value , new_questions , lazy_answering_placeholder , status_history )
214216 status_history += new_question + '<br/><br/>'
215217 sv .answering_question_history .value .append (new_question )
0 commit comments