Skip to content

Commit a6436fb

Browse files
committed
fix errors and add option to use azure openai with key
1 parent bfa879d commit a6436fb

File tree

11 files changed

+83
-29
lines changed

11 files changed

+83
-29
lines changed

app/pages/Settings.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from util.openai_wrapper import (
99
UIOpenAIConfiguration,
1010
key,
11+
openai_azure_auth_type,
1112
openai_azure_model_key,
1213
openai_endpoint_key,
1314
openai_type_key,
@@ -33,7 +34,6 @@ def main():
3334
st.set_page_config(layout="wide", initial_sidebar_state="collapsed", page_icon="app/myapp.ico", page_title='Intelligence Toolkit | Settings')
3435
load_multipage_app()
3536
openai_config = UIOpenAIConfiguration().get_configuration()
36-
print('openai_config', openai_config.api_type)
3737
st.header("Settings")
3838
secrets_handler = SecretsHandler()
3939

@@ -51,6 +51,14 @@ def main():
5151
st.rerun()
5252

5353
if type_input == "Azure OpenAI":
54+
types_az = ["Managed Identity", "Azure Key"]
55+
index_az = types_az.index(openai_config.az_auth_type) if openai_config.az_auth_type in types_az else 0
56+
type_input_az = st.radio("Azure OpenAI Auth Type", types_az, index=index_az, disabled=is_mode_cloud)
57+
if type_input_az != openai_config.az_auth_type:
58+
print('type_input_az', type_input_az)
59+
print('openai_config.az_auth_type', openai_config.az_auth_type)
60+
on_change(secrets_handler, openai_azure_auth_type, type_input_az)()
61+
st.rerun()
5462
col1, col2, col3 = st.columns(3)
5563
with col1:
5664
endpoint = st.text_input("Azure OpenAI Endpoint", disabled=is_mode_cloud, type="password", value=openai_config.api_base)
@@ -69,9 +77,9 @@ def main():
6977
if version != openai_config.api_version:
7078
on_change(secrets_handler, openai_version_key, version)()
7179
st.rerun()
72-
else:
80+
if type_input == "OpenAI" or type_input_az != "Managed Identity":
7381
placeholder = "Enter key here..."
74-
secret_input = st.text_input('Enter your OpenAI key', type="password", disabled=is_mode_cloud, placeholder=placeholder, value=secret)
82+
secret_input = st.text_input('Enter your key', type="password", disabled=is_mode_cloud, placeholder=placeholder, value=secret)
7583

7684
if secret and len(secret) > 0:
7785
st.info("Your key is saved securely.")

app/util/openai_wrapper.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
openai_version_key = 'openai_version'
1111
openai_endpoint_key = 'openai_endpoint'
1212
openai_azure_model_key = 'openai_azure_model'
13+
openai_azure_auth_type = 'openai_azure_auth_type'
1314

1415
class UIOpenAIConfiguration():
1516
def __init__(
@@ -23,13 +24,15 @@ def get_configuration(self):
2324
endpoint = self._secrets.get_secret(openai_endpoint_key) or None
2425
secret_key = self._secrets.get_secret(key) or None
2526
model = self._secrets.get_secret(openai_azure_model_key) or None
27+
az_auth_type = self._secrets.get_secret(openai_azure_auth_type) or None
2628

2729
config = {
2830
'api_type': type,
2931
'api_version': version,
3032
'api_base': endpoint,
3133
'api_key': secret_key,
32-
'model': model
34+
'model': model,
35+
'az_auth_type': az_auth_type
3336
}
3437
values = {k: v for k, v in config.items() if v is not None}
3538
return OpenAIConfiguration(values)

app/workflows/question_answering/functions.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,15 @@
1818
from python.AI.text_splitter import TextSplitter
1919

2020
sv_home = SessionVariables('home')
21-
ai_configuration = UIOpenAIConfiguration().get_configuration()
2221

23-
embedder = Embedder(ai_configuration, config.cache_dir)
22+
23+
def embedder():
24+
try:
25+
ai_configuration = UIOpenAIConfiguration().get_configuration()
26+
return Embedder(ai_configuration, config.cache_dir)
27+
except Exception as e:
28+
st.error(f'Error creating connection: {e}')
29+
st.stop()
2430

2531
def chunk_files(sv, files):
2632
pb = st.progress(0, 'Chunking files...')
@@ -52,11 +58,11 @@ def chunk_files(sv, files):
5258
for chunk in chunks:
5359
file_chunks.append((file, chunk))
5460
file.set_text(doc_text)
55-
61+
functions_embedder = embedder()
5662
for cx, (file, chunk) in enumerate(file_chunks):
5763
pb.progress((cx+1) / len(file_chunks), f'Embedding chunk {cx+1} of {len(file_chunks)}...')
5864
formatted_chunk = chunk.replace("\n", " ")
59-
chunk_vec = embedder.embed_store_one(formatted_chunk, sv_home.save_cache.value)
65+
chunk_vec = functions_embedder.embed_store_one(formatted_chunk, sv_home.save_cache.value)
6066
file.add_chunk(chunk, np.array(chunk_vec), cx+1)
6167
pb.empty()
6268

app/workflows/question_answering/workflow.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import scipy.spatial.distance
1111
import streamlit as st
1212
import workflows.question_answering.classes as classes
13-
import workflows.question_answering.config as config
1413
import workflows.question_answering.functions as functions
1514
import workflows.question_answering.prompts as prompts
1615
from util import ui_components
@@ -110,8 +109,10 @@ def create(sv: SessionVariables, workflow = None):
110109
iteration = 0
111110
source_counts = Counter()
112111
used_chunks = set()
112+
functions_embedder = functions.embedder()
113+
113114
while True:
114-
qe = np.array(functions.embedder.embed_store_one(question, sv_home.save_cache.value))
115+
qe = np.array(functions_embedder.embed_store_one(question, sv_home.save_cache.value))
115116
iteration += 1
116117
cosine_distances = sorted([(t, c, scipy.spatial.distance.cosine(qe, v)) for (t, c, v) in all_units], key=lambda x:x[2], reverse=False)
117118
chunk_index = sv.answering_target_matches.value
@@ -185,15 +186,16 @@ def create(sv: SessionVariables, workflow = None):
185186
qas_raw = ui_components.generate_text(messages, callbacks=[on_callback])
186187
status_history += qas_raw + '<br/><br/>'
187188
try:
189+
functions_embedder = functions.embedder()
188190
qas = json.loads(qas_raw)
189191
for qa in qas:
190192
q = qa['question']
191193
a = qa['answer']
192194
raw_refs = qa['source']
193195
file_page_refs = [tuple([int(x[1:]) for x in r.split(';')]) for r in raw_refs]
194196

195-
q_vec = np.array(functions.embedder.embed_store_one(q, sv_home.save_cache.value))
196-
a_vec = np.array(functions.embedder.embed_store_one(a, sv_home.save_cache.value))
197+
q_vec = np.array(functions_embedder.embed_store_one(q, sv_home.save_cache.value))
198+
a_vec = np.array(functions_embedder.embedder.embed_store_one(a, sv_home.save_cache.value))
197199

198200
qid = sv.answering_next_q_id.value
199201
sv.answering_next_q_id.value += 1
@@ -209,7 +211,7 @@ def create(sv: SessionVariables, workflow = None):
209211
if t == 'chunk' and c[0].id == f.id and c[1] == cx:
210212
all_units.remove((t, c, v))
211213

212-
status_history += f'Augmenting user question with partial answers:<br/>'
214+
status_history += 'Augmenting user question with partial answers:<br/>'
213215
new_question = functions.update_question(sv, sv.answering_question_history.value, new_questions, lazy_answering_placeholder, status_history)
214216
status_history += new_question + '<br/><br/>'
215217
sv.answering_question_history.value.append(new_question)

app/workflows/record_matching/functions.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT license. See LICENSE file in the project.
33
#
4+
import streamlit as st
45
from util.openai_wrapper import UIOpenAIConfiguration
56
from workflows.record_matching import config
67

78
from python.AI.embedder import Embedder
89

9-
ai_configuration = UIOpenAIConfiguration().get_configuration()
10-
embedder = Embedder(ai_configuration, config.cache_dir)
10+
11+
def embedder():
12+
try:
13+
ai_configuration = UIOpenAIConfiguration().get_configuration()
14+
return Embedder(ai_configuration, config.cache_dir)
15+
except Exception as e:
16+
st.error(f'Error creating connection: {e}')
17+
st.stop()
1118

1219
def convert_to_sentences(df, skip):
1320
sentences = []

app/workflows/record_matching/workflow.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,9 @@ def on_embedding_batch_change(current, total):
194194

195195
callback = classes.BatchEmbeddingCallback()
196196
callback.on_embedding_batch_change = on_embedding_batch_change
197-
embeddings = functions.embedder.embed_store_many(all_sentences,[callback], sv_home.save_cache.value)
197+
functions_embedder = functions.embedder()
198+
199+
embeddings = functions_embedder.embed_store_many(all_sentences,[callback], sv_home.save_cache.value)
198200
pb.empty()
199201

200202
nbrs = NearestNeighbors(n_neighbors=50, n_jobs=1, algorithm='auto', leaf_size=20, metric='cosine').fit(embeddings)

app/workflows/risk_networks/functions.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,21 @@
66

77
import networkx as nx
88
import pandas as pd
9+
import streamlit as st
910
import workflows.risk_networks.config as config
1011
from streamlit_agraph import Config, Edge, Node
1112
from util.openai_wrapper import UIOpenAIConfiguration
1213

1314
from python.AI.embedder import Embedder
1415

15-
ai_configuration = UIOpenAIConfiguration().get_configuration()
16-
embedder = Embedder(ai_configuration, config.cache_dir)
16+
17+
def embedder():
18+
try:
19+
ai_configuration = UIOpenAIConfiguration().get_configuration()
20+
return Embedder(ai_configuration, config.cache_dir)
21+
except Exception as e:
22+
st.error(f'Error creating connection: {e}')
23+
st.stop()
1724

1825
def hsl_to_hex(h, s, l):
1926
rgb = colorsys.hls_to_rgb(h / 360, l / 100, s / 100)

app/workflows/risk_networks/workflow.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,8 @@ def on_embedding_batch_change(current, total):
255255

256256
callback = classes.BatchEmbeddingCallback()
257257
callback.on_embedding_batch_change = on_embedding_batch_change
258-
embeddings = functions.embedder.embed_store_many(texts,[callback], sv_home.save_cache.value)
258+
functions_embedder = functions.embedder()
259+
embeddings = functions_embedder.embed_store_many(texts,[callback], sv_home.save_cache.value)
259260
pb.empty()
260261

261262
vals = [(n, t, e) for (n, t), e in zip(text_types, embeddings)]

python/AI/client.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import logging
55
from typing import List
66

7-
from openai import AzureOpenAI, OpenAI
87
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
8+
from openai import AzureOpenAI, OpenAI
9+
910
from .classes import LLMCallback
1011
from .defaults import API_BASE_REQUIRED_FOR_AZURE, DEFAULT_EMBEDDING_MODEL
1112
from .openai_configuration import OpenAIConfiguration
@@ -31,16 +32,24 @@ def create_openai_client(self) -> None:
3132
api_base,
3233
)
3334

34-
token_provider = get_bearer_token_provider(
35-
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
36-
)
35+
if self.configuration.az_auth_type == 'Managed Identity':
36+
token_provider = get_bearer_token_provider(
37+
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
38+
)
3739

38-
self._client = AzureOpenAI(
39-
api_version=self.configuration.api_version,
40-
# Azure-Specifics
41-
azure_ad_token_provider=token_provider,
42-
azure_endpoint=api_base,
43-
)
40+
self._client = AzureOpenAI(
41+
api_version=self.configuration.api_version,
42+
# Azure-Specifics
43+
azure_ad_token_provider=token_provider,
44+
azure_endpoint=api_base,
45+
)
46+
else:
47+
self._client = AzureOpenAI(
48+
api_version=self.configuration.api_version,
49+
# Azure-Specifics
50+
azure_endpoint=api_base,
51+
api_key=self.configuration.api_key,
52+
)
4453
else:
4554
log.info("Creating OpenAI client")
4655
self._client = OpenAI(

python/AI/defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
DEFAULT_LLM_MODEL = "gpt-4o"
88
DEFAULT_AZURE_LLM_MODEL = "gpt-4o"
99
DEFAULT_LLM_MAX_TOKENS = 4000
10+
DEFAULT_AZ_AUTH_TYPE = "Managed Identity"
1011
#
1112
# Text Embedding Parameters
1213
DEFAULT_EMBEDDING_MODEL = "text-embedding-ada-002"

0 commit comments

Comments
 (0)