1. 原始代码

import asyncio
from inspect import iscoroutinefunction
from typing import Awaitable, Callable, List, Optional, Sequence, Union, cast

from autogen_core.base import CancellationToken

from ..base import Response
from ..messages import ChatMessage, HandoffMessage, TextMessage
from ._base_chat_agent import BaseChatAgent

# Define input function types more precisely
SyncInputFunc = Callable[[str], str]
AsyncInputFunc = Callable[[str, Optional[CancellationToken]], Awaitable[str]]
InputFuncType = Union[SyncInputFunc, AsyncInputFunc]


class UserProxyAgent(BaseChatAgent):
    """An agent that can represent a human user through an input function.

    This agent can be used to represent a human user in a chat system by providing a custom input function.

    Args:
        name (str): The name of the agent.
        description (str, optional): A description of the agent.
        input_func (Optional[Callable[[str], str]], Callable[[str, Optional[CancellationToken]], Awaitable[str]]): A function that takes a prompt and returns a user input string.

    .. note::

        Using :class:`UserProxyAgent` puts a running team in a temporary blocked
        state until the user responds. So it is important to time out the user input
        function and cancel using the :class:`~autogen_core.base.CancellationToken` if the user does not respond.
        The input function should also handle exceptions and return a default response if needed.

        For typical use cases that involve
        slow human responses, it is recommended to use termination conditions
        such as :class:`~autogen_agentchat.task.HandoffTermination` or :class:`~autogen_agentchat.task.SourceMatchTermination`
        to stop the running team and return the control to the application.
        You can run the team again with the user input. This way, the state of the team
        can be saved and restored when the user responds.

        See `Pause for User Input <https://microsoft.github.io/autogen/dev/user-guide/agentchat-user-guide/tutorial/teams.html#pause-for-user-input>`_ for more information.

    """

    def __init__(
        self,
        name: str,
        *,
        description: str = "A human user",
        input_func: Optional[InputFuncType] = None,
    ) -> None:
        """Initialize the UserProxyAgent."""
        super().__init__(name=name, description=description)
        self.input_func = input_func or input
        self._is_async = iscoroutinefunction(self.input_func)

    @property
    def produced_message_types(self) -> List[type[ChatMessage]]:
        """Message types this agent can produce."""
        return [TextMessage, HandoffMessage]

    def _get_latest_handoff(self, messages: Sequence[ChatMessage]) -> Optional[HandoffMessage]:
        """Find the HandoffMessage in the message sequence that addresses this agent."""
        if len(messages) > 0 and isinstance(messages[-1], HandoffMessage):
            if messages[-1].target == self.name:
                return messages[-1]
            else:
                raise RuntimeError(f"Handoff message target does not match agent name: {messages[-1].source}")
        return None

    async def _get_input(self, prompt: str, cancellation_token: Optional[CancellationToken]) -> str:
        """Handle input based on function signature."""
        try:
            if self._is_async:
                # Cast to AsyncInputFunc for proper typing
                async_func = cast(AsyncInputFunc, self.input_func)
                return await async_func(prompt, cancellation_token)
            else:
                # Cast to SyncInputFunc for proper typing
                sync_func = cast(SyncInputFunc, self.input_func)
                loop = asyncio.get_event_loop()
                return await loop.run_in_executor(None, sync_func, prompt)

        except asyncio.CancelledError:
            raise
        except Exception as e:
            raise RuntimeError(f"Failed to get user input: {str(e)}") from e

    async def on_messages(
        self, messages: Sequence[ChatMessage], cancellation_token: Optional[CancellationToken] = None
    ) -> Response:
        """Handle incoming messages by requesting user input."""
        try:
            # Check for handoff first
            handoff = self._get_latest_handoff(messages)
            prompt = (
                f"Handoff received from {handoff.source}. Enter your response: " if handoff else "Enter your response: "
            )
            # print(prompt)
            user_input = await self._get_input(prompt, cancellation_token)
            # print(user_input)
            # Return appropriate message type based on handoff presence
            if handoff:
                return Response(
                    chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name)
                )
            else:
                return Response(chat_message=TextMessage(content=user_input, source=self.name))

        except asyncio.CancelledError:
            raise
        except Exception as e:
            raise RuntimeError(f"Failed to get user input: {str(e)}") from e

    async def on_reset(self, cancellation_token: Optional[CancellationToken] = None) -> None:
        """Reset agent state."""
        pass

2. 代码测试

import asyncio
from typing import Optional, Sequence

import pytest
from autogen_agentchat.agents import UserProxyAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import ChatMessage, HandoffMessage, TextMessage
from autogen_core.base import CancellationToken


def custom_input(prompt: str) -> str:
    return "The height of the eiffel tower is 324 meters. Aloha!"


agent = UserProxyAgent(name="test_user", input_func=custom_input)
messages = [
    TextMessage(content="What is the height of the eiffel tower?", source="assistant")
]


response = asyncio.run(agent.on_messages(messages, CancellationToken()))

print(response)

运行结果

Response(chat_message=TextMessage(source='test_user', models_usage=None, content='The height of the eiffel tower is 324 meters. Aloha!'), inner_messages=None)

3. 代码的运行逻辑

这段代码主要分为两部分:

  • UserProxyAgent 类的定义:这是一个模拟人类用户交互的代理类,接收用户输入并根据 input_func(输入函数)来处理输入。
  • 使用 UserProxyAgent 的示例:通过创建 UserProxyAgent 实例并调用它的 on_messages 方法来模拟一个简单的交互过程。

我们一步步分析这段代码的运行逻辑:

4. UserProxyAgent 类的核心功能

  1. __init__ 方法

    def __init__(self, name: str, *, description: str = "A human user", input_func: Optional[InputFuncType] = None) -> None:
        super().__init__(name=name, description=description)
        self.input_func = input_func or input  # 默认使用内置 input 函数
        self._is_async = iscoroutinefunction(self.input_func)  # 检查 input_func 是否是异步函数
    
    • input_func:用于获取用户输入的函数。可以是同步或异步函数。
    • self._is_async:判断 input_func 是同步的还是异步的。如果 input_func 是异步函数(例如 async def),则 self._is_asyncTrue
  2. produced_message_types 属性

    @property
    def produced_message_types(self) -> List[type[ChatMessage]]:
        return [TextMessage, HandoffMessage]
    

    这个属性定义了 UserProxyAgent 可以产生的消息类型。这里返回了 TextMessageHandoffMessage,这两种类型的消息可以由代理生成。

  3. _get_input 方法

    async def _get_input(self, prompt: str, cancellation_token: Optional[CancellationToken]) -> str:
        try:
            if self._is_async:
                async_func = cast(AsyncInputFunc, self.input_func)
                return await async_func(prompt, cancellation_token)
            else:
                sync_func = cast(SyncInputFunc, self.input_func)
                loop = asyncio.get_event_loop()
                return await loop.run_in_executor(None, sync_func, prompt)
        except Exception as e:
            raise RuntimeError(f"Failed to get user input: {str(e)}")
    

    _get_input 方法是根据 input_func 的类型(异步或同步)来获取用户输入:

    • 如果是异步函数,直接调用 await async_func(prompt, cancellation_token)
    • 如果是同步函数,使用 run_in_executor 将同步函数放入线程池中异步执行。
  4. on_messages 方法

    async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: Optional[CancellationToken] = None) -> Response:
        try:
            # Check for handoff first
            handoff = self._get_latest_handoff(messages)
            prompt = (
                f"Handoff received from {handoff.source}. Enter your response: " if handoff else "Enter your response: "
            )
            user_input = await self._get_input(prompt, cancellation_token)
            if handoff:
                return Response(
                    chat_message=HandoffMessage(content=user_input, target=handoff.source, source=self.name)
                )
            else:
                return Response(chat_message=TextMessage(content=user_input, source=self.name))
        except Exception as e:
            raise RuntimeError(f"Failed to get user input: {str(e)}")
    

    on_messages 方法是处理接收到的消息并根据 input_func 请求用户输入的核心方法。

    • 它首先检查是否有 handoff 消息,handoff 是一种消息转交机制,用于指示代理需要等待用户输入。
    • 然后,根据是否有 handoff,设置 prompt 提示用户输入,并调用 _get_input 获取用户输入。
    • 最后,根据输入生成 TextMessageHandoffMessage,并通过 Response 返回。

5. UserProxyAgent 类的使用

  1. 定义一个同步输入函数 custom_input

    def custom_input(prompt: str) -> str:
        return "The height of the Eiffel Tower is 324 meters. Aloha!"
    

    这是一个简单的同步函数,模拟用户输入。无论传入什么 prompt,它总是返回 “The height of the Eiffel Tower is 324 meters. Aloha!”。

  2. 创建 UserProxyAgent 实例

    agent = UserProxyAgent(name="test_user", input_func=custom_input)
    

    创建一个 UserProxyAgent 实例,name 设置为 “test_user”,并将 input_func 设置为刚定义的 custom_input 函数。

  3. 创建消息列表 messages

    messages = [
        TextMessage(content="What is the height of the Eiffel Tower?", source="assistant")
    ]
    

    这是一个消息列表,包含一条来自 “assistant” 的 TextMessage,询问 “What is the height of the Eiffel Tower?”。

  4. 调用 on_messages 方法

    response = asyncio.run(agent.on_messages(messages, CancellationToken()))
    

    这里使用 asyncio.run() 来运行异步的 on_messages 方法。
    messages 传递给 on_messages,模拟一个对话。代理通过 custom_input 获取用户输入。

  5. 输出响应

    print(response)
    

    最终输出的 responseon_messages 返回的结果。根据输入的不同,代理可能返回 TextMessageHandoffMessage

6. 运行时流程

  • on_messages 被调用,代理开始等待用户输入。
  • prompt 被设置为 “Enter your response:”,并调用 custom_input 获取用户输入。
  • 用户输入被模拟为 “The height of the Eiffel Tower is 324 meters. Aloha!”。
  • 生成并返回一个 TextMessage,其内容为用户输入的文本。
  • 最终,response 将是一个包含 TextMessageResponse 对象。

7. 总结

UserProxyAgent 是一个模拟用户交互的代理,能够根据不同类型的输入函数(同步或异步)来获取用户的回答。
on_messages 方法负责处理来自其他代理或系统的消息,等待并获取用户输入,生成适当的消息类型并返回。
示例代码模拟了一个简单的对话,其中用户输入的文本通过 UserProxyAgent 处理并返回。

参考链接:https://github.com/microsoft/autogen/blob/main/python/packages/autogen-agentchat/tests/test_userproxy_agent.py

如果有任何问题,欢迎在评论区提问。

Logo

Agent 垂直技术社区,欢迎活跃、内容共建。

更多推荐