Skip to content

Tools Module

erniebot_agent.tools.schema

OpenAPIProperty

Source code in erniebot-agent/src/erniebot_agent/tools/schema.py
class OpenAPIProperty(BaseModel):
    type: str
    json_schema_extra: Optional[Dict[str, str]] = None
    description: Optional[str] = None
    required: Optional[List[str]] = None
    enum: Optional[List[Union[int, str]]] = None
    items: dict = Field(default_factory=dict)
    properties: dict = Field(default_factory=dict)

ToolParameterView

Source code in erniebot-agent/src/erniebot_agent/tools/schema.py
class ToolParameterView(BaseModel):
    __prompt__: Optional[str] = None

    class Config:
        use_enum_values = True

    @classmethod
    def from_openapi_dict(cls, schema: dict) -> Type[ToolParameterView]:
        """parse openapi component schemas to ParameterView
        Args:
            response_or_returns (dict): the content of status code

        Returns:
            _type_: _description_
        """

        # TODO(wj-Mcat): to load Optional field
        fields = {}
        for field_name, field_dict in schema.get("properties", {}).items():
            # skip loading invalid field to improve compatibility
            if "type" not in field_dict or "description" not in field_dict:
                continue

            field_type = python_type_from_json_type(field_dict)

            if field_type is List[ToolParameterView]:
                SubParameterView: Type[ToolParameterView] = ToolParameterView.from_openapi_dict(
                    field_dict["items"]
                )
                field_type = List[SubParameterView]  # type: ignore
            elif field_type is ToolParameterView:
                field_type = ToolParameterView.from_openapi_dict(field_dict)
            elif "enum" in field_dict:
                field_type = create_enum_class(field_name, field_dict["enum"])

            # TODO(wj-Mcat): remove supporting for `summary` field
            if "summary" in field_dict:
                description = field_dict["summary"]
                logger.info("`summary` field will be deprecated, please use `description`")

                if "description" in field_dict:
                    logger.info("`description` field will be used instead of `summary`")
                    description = field_dict["description"]
            else:
                description = field_dict.get("description", None)

            description = description or ""

            format = field_dict.get("format", None)
            json_schema_extra = {}
            if format is not None:
                json_schema_extra["format"] = format

            json_schema_extra.update(
                {key: value for key, value in field_dict.items() if key.startswith("x-ebagent")}
            )

            if get_typing_list_type(field_type) is not None and field_type is not List[ToolParameterView]:
                json_schema_extra["array_items_schema"] = field_dict["items"]

            field_info_param = dict(
                annotation=field_type, description=description, json_schema_extra=json_schema_extra
            )
            if "default" in field_dict:
                field_info_param["default"] = field_dict["default"]
            field = FieldInfo(**field_info_param)  # type: ignore

            # TODO(wj-Mcat): to handle list field required & not-required
            # if get_typing_list_type(field_type) is not None:
            #     field.default_factory = list

            fields[field_name] = (field_type, field)

        model = create_model("OpenAPIParameterView", __base__=ToolParameterView, **fields)  # type: ignore

        # get the prompt for schema
        model.__prompt__ = schema.get("x-ebagent-prompt", None)
        return model

    @classmethod
    def to_openapi_dict(cls) -> dict:
        """convert ParametersView to openapi spec dict

        Returns:
            dict: schema of openapi
        """

        required_names, properties = [], {}
        for field_name, field_info in cls.model_fields.items():
            if field_info.is_required() and not is_optional_type(field_info.annotation):
                required_names.append(field_name)

            properties[field_name] = dict(get_field_openapi_property(field_info))

        result = {
            "type": "object",
            "properties": properties,
        }
        if len(required_names) > 0:
            result["required"] = required_names
        result = scrub_dict(result, remove_empty_dict=True)  # type: ignore
        return result or {}

    @classmethod
    def function_call_schema(cls) -> dict:
        """get function_call schame

        Returns:
            dict: the schema of function_call
        """
        return cls.to_openapi_dict()

    @classmethod
    def from_dict(cls, field_map: Dict[str, Any]):
        """
        Class method to create a Pydantic model dynamically based on a dictionary.

        Args:
            field_map (Dict[str, Any]): A dictionary mapping field names to their corresponding type
            and description.

        Returns:
            PydanticModel: A dynamically created Pydantic model with fields specified by the
            input dictionary.

        Note:
            This method is used to create a Pydantic model dynamically based on the provided dictionary,
            where each field's type and description are specified in the input.

        """
        fields = {}
        for field_name, field_dict in field_map.items():
            field_type = field_dict["type"]
            description = field_dict["description"]
            field = FieldInfo(annotation=field_type, description=description)
            fields[field_name] = (field_type, field)
        return create_model(cls.__name__, __base__=ToolParameterView, **fields)  # type: ignore

from_dict(field_map) classmethod

Class method to create a Pydantic model dynamically based on a dictionary.

Parameters:

Name Type Description Default
field_map Dict[str, Any]

A dictionary mapping field names to their corresponding type

required

Returns:

Name Type Description
PydanticModel

A dynamically created Pydantic model with fields specified by the

input dictionary.

Note

This method is used to create a Pydantic model dynamically based on the provided dictionary, where each field's type and description are specified in the input.

Source code in erniebot-agent/src/erniebot_agent/tools/schema.py
@classmethod
def from_dict(cls, field_map: Dict[str, Any]):
    """
    Class method to create a Pydantic model dynamically based on a dictionary.

    Args:
        field_map (Dict[str, Any]): A dictionary mapping field names to their corresponding type
        and description.

    Returns:
        PydanticModel: A dynamically created Pydantic model with fields specified by the
        input dictionary.

    Note:
        This method is used to create a Pydantic model dynamically based on the provided dictionary,
        where each field's type and description are specified in the input.

    """
    fields = {}
    for field_name, field_dict in field_map.items():
        field_type = field_dict["type"]
        description = field_dict["description"]
        field = FieldInfo(annotation=field_type, description=description)
        fields[field_name] = (field_type, field)
    return create_model(cls.__name__, __base__=ToolParameterView, **fields)  # type: ignore

from_openapi_dict(schema) classmethod

parse openapi component schemas to ParameterView Args: response_or_returns (dict): the content of status code

Returns:

Name Type Description
_type_ Type[ToolParameterView]

description

Source code in erniebot-agent/src/erniebot_agent/tools/schema.py
@classmethod
def from_openapi_dict(cls, schema: dict) -> Type[ToolParameterView]:
    """parse openapi component schemas to ParameterView
    Args:
        response_or_returns (dict): the content of status code

    Returns:
        _type_: _description_
    """

    # TODO(wj-Mcat): to load Optional field
    fields = {}
    for field_name, field_dict in schema.get("properties", {}).items():
        # skip loading invalid field to improve compatibility
        if "type" not in field_dict or "description" not in field_dict:
            continue

        field_type = python_type_from_json_type(field_dict)

        if field_type is List[ToolParameterView]:
            SubParameterView: Type[ToolParameterView] = ToolParameterView.from_openapi_dict(
                field_dict["items"]
            )
            field_type = List[SubParameterView]  # type: ignore
        elif field_type is ToolParameterView:
            field_type = ToolParameterView.from_openapi_dict(field_dict)
        elif "enum" in field_dict:
            field_type = create_enum_class(field_name, field_dict["enum"])

        # TODO(wj-Mcat): remove supporting for `summary` field
        if "summary" in field_dict:
            description = field_dict["summary"]
            logger.info("`summary` field will be deprecated, please use `description`")

            if "description" in field_dict:
                logger.info("`description` field will be used instead of `summary`")
                description = field_dict["description"]
        else:
            description = field_dict.get("description", None)

        description = description or ""

        format = field_dict.get("format", None)
        json_schema_extra = {}
        if format is not None:
            json_schema_extra["format"] = format

        json_schema_extra.update(
            {key: value for key, value in field_dict.items() if key.startswith("x-ebagent")}
        )

        if get_typing_list_type(field_type) is not None and field_type is not List[ToolParameterView]:
            json_schema_extra["array_items_schema"] = field_dict["items"]

        field_info_param = dict(
            annotation=field_type, description=description, json_schema_extra=json_schema_extra
        )
        if "default" in field_dict:
            field_info_param["default"] = field_dict["default"]
        field = FieldInfo(**field_info_param)  # type: ignore

        # TODO(wj-Mcat): to handle list field required & not-required
        # if get_typing_list_type(field_type) is not None:
        #     field.default_factory = list

        fields[field_name] = (field_type, field)

    model = create_model("OpenAPIParameterView", __base__=ToolParameterView, **fields)  # type: ignore

    # get the prompt for schema
    model.__prompt__ = schema.get("x-ebagent-prompt", None)
    return model

function_call_schema() classmethod

get function_call schame

Returns:

Name Type Description
dict dict

the schema of function_call

Source code in erniebot-agent/src/erniebot_agent/tools/schema.py
@classmethod
def function_call_schema(cls) -> dict:
    """get function_call schame

    Returns:
        dict: the schema of function_call
    """
    return cls.to_openapi_dict()

to_openapi_dict() classmethod

convert ParametersView to openapi spec dict

Returns:

Name Type Description
dict dict

schema of openapi

Source code in erniebot-agent/src/erniebot_agent/tools/schema.py
@classmethod
def to_openapi_dict(cls) -> dict:
    """convert ParametersView to openapi spec dict

    Returns:
        dict: schema of openapi
    """

    required_names, properties = [], {}
    for field_name, field_info in cls.model_fields.items():
        if field_info.is_required() and not is_optional_type(field_info.annotation):
            required_names.append(field_name)

        properties[field_name] = dict(get_field_openapi_property(field_info))

    result = {
        "type": "object",
        "properties": properties,
    }
    if len(required_names) > 0:
        result["required"] = required_names
    result = scrub_dict(result, remove_empty_dict=True)  # type: ignore
    return result or {}

RemoteToolView dataclass

Source code in erniebot-agent/src/erniebot_agent/tools/schema.py
@dataclass
class RemoteToolView:
    uri: str
    method: str
    name: str
    description: str
    version: str

    parameters: Optional[Type[ToolParameterView]] = None
    parameters_description: Optional[str] = None
    parameters_content_type: Optional[str] = None

    returns: Optional[Type[ToolParameterView]] = None
    returns_description: Optional[str] = None
    returns_content_type: Optional[str] = None

    returns_ref_uri: Optional[str] = None
    parameters_ref_uri: Optional[str] = None

    def to_openapi_dict(self):
        result = {
            "operationId": self.name,
            "description": self.description,
        }
        if self.returns is not None:
            response = {
                "200": {
                    "description": self.returns_description,
                    "content": {
                        self.returns_content_type: {
                            "schema": {"$ref": "#/components/schemas/" + (self.returns_ref_uri or "")}
                        }
                    },
                }
            }
            result["responses"] = response

        if self.parameters is not None:
            parameters = {
                "required": True,
                "content": {
                    self.parameters_content_type: {
                        "schema": {"$ref": "#/components/schemas/" + (self.parameters_ref_uri or "")}
                    }
                },
            }
            result["requestBody"] = parameters
        return {self.method: result}

    @staticmethod
    def from_openapi_dict(
        uri: str,
        method: str,
        path_info: dict,
        parameters_views: dict[str, Type[ToolParameterView]],
        version: str,
    ) -> RemoteToolView:
        """construct RemoteToolView from openapi spec-dict info

        Args:
            uri (str): the url path of remote tool
            method (str): http method: one of [get, post, put, delete]
            path_info (dict): the spec info of remote tool
            parameters_views (dict[str, ParametersView]):
                the dict of parameters views which are the schema of input/output of tool
            version (Optional[str]): the optional version of remote tool

        Returns:
            RemoteToolView: the instance of remote tool view
        """
        parameters_ref_uri, returns_ref_uri = None, None
        parameters, parameters_description = None, None
        parameters_content_type, returns_content_type = None, None
        if "requestBody" in path_info:
            request_content = path_info["requestBody"]["content"]
            assert len(request_content.keys()) == 1
            parameters_content_type = list(request_content.keys())[0]
            request_ref = request_content[parameters_content_type]["schema"]["$ref"]
            parameters_ref_uri = request_ref.split("/")[-1]
            assert parameters_ref_uri in parameters_views
            parameters = parameters_views[parameters_ref_uri]
            parameters_description = path_info["requestBody"].get("description", None)

        returns, returns_description = None, None
        if "responses" in path_info:
            response_content = list(path_info["responses"].values())[0]["content"]
            assert len(response_content.keys()) == 1
            returns_content_type = list(response_content.keys())[0]
            response_ref = response_content[returns_content_type]["schema"]["$ref"]
            returns_ref_uri = response_ref.split("/")[-1]
            assert returns_ref_uri in parameters_views
            returns = parameters_views[returns_ref_uri]
            returns_description = list(path_info["responses"].values())[0].get("description", None)

        return RemoteToolView(
            name=path_info["operationId"],
            parameters=parameters,
            version=version,
            parameters_description=parameters_description,
            parameters_content_type=parameters_content_type,
            returns=returns,
            returns_description=returns_description,
            returns_content_type=returns_content_type,
            description=path_info.get("description", path_info.get("summary", None)),
            method=method,
            uri=uri,
            # save ref id info
            returns_ref_uri=returns_ref_uri,
            parameters_ref_uri=parameters_ref_uri,
        )

    def function_call_schema(self):
        inputs = {
            "name": self.name,
            "description": self.description,
        }
        if self.parameters is not None:
            inputs["parameters"] = self.parameters.function_call_schema()  # type: ignore
        else:
            inputs["parameters"] = {"type": "object", "properties": {}}

        if self.returns is not None:
            inputs["responses"] = self.returns.function_call_schema()  # type: ignore
        return scrub_dict(inputs) or {}

from_openapi_dict(uri, method, path_info, parameters_views, version) staticmethod

construct RemoteToolView from openapi spec-dict info

Parameters:

Name Type Description Default
uri str

the url path of remote tool

required
method str

http method: one of [get, post, put, delete]

required
path_info dict

the spec info of remote tool

required
parameters_views dict[str, ParametersView]

the dict of parameters views which are the schema of input/output of tool

required
version Optional[str]

the optional version of remote tool

required

Returns:

Name Type Description
RemoteToolView RemoteToolView

the instance of remote tool view

Source code in erniebot-agent/src/erniebot_agent/tools/schema.py
@staticmethod
def from_openapi_dict(
    uri: str,
    method: str,
    path_info: dict,
    parameters_views: dict[str, Type[ToolParameterView]],
    version: str,
) -> RemoteToolView:
    """construct RemoteToolView from openapi spec-dict info

    Args:
        uri (str): the url path of remote tool
        method (str): http method: one of [get, post, put, delete]
        path_info (dict): the spec info of remote tool
        parameters_views (dict[str, ParametersView]):
            the dict of parameters views which are the schema of input/output of tool
        version (Optional[str]): the optional version of remote tool

    Returns:
        RemoteToolView: the instance of remote tool view
    """
    parameters_ref_uri, returns_ref_uri = None, None
    parameters, parameters_description = None, None
    parameters_content_type, returns_content_type = None, None
    if "requestBody" in path_info:
        request_content = path_info["requestBody"]["content"]
        assert len(request_content.keys()) == 1
        parameters_content_type = list(request_content.keys())[0]
        request_ref = request_content[parameters_content_type]["schema"]["$ref"]
        parameters_ref_uri = request_ref.split("/")[-1]
        assert parameters_ref_uri in parameters_views
        parameters = parameters_views[parameters_ref_uri]
        parameters_description = path_info["requestBody"].get("description", None)

    returns, returns_description = None, None
    if "responses" in path_info:
        response_content = list(path_info["responses"].values())[0]["content"]
        assert len(response_content.keys()) == 1
        returns_content_type = list(response_content.keys())[0]
        response_ref = response_content[returns_content_type]["schema"]["$ref"]
        returns_ref_uri = response_ref.split("/")[-1]
        assert returns_ref_uri in parameters_views
        returns = parameters_views[returns_ref_uri]
        returns_description = list(path_info["responses"].values())[0].get("description", None)

    return RemoteToolView(
        name=path_info["operationId"],
        parameters=parameters,
        version=version,
        parameters_description=parameters_description,
        parameters_content_type=parameters_content_type,
        returns=returns,
        returns_description=returns_description,
        returns_content_type=returns_content_type,
        description=path_info.get("description", path_info.get("summary", None)),
        method=method,
        uri=uri,
        # save ref id info
        returns_ref_uri=returns_ref_uri,
        parameters_ref_uri=parameters_ref_uri,
    )

Endpoint dataclass

Source code in erniebot-agent/src/erniebot_agent/tools/schema.py
@dataclass
class Endpoint:
    url: str
    description: Optional[str] = None

EndpointInfo dataclass

Source code in erniebot-agent/src/erniebot_agent/tools/schema.py
@dataclass
class EndpointInfo:
    title: str
    description: str
    version: str

erniebot_agent.tools.base

BaseTool

Source code in erniebot-agent/src/erniebot_agent/tools/base.py
class BaseTool(ABC):
    @property
    @abstractmethod
    def tool_name(self) -> str:
        raise NotImplementedError

    @property
    @abstractmethod
    def examples(self) -> List[Message]:
        raise NotImplementedError

    @abstractmethod
    async def __call__(self, *args: Any, **kwds: Any) -> Any:
        raise NotImplementedError

    @abstractmethod
    def function_call_schema(self) -> dict:
        raise NotImplementedError

Tool

Source code in erniebot-agent/src/erniebot_agent/tools/base.py
class Tool(BaseTool, ABC):
    description: str
    name: Optional[str] = None
    input_type: Optional[Type[ToolParameterView]] = None
    ouptut_type: Optional[Type[ToolParameterView]] = None

    def __str__(self) -> str:
        name = self.name if self.name else self.tool_name
        return "<name: {0}, description: {1}>".format(name, self.description)

    def __repr__(self):
        return self.__str__()

    @property
    def tool_name(self):
        return self.name or self.__class__.__name__

    @abstractmethod
    async def __call__(self, *args: Any, **kwds: Any) -> Dict[str, Any]:
        """the body of tools

        Returns:
            Any:
        """
        raise NotImplementedError

    def function_call_schema(self) -> dict:
        inputs = {
            "name": self.tool_name,
            "description": self.description,
        }

        if len(self.examples) > 0:
            inputs["examples"] = [example.to_dict() for example in self.examples]

        if self.input_type is not None:
            inputs["parameters"] = self.input_type.function_call_schema()
        else:
            inputs["parameters"] = {"type": "object", "properties": {}}

        if self.ouptut_type is not None:
            inputs["responses"] = self.ouptut_type.function_call_schema()

        return scrub_dict(inputs) or {}

    @property
    def examples(self) -> List[Message]:
        return []

__call__(*args, **kwds) abstractmethod async

the body of tools

Returns:

Name Type Description
Any Dict[str, Any]
Source code in erniebot-agent/src/erniebot_agent/tools/base.py
@abstractmethod
async def __call__(self, *args: Any, **kwds: Any) -> Dict[str, Any]:
    """the body of tools

    Returns:
        Any:
    """
    raise NotImplementedError

erniebot_agent.tools.remote_tool

RemoteTool

Source code in erniebot-agent/src/erniebot_agent/tools/remote_tool.py
class RemoteTool(BaseTool):
    def __init__(
        self,
        tool_view: RemoteToolView,
        server_url: str,
        headers: dict,
        version: str,
        file_manager: Optional[FileManager],
        examples: Optional[List[Message]] = None,
        tool_name_prefix: Optional[str] = None,
    ) -> None:
        self.tool_view = tool_view
        self.server_url = server_url
        self.headers = headers
        self.version = version
        self.file_manager = file_manager
        self._examples = examples
        self.tool_name_prefix = tool_name_prefix
        # If `tool_name_prefix`` is provided, we prepend `tool_name_prefix`` to the `name` field of all tools
        if tool_name_prefix is not None and not self.tool_view.name.startswith(f"{self.tool_name_prefix}/"):
            self.tool_view = dataclasses.replace(
                self.tool_view, name=f"{self.tool_name_prefix}/{self.tool_view.name}"
            )

        self.response_prompt: Optional[str] = None

    @property
    def examples(self) -> List[Message]:
        return self._examples or []

    def __str__(self) -> str:
        return "<name: {0}, server_url: {1}, description: {2}>".format(
            self.tool_name, self.server_url, self.tool_view.description
        )

    def __repr__(self):
        return self.__str__()

    @property
    def tool_name(self):
        return self.tool_view.name

    async def __pre_process__(self, tool_arguments: Dict[str, Any]) -> dict:
        async def fileid_to_byte(file_id, file_manager):
            file = file_manager.look_up_file_by_id(file_id)
            byte_str = await file.read_contents()
            return byte_str

        async def convert_to_file_data(file_data: str, format: str):
            value = file_data.replace("<file>", "").replace("</file>", "")
            byte_value = await fileid_to_byte(value, file_manager)
            if format == "byte":
                byte_value = base64.b64encode(byte_value).decode()
            return byte_value

        file_manager = await self._get_file_manager()

        # 1. replace fileid with byte string
        parameter_file_info = get_file_info_from_param_view(self.tool_view.parameters)
        for key in tool_arguments.keys():
            if self.tool_view.parameters:
                if key not in self.tool_view.parameters.model_fields:
                    keys = list(self.tool_view.parameters.model_fields.keys())
                    raise RemoteToolError(
                        f"`{self.tool_name}` received unexpected arguments `{key}`. "
                        f"The avaiable arguments are {keys}",
                        stage="Input parsing",
                    )
            if key not in parameter_file_info:
                continue
            if self.tool_view.parameters is None:
                break

            argument_value = tool_arguments[key]
            if isinstance(argument_value, list):
                for index in range(len(argument_value)):
                    argument_value[index] = await convert_to_file_data(
                        argument_value[index], parameter_file_info[key]["format"]
                    )
            else:
                argument_value = await convert_to_file_data(
                    argument_value, parameter_file_info[key]["format"]
                )

            tool_arguments[key] = argument_value

        # 2. call tool get response
        if self.tool_view.parameters is not None:
            tool_arguments = dict(self.tool_view.parameters(**tool_arguments))

        return tool_arguments

    async def __post_process__(self, tool_response: dict) -> dict:
        tool_response = self.__adhoc_post_process__(tool_response)
        check_json_length(tool_response)
        if self.response_prompt is not None:
            tool_response["prompt"] = self.response_prompt
        elif self.tool_view.returns is not None and self.tool_view.returns.__prompt__ is not None:
            tool_response["prompt"] = self.tool_view.returns.__prompt__
        elif tool_response_contains_file(tool_response):
            tool_response["prompt"] = (
                "参考工具说明中对各个结果字段的描述,提取工具调用结果中的信息,生成一段通顺的文本满足用户的需求。"
                "请务必确保每个符合'file-'格式的字段只出现一次,无需将其转换为链接,也无需添加任何HTML、Markdown或其他格式化元素。"
            )

        # TODO(wj-Mcat): open the tool-response valdiation with pydantic model
        # if self.tool_view.returns is not None:
        #     tool_response = dict(self.tool_view.returns(**tool_response))
        return tool_response

    async def __call__(self, **tool_arguments: Dict[str, Any]) -> Any:
        tool_arguments = await self.__pre_process__(tool_arguments)
        tool_response = await self.send_request(tool_arguments)
        return await self.__post_process__(tool_response)

    async def send_request(self, tool_arguments: Dict[str, Any]) -> dict:
        url = self.server_url + self.tool_view.uri + "?version=" + self.version

        headers = deepcopy(self.headers)
        headers["Content-Type"] = self.tool_view.parameters_content_type

        requests_inputs = {
            "headers": headers,
        }
        if self.tool_view.method == "get":
            requests_inputs["params"] = tool_arguments
        elif self.tool_view.parameters_content_type == "application/json":
            requests_inputs["json"] = tool_arguments
        elif self.tool_view.parameters_content_type in [
            "application/x-www-form-urlencoded",
        ]:
            requests_inputs["data"] = tool_arguments
        elif self.tool_view.parameters_content_type == "multipart/form-data":
            parameter_file_infos = get_file_info_from_param_view(self.tool_view.parameters)
            requests_inputs["files"] = {}
            for file_key in parameter_file_infos.keys():
                if file_key in tool_arguments:
                    requests_inputs["files"][file_key] = tool_arguments.pop(file_key)
                    headers.pop("Content-Type", None)
            requests_inputs["data"] = tool_arguments
        else:
            raise RemoteToolError(
                f"Unsupported content type: {self.tool_view.parameters_content_type}", stage="Executing"
            )
        if self.tool_view.method == "get":
            response = requests.get(url, **requests_inputs)  # type: ignore
        elif self.tool_view.method == "post":
            response = requests.post(url, **requests_inputs)  # type: ignore
        elif self.tool_view.method == "put":
            response = requests.put(url, **requests_inputs)  # type: ignore
        elif self.tool_view.method == "delete":
            response = requests.delete(url, **requests_inputs)  # type: ignore
        else:
            raise RemoteToolError(f"method<{self.tool_view.method}> is invalid", stage="Executing")

        if response.status_code != 200:
            logger.debug(f"The resource requested returned the following headers: {response.headers}")
            raise RemoteToolError(
                f"The resource requested by `{self.tool_name}` "
                f"returned {response.status_code}: {response.text}",
                stage="Executing",
            )

        # parse the file from response
        returns_file_infos = get_file_info_from_param_view(self.tool_view.returns)

        if len(returns_file_infos) == 0 and is_json_response(response):
            return response.json()

        file_manager = await self._get_file_manager()

        file_metadata = {"tool_name": self.tool_name}
        if is_json_response(response) and len(returns_file_infos) > 0:
            response_json = response.json()
            file_info = await parse_file_from_json_response(
                response_json,
                file_manager=file_manager,
                param_view=self.tool_view.returns,  # type: ignore
                tool_name=self.tool_name,
            )
            response_json.update(file_info)
            return response_json
        file = await parse_file_from_response(
            response, file_manager, file_infos=returns_file_infos, file_metadata=file_metadata
        )

        if file is not None:
            if len(returns_file_infos) == 0:
                return {self.tool_view.returns_ref_uri: file.id}

            file_name = list(returns_file_infos.keys())[0]
            return {file_name: file.id}

        if len(returns_file_infos) == 0:
            return response.json()

        raise RemoteToolError(
            f"<{list(returns_file_infos.keys())}> are defined but cannot be processed from the "
            "response. Please ensure that the response headers contain either the Content-Disposition "
            "or Content-Type field.",
            stage="Output parsing",
        )

    def function_call_schema(self) -> dict:
        schema = self.tool_view.function_call_schema()

        if len(self.examples) > 0:
            schema["examples"] = [example.to_dict() for example in self.examples]

        return schema or {}

    def __adhoc_post_process__(self, tool_response: dict) -> dict:
        # temporary adhoc post processing logic for certain toolkits
        if self.tool_name.startswith("official-doc-rec") and self.tool_name.endswith("office_doc_rec"):
            if "results" in tool_response and isinstance(tool_response["results"], list):
                reformatted_result = []
                for result_line in tool_response["results"]:
                    if "words" in result_line and "word" in result_line["words"]:
                        reformatted_result.append(result_line["words"]["word"])
                tool_response["results"] = reformatted_result
        elif self.tool_name.startswith("highacc-ocr") and self.tool_name.endswith("OCR"):
            if "words_result" in tool_response and isinstance(tool_response["words_result"], list):
                reformatted_result = []
                for result in tool_response["words_result"]:
                    if "words" in result:
                        reformatted_result.append(result["words"])
                tool_response["words_result"] = reformatted_result
        elif self.tool_name.startswith("doc-analysis") and self.tool_name.endswith("doc_analysis"):
            if "results" in tool_response and isinstance(tool_response["results"], list):
                reformatted_result = []
                for result in tool_response["results"]:
                    if "words" in result and "word" in result["words"]:
                        reformatted_result.append(result["words"]["word"])
                tool_response["results"] = reformatted_result
        elif self.tool_name.startswith("pic-translate") and self.tool_name.endswith("pic_translate"):
            if "data" in tool_response:
                if "content" in tool_response["data"]:
                    tool_response["data"].pop("content")
                if "sumSrc" in tool_response["data"]:
                    tool_response["data"].pop("sumSrc")
        elif self.tool_name.startswith("translation") and self.tool_name.endswith("translation"):
            if "result" in tool_response and "trans_result" in tool_response["result"]:
                if isinstance(tool_response["result"]["trans_result"], list):
                    reformatted_result = []
                    for result in tool_response["result"]["trans_result"]:
                        if "dst" in result:
                            reformatted_result.append({"dst": result["dst"]})
                    tool_response["result"]["trans_result"] = reformatted_result
        elif self.tool_name.startswith("shopping-receipt") and self.tool_name.endswith("shopping_receip"):
            if "words_result" in tool_response and isinstance(tool_response["words_result"], list):
                keys = [
                    "shop_name",
                    "receipt_num",
                    "machine_num",
                    "employee_num",
                    "consumption_date",
                    "consumption_time",
                    "total_amount",
                    "change",
                    "currency",
                    "paid_amount",
                    "discount",
                    "print_time",
                ]
                for result in tool_response["words_result"]:
                    for key in keys:
                        if (
                            key in result
                            and len(result[key]) > 0
                            and "word" in result[key][0]
                            and result[key][0]["word"] == ""
                        ):
                            result.pop(key)
        # Remove log_id if in tool_response
        if "log_id" in tool_response:
            tool_response.pop("log_id")
        return tool_response

    async def _get_file_manager(self) -> FileManager:
        if self.file_manager is None:
            file_manager = await GlobalFileManagerHandler().get()
        else:
            file_manager = self.file_manager
        return file_manager

erniebot_agent.tools.remote_toolkit

RemoteToolkit dataclass

RemoteToolkit can be converted by openapi.yaml and endpoint

Source code in erniebot-agent/src/erniebot_agent/tools/remote_toolkit.py
@dataclass
class RemoteToolkit:
    """RemoteToolkit can be converted by openapi.yaml and endpoint"""

    openapi: str
    info: EndpointInfo
    servers: List[Endpoint]
    paths: List[RemoteToolView]
    file_manager: Optional[FileManager]

    component_schemas: dict[str, Type[ToolParameterView]]
    headers: dict
    examples: List[Message] = field(default_factory=list)
    _AISTUDIO_HUB_BASE_URL: ClassVar[str] = "https://aistudio-hub.baidu.com"

    @property
    def tool_name_prefix(self) -> str:
        return f"{self.info.title}/{self.info.version}"

    def __getitem__(self, tool_name: str) -> RemoteTool:
        return self.get_tool(tool_name)

    def get_tools(self) -> List[RemoteTool]:
        TOOL_CLASS = tool_registor.get_tool_class(self.info.title)
        return [
            TOOL_CLASS(
                path,
                self.servers[0].url,
                self.headers,
                self.info.version,
                file_manager=self.file_manager,
                examples=self.get_examples_by_name(path.name),
                tool_name_prefix=self.tool_name_prefix,
            )
            for path in self.paths
        ]

    def get_examples_by_name(self, tool_name: str) -> List[Message]:
        """get examples by tool-name

        Args:
            tool_name (str): the name of the tool

        Returns:
            List[Message]: the messages
        """
        # 1. split messages
        tool_examples: List[List[Message]] = []
        examples: List[Message] = []
        for example in self.examples:
            if isinstance(example, HumanMessage):
                if len(examples) == 0:
                    examples.append(example)
                else:
                    tool_examples.append(examples)
                    examples = [example]
            else:
                examples.append(example)

        if len(examples) > 0:
            tool_examples.append(examples)

        final_exampels: List[Message] = []
        # 2. find the target tool examples or empty messages
        for examples in tool_examples:
            tool_names = [
                example.function_call.get("name", None)
                for example in examples
                if isinstance(example, AIMessage) and example.function_call is not None
            ]
            tool_names = [name for name in tool_names if name]

            if tool_name in tool_names:
                # 3. prepend `tool_name_prefix` to all tool names in examples
                for example in examples:
                    if isinstance(example, AIMessage) and example.function_call is not None:
                        original_tool_name = example.function_call["name"]
                        example.function_call["name"] = f"{self.tool_name_prefix}/{original_tool_name}"
                final_exampels.extend(examples)

        return final_exampels

    def get_tool(self, tool_name: str) -> RemoteTool:
        paths = [path for path in self.paths if path.name == tool_name]
        if len(paths) == 0:
            raise RemoteToolError(
                f"`{tool_name}` not found under RemoteToolkit `{self.tool_name_prefix}`", stage="Loading"
            )
        elif len(paths) > 1:
            raise RemoteToolError(
                f"Found duplicate `{tool_name}` under RemoteToolkit `{self.tool_name_prefix}`",
                stage="Loading",
            )

        TOOL_CLASS = tool_registor.get_tool_class(self.info.title)
        return TOOL_CLASS(
            paths[0],
            self.servers[0].url,
            self.headers,
            self.info.version,
            file_manager=self.file_manager,
            examples=self.get_examples_by_name(tool_name),
            tool_name_prefix=self.tool_name_prefix,
        )

    def to_openapi_dict(self) -> dict:
        """convert plugin schema to openapi spec dict"""
        spec_dict = {
            "openapi": self.openapi,
            "info": asdict(self.info),
            "servers": [asdict(server) for server in self.servers],
            "paths": {tool_view.uri: tool_view.to_openapi_dict() for tool_view in self.paths},
            "components": {
                "schemas": {
                    uri: parameters_view.to_openapi_dict()
                    for uri, parameters_view in self.component_schemas.items()
                }
            },
        }
        return scrub_dict(spec_dict, remove_empty_dict=True) or {}

    def to_openapi_file(self, file: str):
        """generate openapi configuration file

        Args:
            file (str): the path of the openapi yaml file
        """
        spec_dict = self.to_openapi_dict()
        with open(file, "w+", encoding="utf-8") as f:
            safe_dump(spec_dict, f, indent=4)

    @classmethod
    def from_openapi_dict(
        cls,
        openapi_dict: Dict[str, Any],
        access_token: Optional[str] = None,
        file_manager: Optional[FileManager] = None,
    ) -> RemoteToolkit:
        info = EndpointInfo(**openapi_dict["info"])
        servers = [Endpoint(**server) for server in openapi_dict.get("servers", [])]

        if access_token is None:
            access_token = C.get_global_access_token()

        # components
        component_schemas = openapi_dict["components"]["schemas"]
        fields = {}
        for schema_name, schema in component_schemas.items():
            parameter_view = ToolParameterView.from_openapi_dict(schema)
            fields[schema_name] = parameter_view

        # paths
        paths = []
        for path, path_info in openapi_dict.get("paths", {}).items():
            for method, path_method_info in path_info.items():
                paths.append(
                    RemoteToolView.from_openapi_dict(
                        uri=path,
                        method=method,
                        version=info.version,
                        path_info=path_method_info,
                        parameters_views=fields,
                    )
                )

        return RemoteToolkit(
            openapi=openapi_dict["openapi"],
            info=info,
            servers=servers,
            paths=paths,
            component_schemas=fields,
            headers=cls._get_authorization_headers(access_token),
            file_manager=file_manager,
        )

    @classmethod
    def from_openapi_file(
        cls, file: str, access_token: Optional[str] = None, file_manager: Optional[FileManager] = None
    ) -> RemoteToolkit:
        """only support openapi v3.0.1

        Args:
            file (str): the path of openapi yaml file
            access_token (Optional[str]): the path of openapi yaml file
        """
        if not validate_openapi_yaml(file):
            raise RemoteToolError(f"invalid openapi yaml file: {file}", stage="Loading")

        if access_token is None:
            access_token = C.get_global_access_token()

        spec_dict, _ = read_from_filename(file)
        return cls.from_openapi_dict(
            spec_dict, access_token=access_token, file_manager=file_manager  # type: ignore
        )

    @classmethod
    def _get_authorization_headers(cls, access_token: Optional[str]) -> dict:
        headers = {"Content-Type": "application/json"}
        if access_token is None:
            logger.warning("access_token is NOT provided, this may cause 403 HTTP error..")
        else:
            headers["Authorization"] = f"token {access_token}"
        return headers

    @classmethod
    def from_aistudio(
        cls,
        tool_id: str,
        version: Optional[str] = None,
        access_token: Optional[str] = None,
        file_manager: Optional[FileManager] = None,
    ) -> RemoteToolkit:
        from urllib.parse import urlparse

        if access_token is None:
            access_token = C.get_global_access_token()

        aistudio_base_url = os.getenv("AISTUDIO_HUB_BASE_URL", cls._AISTUDIO_HUB_BASE_URL)
        parsed_url = urlparse(aistudio_base_url)
        tool_url = parsed_url._replace(netloc=f"tool-{tool_id}.{parsed_url.netloc}").geturl()
        return cls.from_url(tool_url, version=version, access_token=access_token, file_manager=file_manager)

    @classmethod
    def from_url(
        cls,
        url: str,
        version: Optional[str] = None,
        access_token: Optional[str] = None,
        file_manager: Optional[FileManager] = None,
    ) -> RemoteToolkit:
        # 1. download openapy.yaml file to temp directory
        if not url.endswith("/"):
            url += "/"
        openapi_yaml_url = url + ".well-known/openapi.yaml"

        if version:
            openapi_yaml_url = openapi_yaml_url + "?version=" + version

        if access_token is None:
            access_token = C.get_global_access_token()

        with tempfile.TemporaryDirectory() as temp_dir:
            response = requests.get(openapi_yaml_url, headers=cls._get_authorization_headers(access_token))
            if response.status_code != 200:
                logger.debug(f"The resource requested returned the following headers: {response.headers}")
                raise RemoteToolError(
                    f"`{openapi_yaml_url}` returned {response.status_code}: {response.text}", stage="Loading"
                )

            file_content = response.content.decode("utf-8")
            if not file_content.strip():
                raise RemoteToolError(f"the content is empty from: {openapi_yaml_url}", stage="Loading")

            file_path = os.path.join(temp_dir, "openapi.yaml")
            with open(file_path, "w+", encoding="utf-8") as f:
                f.write(file_content)

            toolkit = RemoteToolkit.from_openapi_file(
                file_path, access_token=access_token, file_manager=file_manager
            )
            for server in toolkit.servers:
                server.url = url

            toolkit.examples = cls.load_remote_examples_yaml(url, access_token)

        return toolkit

    @classmethod
    def load_remote_examples_yaml(cls, url: str, access_token: Optional[str] = None) -> List[Message]:
        """load remote examples by url: url/.well-known/examples.yaml

        Args:
            url (str): the base url of the remote toolkit
        """
        if not url.endswith("/"):
            url += "/"
        examples_yaml_url = url + ".well-known/examples.yaml"
        if not url_file_exists(examples_yaml_url, cls._get_authorization_headers(access_token)):
            return []

        if access_token is None:
            access_token = C.get_global_access_token()

        examples = []
        with tempfile.TemporaryDirectory() as temp_dir:
            response = requests.get(examples_yaml_url, headers=cls._get_authorization_headers(access_token))
            if response.status_code != 200:
                logger.debug(f"The resource requested returned the following headers: {response.headers}")
                raise RemoteToolError(
                    f"`{examples_yaml_url}` returned {response.status_code}: {response.text}",
                    stage="Loading",
                )

            file_content = response.content.decode("utf-8")
            if not file_content.strip():
                raise RemoteToolError(f"the content is empty from: {examples_yaml_url}", stage="Loading")

            file_path = os.path.join(temp_dir, "examples.yaml")
            with open(file_path, "w+", encoding="utf-8") as f:
                f.write(file_content)

            examples = cls.load_examples_yaml(file_path)

        return examples

    @classmethod
    def load_examples_dict(cls, examples_dict: Dict[str, Any]) -> List[Message]:
        messages: List[Message] = []
        for examples in examples_dict["examples"]:
            examples = examples["context"]
            for example in examples:
                if "user" == example["role"]:
                    messages.append(HumanMessage(example["content"]))
                elif "bot" in example["role"]:
                    plugin = example["plugin"]
                    if "operationId" in plugin:
                        function_call: FunctionCall = {
                            "name": plugin["operationId"],
                            "thoughts": plugin["thoughts"],
                            "arguments": json.dumps(plugin["requestArguments"], ensure_ascii=False),
                        }
                    else:
                        function_call = {
                            "name": "",
                            "thoughts": plugin["thoughts"],
                            "arguments": "{}",
                        }  # type: ignore
                    messages.append(AIMessage("", function_call=function_call))
                else:
                    raise RemoteToolError(f"invald role: <{example['role']}>", stage="Loading")
        return messages

    @classmethod
    def load_examples_yaml(cls, file: str) -> List[Message]:
        """load examples from yaml file

        Args:
            file (str): the path of examples file

        Returns:
            List[Message]: the list of messages
        """
        content: dict = read_from_filename(file)[0]  # type: ignore
        if len(content) == 0 or "examples" not in content:
            raise RemoteToolError("invalid examples configuration file", stage="Loading")
        return cls.load_examples_dict(content)

    def function_call_schemas(self) -> List[dict]:
        return [tool.function_call_schema() for tool in self.get_tools()]

from_openapi_file(file, access_token=None, file_manager=None) classmethod

only support openapi v3.0.1

Parameters:

Name Type Description Default
file str

the path of openapi yaml file

required
access_token Optional[str]

the path of openapi yaml file

None
Source code in erniebot-agent/src/erniebot_agent/tools/remote_toolkit.py
@classmethod
def from_openapi_file(
    cls, file: str, access_token: Optional[str] = None, file_manager: Optional[FileManager] = None
) -> RemoteToolkit:
    """only support openapi v3.0.1

    Args:
        file (str): the path of openapi yaml file
        access_token (Optional[str]): the path of openapi yaml file
    """
    if not validate_openapi_yaml(file):
        raise RemoteToolError(f"invalid openapi yaml file: {file}", stage="Loading")

    if access_token is None:
        access_token = C.get_global_access_token()

    spec_dict, _ = read_from_filename(file)
    return cls.from_openapi_dict(
        spec_dict, access_token=access_token, file_manager=file_manager  # type: ignore
    )

get_examples_by_name(tool_name)

get examples by tool-name

Parameters:

Name Type Description Default
tool_name str

the name of the tool

required

Returns:

Type Description
List[Message]

List[Message]: the messages

Source code in erniebot-agent/src/erniebot_agent/tools/remote_toolkit.py
def get_examples_by_name(self, tool_name: str) -> List[Message]:
    """get examples by tool-name

    Args:
        tool_name (str): the name of the tool

    Returns:
        List[Message]: the messages
    """
    # 1. split messages
    tool_examples: List[List[Message]] = []
    examples: List[Message] = []
    for example in self.examples:
        if isinstance(example, HumanMessage):
            if len(examples) == 0:
                examples.append(example)
            else:
                tool_examples.append(examples)
                examples = [example]
        else:
            examples.append(example)

    if len(examples) > 0:
        tool_examples.append(examples)

    final_exampels: List[Message] = []
    # 2. find the target tool examples or empty messages
    for examples in tool_examples:
        tool_names = [
            example.function_call.get("name", None)
            for example in examples
            if isinstance(example, AIMessage) and example.function_call is not None
        ]
        tool_names = [name for name in tool_names if name]

        if tool_name in tool_names:
            # 3. prepend `tool_name_prefix` to all tool names in examples
            for example in examples:
                if isinstance(example, AIMessage) and example.function_call is not None:
                    original_tool_name = example.function_call["name"]
                    example.function_call["name"] = f"{self.tool_name_prefix}/{original_tool_name}"
            final_exampels.extend(examples)

    return final_exampels

load_examples_yaml(file) classmethod

load examples from yaml file

Parameters:

Name Type Description Default
file str

the path of examples file

required

Returns:

Type Description
List[Message]

List[Message]: the list of messages

Source code in erniebot-agent/src/erniebot_agent/tools/remote_toolkit.py
@classmethod
def load_examples_yaml(cls, file: str) -> List[Message]:
    """load examples from yaml file

    Args:
        file (str): the path of examples file

    Returns:
        List[Message]: the list of messages
    """
    content: dict = read_from_filename(file)[0]  # type: ignore
    if len(content) == 0 or "examples" not in content:
        raise RemoteToolError("invalid examples configuration file", stage="Loading")
    return cls.load_examples_dict(content)

load_remote_examples_yaml(url, access_token=None) classmethod

load remote examples by url: url/.well-known/examples.yaml

Parameters:

Name Type Description Default
url str

the base url of the remote toolkit

required
Source code in erniebot-agent/src/erniebot_agent/tools/remote_toolkit.py
@classmethod
def load_remote_examples_yaml(cls, url: str, access_token: Optional[str] = None) -> List[Message]:
    """load remote examples by url: url/.well-known/examples.yaml

    Args:
        url (str): the base url of the remote toolkit
    """
    if not url.endswith("/"):
        url += "/"
    examples_yaml_url = url + ".well-known/examples.yaml"
    if not url_file_exists(examples_yaml_url, cls._get_authorization_headers(access_token)):
        return []

    if access_token is None:
        access_token = C.get_global_access_token()

    examples = []
    with tempfile.TemporaryDirectory() as temp_dir:
        response = requests.get(examples_yaml_url, headers=cls._get_authorization_headers(access_token))
        if response.status_code != 200:
            logger.debug(f"The resource requested returned the following headers: {response.headers}")
            raise RemoteToolError(
                f"`{examples_yaml_url}` returned {response.status_code}: {response.text}",
                stage="Loading",
            )

        file_content = response.content.decode("utf-8")
        if not file_content.strip():
            raise RemoteToolError(f"the content is empty from: {examples_yaml_url}", stage="Loading")

        file_path = os.path.join(temp_dir, "examples.yaml")
        with open(file_path, "w+", encoding="utf-8") as f:
            f.write(file_content)

        examples = cls.load_examples_yaml(file_path)

    return examples

to_openapi_dict()

convert plugin schema to openapi spec dict

Source code in erniebot-agent/src/erniebot_agent/tools/remote_toolkit.py
def to_openapi_dict(self) -> dict:
    """convert plugin schema to openapi spec dict"""
    spec_dict = {
        "openapi": self.openapi,
        "info": asdict(self.info),
        "servers": [asdict(server) for server in self.servers],
        "paths": {tool_view.uri: tool_view.to_openapi_dict() for tool_view in self.paths},
        "components": {
            "schemas": {
                uri: parameters_view.to_openapi_dict()
                for uri, parameters_view in self.component_schemas.items()
            }
        },
    }
    return scrub_dict(spec_dict, remove_empty_dict=True) or {}

to_openapi_file(file)

generate openapi configuration file

Parameters:

Name Type Description Default
file str

the path of the openapi yaml file

required
Source code in erniebot-agent/src/erniebot_agent/tools/remote_toolkit.py
def to_openapi_file(self, file: str):
    """generate openapi configuration file

    Args:
        file (str): the path of the openapi yaml file
    """
    spec_dict = self.to_openapi_dict()
    with open(file, "w+", encoding="utf-8") as f:
        safe_dump(spec_dict, f, indent=4)

erniebot_agent.tools.tool_manager

ToolManager

A ToolManager instance manages tools for an agent.

This implementation is based on ToolsManager in https://github.com/deepset-ai/haystack/blob/main/haystack/agents/base.py

Source code in erniebot-agent/src/erniebot_agent/tools/tool_manager.py
@final
class ToolManager(object):
    """A `ToolManager` instance manages tools for an agent.

    This implementation is based on `ToolsManager` in
    https://github.com/deepset-ai/haystack/blob/main/haystack/agents/base.py
    """

    def __init__(self, tools: List[BaseTool]) -> None:
        super().__init__()
        self._tools: Dict[str, BaseTool] = {}
        for tool in tools:
            self.add_tool(tool)

    def __getitem__(self, tool_name: str) -> BaseTool:
        return self.get_tool(tool_name)

    def add_tool(self, tool: BaseTool) -> None:
        tool_name = tool.tool_name
        if tool_name in self._tools:
            raise ValueError(f"Name {repr(tool_name)} is already registered.")
        self._tools[tool_name] = tool

    def remove_tool(self, tool: BaseTool) -> None:
        tool_name = tool.tool_name
        if tool_name not in self._tools:
            raise ValueError(f"Name {repr(tool_name)} is not registered.")
        if self._tools[tool_name] is not tool:
            raise RuntimeError(f"The tool with the registered name {repr(tool_name)} is not the given tool.")
        self._tools.pop(tool_name)

    def get_tool(self, tool_name: str) -> BaseTool:
        if tool_name not in self._tools:
            raise ValueError(f"Name {repr(tool_name)} is not registered.")
        return self._tools[tool_name]

    def get_tools(self) -> List[BaseTool]:
        return list(self._tools.values())

    def get_tool_names(self) -> str:
        return ", ".join(self._tools.keys())

    def get_tool_names_with_descriptions(self) -> str:
        return "\n".join(
            f"{name}:{json.dumps(tool.function_call_schema())}" for name, tool in self._tools.items()
        )

    def get_tool_schemas(self):
        return [tool.function_call_schema() for tool in self._tools.values()]

erniebot_agent.tools.baizhong_tool

BaizhongSearchTool

Source code in erniebot-agent/src/erniebot_agent/tools/baizhong_tool.py
class BaizhongSearchTool(Tool):
    description: str = "在知识库中检索与用户输入query相关的段落"
    input_type: Type[ToolParameterView] = BaizhongSearchToolInputView
    ouptut_type: Type[ToolParameterView] = BaizhongSearchToolOutputView

    def __init__(
        self, description, db, threshold: float = 0.0, input_type=None, output_type=None, examples=None
    ) -> None:
        super().__init__()
        self.db = db
        self.description = description
        self.few_shot_examples = []
        if input_type is not None:
            self.input_type = input_type
        if output_type is not None:
            self.ouptut_type = output_type
        if examples is not None:
            self.few_shot_examples = examples
        self.threshold = threshold

    async def __call__(self, query: str, top_k: int = 3, filters: Optional[Dict[str, Any]] = None):
        documents = self.db.search(query, top_k, filters)
        documents = [item for item in documents if item["score"] > self.threshold]
        return {"documents": documents}

    @property
    def examples(
        self,
    ) -> List[Any]:
        few_shot_objects: List[Any] = []
        for item in self.few_shot_examples:
            few_shot_objects.append(HumanMessage(item["user"]))
            few_shot_objects.append(
                AIMessage(
                    "",
                    function_call={
                        "name": self.tool_name,
                        "thoughts": item["thoughts"],
                        "arguments": item["arguments"],
                    },
                )
            )

        return few_shot_objects

erniebot_agent.tools.calculator_tool

CalculatorTool

Source code in erniebot-agent/src/erniebot_agent/tools/calculator_tool.py
class CalculatorTool(Tool):
    description: str = "CalculatorTool用于执行数学公式计算"
    input_type: Type[ToolParameterView] = CalculatorToolInputView
    ouptut_type: Type[ToolParameterView] = CalculatorToolOutputView

    async def __call__(self, math_formula: str) -> Dict[str, float]:
        return {"formula_result": eval(math_formula)}

    @property
    def examples(self) -> List[Message]:
        return [
            HumanMessage("请告诉我三加六等于多少?"),
            AIMessage(
                "",
                function_call={
                    "name": self.tool_name,
                    "thoughts": f"用户想知道3加6等于多少,我可以使用{self.tool_name}工具来计算公式,其中`math_formula`字段的内容为:'3+6'。",
                    "arguments": '{"math_formula": "3+6"}',
                },
                token_usage={
                    "prompt_tokens": 5,
                    "completion_tokens": 7,
                },  # TODO: Functional AIMessage will not add in the memory, will it add token_usage?
            ),
            HumanMessage("一加八再乘以5是多少?"),
            AIMessage(
                "",
                function_call={
                    "name": self.tool_name,
                    "thoughts": f"用户想知道1加8再乘5等于多少,我可以使用{self.tool_name}工具来计算公式,"
                    "其中`math_formula`字段的内容为:'(1+8)*5'。",
                    "arguments": '{"math_formula": "(1+8)*5"}',
                },
                token_usage={"prompt_tokens": 5, "completion_tokens": 7},  # For test only
            ),
            HumanMessage("我想知道十二除以四再加五等于多少?"),
            AIMessage(
                "",
                function_call={
                    "name": self.tool_name,
                    "thoughts": f"用户想知道12除以4再加5等于多少,我可以使用{self.tool_name}工具来计算公式,"
                    "其中`math_formula`字段的内容为:'12/4+5'。",
                    "arguments": '{"math_formula": "12/4+5"}',
                },
                token_usage={"prompt_tokens": 5, "completion_tokens": 7},  # For test only
            ),
        ]

erniebot_agent.tools.current_time_tool

CurrentTimeTool

Source code in erniebot-agent/src/erniebot_agent/tools/current_time_tool.py
class CurrentTimeTool(Tool):
    description: str = "CurrentTimeTool 用于获取当前时间"
    ouptut_type: Type[ToolParameterView] = CurrentTimeToolOutputView

    async def __call__(self) -> Dict[str, str]:
        return {"current_time": datetime.strftime(datetime.now(), "%Y年%m月%d日 %H时%M分%S秒")}

    @property
    def examples(self) -> List[Message]:
        return [
            HumanMessage("现在几点钟了"),
            AIMessage(
                "",
                function_call={
                    "name": self.tool_name,
                    "thoughts": f"用户想知道现在几点了,我可以使用{self.tool_name}来获取当前时间,并从其中获得当前小时时间。",
                    "arguments": "{}",
                },
                token_usage={"prompt_tokens": 5, "completion_tokens": 7},  # For test only
            ),
            HumanMessage("现在是什么时候?"),
            AIMessage(
                "",
                function_call={
                    "name": self.tool_name,
                    "thoughts": f"用户想知道现在几点了,我可以使用{self.tool_name}来获取当前时间",
                    "arguments": "{}",
                },
                token_usage={"prompt_tokens": 5, "completion_tokens": 7},  # For test only
            ),
        ]

erniebot_agent.tools.chat_with_eb

ChatWithEB

Source code in erniebot-agent/src/erniebot_agent/tools/chat_with_eb.py
class ChatWithEB(Tool):
    description: str = (
        "ChatWithEB是一款根据用户的问题,向EB生成式大语言模型进行提问,并获取EB回答结果的工具。EB一般能解决知识型问答、文本创作、信息查询、信息检索等基础的文本生成和信息检索功能"
    )
    input_type: Type[ToolParameterView] = ChatWithEBInputView
    ouptut_type: Type[ToolParameterView] = ChatWithEBOutputView

    def __init__(self, llm: ERNIEBot):
        self.llm = llm

    async def __call__(self, query: str) -> Dict[str, str]:
        response = await self.llm.chat([HumanMessage(query)])
        return {"response": response.content}