1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- import streamlit as st
- import asyncio
- from autogen import AssistantAgent, UserProxyAgent
- st.write("""# AutoGen Chat Agents""")
- class TrackableAssistantAgent(AssistantAgent):
- def _process_received_message(self, message, sender, silent):
- with st.chat_message(sender.name):
- st.markdown(message)
- return super()._process_received_message(message, sender, silent)
- class TrackableUserProxyAgent(UserProxyAgent):
- def _process_received_message(self, message, sender, silent):
- with st.chat_message(sender.name):
- st.markdown(message)
- return super()._process_received_message(message, sender, silent)
- selected_model = None
- selected_key = None
- with st.sidebar:
- st.header("OpenAI Configuration")
- selected_model = st.selectbox("Model", ['gpt-3.5-turbo', 'gpt-4'], index=1)
- selected_key = st.text_input("API Key", type="password")
- with st.container():
- # for message in st.session_state["messages"]:
- # st.markdown(message)
- user_input = st.chat_input("Type something...")
- if user_input:
- if not selected_key or not selected_model:
- st.warning(
- 'You must provide valid OpenAI API key and choose preferred model', icon="⚠️")
- st.stop()
- llm_config = {
- "request_timeout": 600,
- "config_list": [
- {
- "model": selected_model,
- "api_key": selected_key
- }
- ]
- }
- # create an AssistantAgent instance named "assistant"
- assistant = TrackableAssistantAgent(
- name="assistant", llm_config=llm_config)
- # create a UserProxyAgent instance named "user"
- user_proxy = TrackableUserProxyAgent(
- name="user", human_input_mode="NEVER", llm_config=llm_config)
- # Create an event loop
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- # Define an asynchronous function
- async def initiate_chat():
- await user_proxy.a_initiate_chat(
- assistant,
- message=user_input,
- )
- # Run the asynchronous function within the event loop
- loop.run_until_complete(initiate_chat())
|