Skip to content

Memory Module

erniebot_agent.chat_models.base

ChatModel

The base class of chat-optimized LLM.

Attributes:

Name Type Description
model str

The model name.

default_chat_kwargs Any

A dict for setting default args for chat model, the supported keys include model, _config_, top_p, etc.

Source code in erniebot-agent/src/erniebot_agent/chat_models/base.py
class ChatModel(metaclass=ABCMeta):
    """The base class of chat-optimized LLM.

    Attributes:
        model (str): The model name.
        default_chat_kwargs (Any): A dict for setting default args for chat model,
            the supported keys include `model`, `_config_`, `top_p`, etc.
    """

    def __init__(self, model: str, **default_chat_kwargs: Any):
        self.model = model
        self.default_chat_kwargs = default_chat_kwargs

    @overload
    async def chat(
        self, messages: List[Message], *, stream: Literal[False] = ..., **kwargs: Any
    ) -> AIMessage:
        ...

    @overload
    async def chat(
        self, messages: List[Message], *, stream: Literal[True], **kwargs: Any
    ) -> AsyncIterator[AIMessageChunk]:
        ...

    @overload
    async def chat(
        self, messages: List[Message], *, stream: bool, **kwargs: Any
    ) -> Union[AIMessage, AsyncIterator[AIMessageChunk]]:
        ...

    @abstractmethod
    async def chat(
        self, messages: List[Message], *, stream: bool = False, **kwargs: Any
    ) -> Union[AIMessage, AsyncIterator[AIMessageChunk]]:
        """The abstract method for asynchronously chatting with the LLM.

        Args:
            messages (List[Message]): A list of messages.
            stream (bool): Whether to use streaming generation. Defaults to False.
            **kwargs: Keyword arguments, such as `top_p`, `temperature`, `penalty_score`, and `system`.

        Returns:
            If stream is False, returns a single message.
            If stream is True, returns an asynchronous iterator of message chunks.
        """
        raise NotImplementedError

chat(messages, *, stream=False, **kwargs) abstractmethod async

The abstract method for asynchronously chatting with the LLM.

Parameters:

Name Type Description Default
messages List[Message]

A list of messages.

required
stream bool

Whether to use streaming generation. Defaults to False.

False
**kwargs Any

Keyword arguments, such as top_p, temperature, penalty_score, and system.

{}

Returns:

Type Description
Union[AIMessage, AsyncIterator[AIMessageChunk]]

If stream is False, returns a single message.

Union[AIMessage, AsyncIterator[AIMessageChunk]]

If stream is True, returns an asynchronous iterator of message chunks.

Source code in erniebot-agent/src/erniebot_agent/chat_models/base.py
@abstractmethod
async def chat(
    self, messages: List[Message], *, stream: bool = False, **kwargs: Any
) -> Union[AIMessage, AsyncIterator[AIMessageChunk]]:
    """The abstract method for asynchronously chatting with the LLM.

    Args:
        messages (List[Message]): A list of messages.
        stream (bool): Whether to use streaming generation. Defaults to False.
        **kwargs: Keyword arguments, such as `top_p`, `temperature`, `penalty_score`, and `system`.

    Returns:
        If stream is False, returns a single message.
        If stream is True, returns an asynchronous iterator of message chunks.
    """
    raise NotImplementedError

erniebot_agent.chat_models.erniebot

ERNIEBot

The implementation of the ERNIE Bot model.

Attributes:

Name Type Description
model str

The model name.

api_type str

The backend of the ERNIE Bot model.

access_token Optional[str]

The access token corresponding to the backend.

default_chat_kwargs Any

A dict for setting default args for chat model, the supported keys include model, _config_, top_p, etc.

Source code in erniebot-agent/src/erniebot_agent/chat_models/erniebot.py
class ERNIEBot(BaseERNIEBot):
    """The implementation of the ERNIE Bot model.

    Attributes:
        model (str): The model name.
        api_type (str): The backend of the ERNIE Bot model.
        access_token (Optional[str]): The access token corresponding to the backend.
        default_chat_kwargs (Any): A dict for setting default args for chat model,
            the supported keys include `model`, `_config_`, `top_p`, etc.
    """

    def __init__(
        self,
        model: str,
        api_type: str = "aistudio",
        access_token: Optional[str] = None,
        enable_multi_step_tool_call: bool = False,
        **default_chat_kwargs: Any,
    ) -> None:
        """Initializes an instance of the `ERNIEBot` class.

        Args:
            model (str): The model name. It should be "ernie-3.5", "ernie-turbo", "ernie-4.0", or
                "ernie-longtext".
            api_type (str): The backend of erniebot. It should be "aistudio" or "qianfan".
                Defaults to "aistudio".
            access_token (Optional[str]): The access token for the backend of erniebot.
                If access_token is None, the global access_token will be used.
            enable_multi_step_tool_call (bool): Whether to enable the multi-step tool call.
                Defaults to False.
            **default_chat_kwargs: Keyword arguments, such as `_config_`, `top_p`, `temperature`,
                `penalty_score`, and `system`.
        """
        super().__init__(model=model, **default_chat_kwargs)

        self.api_type = api_type
        if access_token is None:
            access_token = C.get_global_access_token()
        self.access_token = access_token
        self._maybe_validate_qianfan_auth()

        self.enable_multi_step_json = json.dumps(
            {"multi_step_tool_call_close": not enable_multi_step_tool_call}
        )

    @overload
    async def chat(
        self,
        messages: List[Message],
        *,
        stream: Literal[False] = ...,
        functions: Optional[List[dict]] = ...,
        **kwargs: Any,
    ) -> AIMessage:
        ...

    @overload
    async def chat(
        self,
        messages: List[Message],
        *,
        stream: Literal[True],
        functions: Optional[List[dict]] = ...,
        **kwargs: Any,
    ) -> AsyncIterator[AIMessageChunk]:
        ...

    @overload
    async def chat(
        self, messages: List[Message], *, stream: bool, functions: Optional[List[dict]] = ..., **kwargs: Any
    ) -> Union[AIMessage, AsyncIterator[AIMessageChunk]]:
        ...

    async def chat(
        self,
        messages: List[Message],
        *,
        stream: bool = False,
        functions: Optional[List[dict]] = None,
        **kwargs: Any,
    ) -> Union[AIMessage, AsyncIterator[AIMessageChunk]]:
        """Asynchronously chats with the ERNIE Bot model.

        Args:
            messages (List[Message]): A list of messages.
            stream (bool): Whether to use streaming generation. Defaults to False.
            functions (Optional[List[dict]]): The function definitions to be used by the model.
                Defaults to None.
            **kwargs: Keyword arguments, such as `top_p`, `temperature`, `penalty_score`, and `system`.

        Returns:
            If `stream` is False, returns a single message.
            If `stream` is True, returns an asynchronous iterator of message chunks.
        """
        cfg_dict = self._generate_config(messages, functions=functions, **kwargs)

        response = await self._generate_response(cfg_dict, stream, functions)

        if not stream:
            assert isinstance(response, ChatCompletionResponse)
            return convert_response_to_output(response, AIMessage)
        else:
            assert isinstance(response, AsyncIterator)
            # We have to do type casting here due to the known mypy issue:
            # https://github.com/python/mypy/issues/16590
            return (
                convert_response_to_output(resp, AIMessageChunk)
                async for resp in cast(AsyncIterator[ChatCompletionResponse], response)
            )

    def _generate_config(self, messages: List[Message], functions, **kwargs) -> dict:
        if any(isinstance(m, SystemMessage) for m in messages):
            raise ValueError(f"The input messages should not contain SystemMessage: {messages}")
        cfg_dict = self.default_chat_kwargs.copy()

        cfg_dict.setdefault("_config_", {})
        if self.api_type is not None:
            cfg_dict["_config_"]["api_type"] = self.api_type
        if hasattr(self, "ak") and hasattr(self, "sk"):
            cfg_dict["_config_"]["ak"] = self.ak
            cfg_dict["_config_"]["sk"] = self.sk
        cfg_dict["_config_"]["access_token"] = self.access_token

        cfg_dict["messages"] = [m.to_dict() for m in messages]
        cfg_dict["model"] = self.model
        if functions is not None:
            cfg_dict["functions"] = functions

        name_list = ["top_p", "temperature", "penalty_score", "system", "plugins"]
        for name in name_list:
            if name in kwargs:
                cfg_dict[name] = kwargs[name]

        if "plugins" in cfg_dict and (cfg_dict["plugins"] is None or len(cfg_dict["plugins"]) == 0):
            cfg_dict.pop("plugins")

        return cfg_dict

    def _maybe_validate_qianfan_auth(self) -> None:
        if self.api_type == "qianfan":
            if self.access_token is None:
                # 默认选择千帆时,如果设置了access_token,这个access_token不是aistudio的
                if "ak" and "sk" not in self.default_chat_kwargs:
                    ak, sk = C.get_global_aksk()
                    if ak is None or sk is None:
                        raise RuntimeError("Please set at least one of ak+sk or access token.")
                    else:
                        self.ak = ak
                        self.sk = sk
                else:
                    self.ak = self.default_chat_kwargs.pop("ak")
                    self.sk = self.default_chat_kwargs.pop("sk")
            else:
                # If set access_token in environment and pass ak and sk in default_chat_kwargs,
                # the access_token in default_chat_kwargs will be used.
                if "ak" and "sk" in self.default_chat_kwargs:
                    self.ak = self.default_chat_kwargs.pop("ak")
                    self.sk = self.default_chat_kwargs.pop("sk")

    async def _generate_response(
        self, cfg_dict: dict, stream: bool, functions: Optional[List[dict]]
    ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponse]]:
        # TODO: Improve this when erniebot typing issue is fixed.
        # Note: If plugins is not None, erniebot will not use Baidu_search.
        if "plugins" in cfg_dict:
            response = await erniebot.ChatCompletionWithPlugins.acreate(
                messages=cfg_dict["messages"],
                plugins=cfg_dict["plugins"],  # type: ignore
                stream=stream,
                _config_=cfg_dict["_config_"],
                functions=functions,  # type: ignore
                extra_params={
                    "extra_data": self.enable_multi_step_json,
                },
            )
        else:
            response = await erniebot.ChatCompletion.acreate(
                stream=stream,
                extra_params={
                    "extra_data": self.enable_multi_step_json,
                },
                **cfg_dict,
            )

        return response

__init__(model, api_type='aistudio', access_token=None, enable_multi_step_tool_call=False, **default_chat_kwargs)

Initializes an instance of the ERNIEBot class.

Parameters:

Name Type Description Default
model str

The model name. It should be "ernie-3.5", "ernie-turbo", "ernie-4.0", or "ernie-longtext".

required
api_type str

The backend of erniebot. It should be "aistudio" or "qianfan". Defaults to "aistudio".

'aistudio'
access_token Optional[str]

The access token for the backend of erniebot. If access_token is None, the global access_token will be used.

None
enable_multi_step_tool_call bool

Whether to enable the multi-step tool call. Defaults to False.

False
**default_chat_kwargs Any

Keyword arguments, such as _config_, top_p, temperature, penalty_score, and system.

{}
Source code in erniebot-agent/src/erniebot_agent/chat_models/erniebot.py
def __init__(
    self,
    model: str,
    api_type: str = "aistudio",
    access_token: Optional[str] = None,
    enable_multi_step_tool_call: bool = False,
    **default_chat_kwargs: Any,
) -> None:
    """Initializes an instance of the `ERNIEBot` class.

    Args:
        model (str): The model name. It should be "ernie-3.5", "ernie-turbo", "ernie-4.0", or
            "ernie-longtext".
        api_type (str): The backend of erniebot. It should be "aistudio" or "qianfan".
            Defaults to "aistudio".
        access_token (Optional[str]): The access token for the backend of erniebot.
            If access_token is None, the global access_token will be used.
        enable_multi_step_tool_call (bool): Whether to enable the multi-step tool call.
            Defaults to False.
        **default_chat_kwargs: Keyword arguments, such as `_config_`, `top_p`, `temperature`,
            `penalty_score`, and `system`.
    """
    super().__init__(model=model, **default_chat_kwargs)

    self.api_type = api_type
    if access_token is None:
        access_token = C.get_global_access_token()
    self.access_token = access_token
    self._maybe_validate_qianfan_auth()

    self.enable_multi_step_json = json.dumps(
        {"multi_step_tool_call_close": not enable_multi_step_tool_call}
    )

chat(messages, *, stream=False, functions=None, **kwargs) async

Asynchronously chats with the ERNIE Bot model.

Parameters:

Name Type Description Default
messages List[Message]

A list of messages.

required
stream bool

Whether to use streaming generation. Defaults to False.

False
functions Optional[List[dict]]

The function definitions to be used by the model. Defaults to None.

None
**kwargs Any

Keyword arguments, such as top_p, temperature, penalty_score, and system.

{}

Returns:

Type Description
Union[AIMessage, AsyncIterator[AIMessageChunk]]

If stream is False, returns a single message.

Union[AIMessage, AsyncIterator[AIMessageChunk]]

If stream is True, returns an asynchronous iterator of message chunks.

Source code in erniebot-agent/src/erniebot_agent/chat_models/erniebot.py
async def chat(
    self,
    messages: List[Message],
    *,
    stream: bool = False,
    functions: Optional[List[dict]] = None,
    **kwargs: Any,
) -> Union[AIMessage, AsyncIterator[AIMessageChunk]]:
    """Asynchronously chats with the ERNIE Bot model.

    Args:
        messages (List[Message]): A list of messages.
        stream (bool): Whether to use streaming generation. Defaults to False.
        functions (Optional[List[dict]]): The function definitions to be used by the model.
            Defaults to None.
        **kwargs: Keyword arguments, such as `top_p`, `temperature`, `penalty_score`, and `system`.

    Returns:
        If `stream` is False, returns a single message.
        If `stream` is True, returns an asynchronous iterator of message chunks.
    """
    cfg_dict = self._generate_config(messages, functions=functions, **kwargs)

    response = await self._generate_response(cfg_dict, stream, functions)

    if not stream:
        assert isinstance(response, ChatCompletionResponse)
        return convert_response_to_output(response, AIMessage)
    else:
        assert isinstance(response, AsyncIterator)
        # We have to do type casting here due to the known mypy issue:
        # https://github.com/python/mypy/issues/16590
        return (
            convert_response_to_output(resp, AIMessageChunk)
            async for resp in cast(AsyncIterator[ChatCompletionResponse], response)
        )