Skip to content

Memory Module

erniebot_agent.memory

Memory

The base class of memory

Attributes:

Name Type Description
msg_manager MessageManager

the message manager of a conversation.

Returns:

Type Description

A memory object.

Source code in erniebot-agent/src/erniebot_agent/memory/base.py
class Memory:
    """
    The base class of memory

    Attributes:
        msg_manager (MessageManager): the message manager of a conversation.

    Returns:
        A memory object.
    """

    def __init__(self):
        self.msg_manager = MessageManager()

    def add_messages(self, messages: List[Message]):
        """Add a list of messages to memory."""
        for message in messages:
            self.add_message(message)

    def add_message(self, message: Message):
        """Add a message to memory."""
        if isinstance(message, AIMessage):
            self.msg_manager.update_last_message_token_count(message.query_tokens_count)
        self.msg_manager.add_message(message)

    def get_messages(self) -> List[Message]:
        """Get all the messages in memory."""
        return self.msg_manager.retrieve_messages()

    def get_system_message(self) -> SystemMessage:
        """Get the system message in memory."""
        return self.msg_manager.system_message

    def clear_chat_history(self):
        """Reset the memory."""
        self.msg_manager.clear_messages()

add_message(message)

Add a message to memory.

Source code in erniebot-agent/src/erniebot_agent/memory/base.py
def add_message(self, message: Message):
    """Add a message to memory."""
    if isinstance(message, AIMessage):
        self.msg_manager.update_last_message_token_count(message.query_tokens_count)
    self.msg_manager.add_message(message)

add_messages(messages)

Add a list of messages to memory.

Source code in erniebot-agent/src/erniebot_agent/memory/base.py
def add_messages(self, messages: List[Message]):
    """Add a list of messages to memory."""
    for message in messages:
        self.add_message(message)

clear_chat_history()

Reset the memory.

Source code in erniebot-agent/src/erniebot_agent/memory/base.py
def clear_chat_history(self):
    """Reset the memory."""
    self.msg_manager.clear_messages()

get_messages()

Get all the messages in memory.

Source code in erniebot-agent/src/erniebot_agent/memory/base.py
def get_messages(self) -> List[Message]:
    """Get all the messages in memory."""
    return self.msg_manager.retrieve_messages()

get_system_message()

Get the system message in memory.

Source code in erniebot-agent/src/erniebot_agent/memory/base.py
def get_system_message(self) -> SystemMessage:
    """Get the system message in memory."""
    return self.msg_manager.system_message

WholeMemory

The type of memory that include all the messages.

Attributes:

Name Type Description
msg_manager MessageManager

the message manager of a conversation.

Source code in erniebot-agent/src/erniebot_agent/memory/whole_memory.py
class WholeMemory(Memory):
    """
    The type of memory that include all the messages.

    Attributes:
        msg_manager (MessageManager): the message manager of a conversation.
    """

    def __init__(self):
        super().__init__()

LimitTokensMemory

The class of memory that limits the number of tokens. If number of tokens in the context >= max_token_limit, it will pop message from msg_manager.

Parameters:

Name Type Description Default
max_token_limit int

The maximum number of tokens in the context.

3000

Attributes:

Name Type Description
max_token_limit int

The maximum number of tokens in the context.

mem_token_count int

The number of tokens in the context.

Examples:

.. code-block:: python
    from erniebot_agent.memory import LimitTokensMemory
    memory = LimitTokensMemory(max_token_limit=3000)
    memory.add_message(AIMessage("Hello world!"))
Source code in erniebot-agent/src/erniebot_agent/memory/limit_tokens_memory.py
class LimitTokensMemory(Memory):
    """
    The class of memory that limits the number of tokens.
    If number of tokens in the context >= max_token_limit, it will pop message from msg_manager.

    Args:
        max_token_limit (int): The maximum number of tokens in the context.

    Attributes:
        max_token_limit (int): The maximum number of tokens in the context.
        mem_token_count (int): The number of tokens in the context.

    Examples:

        .. code-block:: python
            from erniebot_agent.memory import LimitTokensMemory
            memory = LimitTokensMemory(max_token_limit=3000)
            memory.add_message(AIMessage("Hello world!"))

    """

    def __init__(self, max_token_limit=3000):
        super().__init__()
        self.max_token_limit = max_token_limit
        self.mem_token_count = 0

        assert (
            max_token_limit is None
        ) or max_token_limit > 0, "max_token_limit should be None or positive integer, \
                but got {max_token_limit}".format(
            max_token_limit=max_token_limit
        )

    def add_message(self, message: Message):
        """
        Add a message to memory. Prune the message if number of tokens in memory >= max_token_limit.

        Args:
            message (Message): The message to be added.

        Returns:
            None
        """
        super().add_message(message)
        # TODO(shiyutang): 仅在添加AIMessage时截断会导致HumanMessage传入到LLM时可能长度超限
        # 最优方案为每条message产生时确定token_count,从而在每次加入message时都进行prune_message
        if isinstance(message, AIMessage):
            self.prune_message()

    def prune_message(self):
        """
        Prune the message if number of tokens in memory >= max_token_limit.

        Raises:
            RuntimeError: If the message is empty after pruning.

        Returns:
            None

        """
        self.mem_token_count += self.msg_manager.messages[-1].token_count
        self.mem_token_count += self.msg_manager.messages[-2].token_count  # add human message token length
        if self.max_token_limit is not None:
            while self.mem_token_count > self.max_token_limit:
                deleted_message = self.msg_manager.pop_message()
                self.mem_token_count -= deleted_message.token_count
            else:
                if len(self.get_messages()) == 0:
                    raise RuntimeError(
                        "The messsage is now empty. \
It indicates {} which takes up {} tokens and exeeded tokens limits of {} tokens.".format(
                            deleted_message, deleted_message.token_count, self.max_token_limit
                        )
                    )

add_message(message)

Add a message to memory. Prune the message if number of tokens in memory >= max_token_limit.

Parameters:

Name Type Description Default
message Message

The message to be added.

required

Returns:

Type Description

None

Source code in erniebot-agent/src/erniebot_agent/memory/limit_tokens_memory.py
def add_message(self, message: Message):
    """
    Add a message to memory. Prune the message if number of tokens in memory >= max_token_limit.

    Args:
        message (Message): The message to be added.

    Returns:
        None
    """
    super().add_message(message)
    # TODO(shiyutang): 仅在添加AIMessage时截断会导致HumanMessage传入到LLM时可能长度超限
    # 最优方案为每条message产生时确定token_count,从而在每次加入message时都进行prune_message
    if isinstance(message, AIMessage):
        self.prune_message()

prune_message()

Prune the message if number of tokens in memory >= max_token_limit.

Raises:

Type Description
RuntimeError

If the message is empty after pruning.

Returns:

Type Description

None

Source code in erniebot-agent/src/erniebot_agent/memory/limit_tokens_memory.py
    def prune_message(self):
        """
        Prune the message if number of tokens in memory >= max_token_limit.

        Raises:
            RuntimeError: If the message is empty after pruning.

        Returns:
            None

        """
        self.mem_token_count += self.msg_manager.messages[-1].token_count
        self.mem_token_count += self.msg_manager.messages[-2].token_count  # add human message token length
        if self.max_token_limit is not None:
            while self.mem_token_count > self.max_token_limit:
                deleted_message = self.msg_manager.pop_message()
                self.mem_token_count -= deleted_message.token_count
            else:
                if len(self.get_messages()) == 0:
                    raise RuntimeError(
                        "The messsage is now empty. \
It indicates {} which takes up {} tokens and exeeded tokens limits of {} tokens.".format(
                            deleted_message, deleted_message.token_count, self.max_token_limit
                        )
                    )

SlidingWindowMemory

This class controls max number of rounds of message using sliding window tactic. Each round contains a piece of human message and a piece of AI message.

Attributes:

Name Type Description
max_round(int)

Max number of rounds.

retained_round(int)

The first number of rounds of memory will be preserverd. Default to 0.

Raises:

Type Description
ValueError

If max_round is not positive integer.

Source code in erniebot-agent/src/erniebot_agent/memory/sliding_window_memory.py
class SlidingWindowMemory(Memory):
    """
    This class controls max number of rounds of message using sliding window tactic.
    Each round contains a piece of human message and a piece of AI message.

    Attributes:
        max_round(int): Max number of rounds.
        retained_round(int): The first number of rounds of memory will be preserverd. Default to 0.

    Raises:
        ValueError: If max_round is not positive integer.
    """

    def __init__(self, max_round: int, retained_round: int = 0) -> None:
        """
        Args:
            max_round(int): Max number of rounds(round: human message and AI message).
            retained_round(int): The number remaining_memory rounds of memory to be retained. Default to 0.

        """

        super().__init__()
        self.max_round = max_round
        self.retained_round = retained_round

        if max_round <= 0:
            raise ValueError(f"max_round should be positive integer, but got {max_round}")

    def add_message(self, message: Message) -> None:
        """Add a message to memory."""
        super().add_message(message=message)
        self.prune_message()

    def prune_message(self) -> None:
        """Prune memory to max_round if necessary."""
        while len(self.get_messages()) > self.max_round * 2:
            self.msg_manager.pop_message(self.retained_round * 2)

            num_message = len(self.get_messages())
            if num_message % 2 == 0:
                if len(self.get_messages()) > self.retained_round * 2:
                    self.msg_manager.pop_message(self.retained_round * 2)
                else:
                    self.msg_manager.pop_message(num_message - 1)

__init__(max_round, retained_round=0)

Parameters:

Name Type Description Default
max_round(int)

Max number of rounds(round: human message and AI message).

required
retained_round(int)

The number remaining_memory rounds of memory to be retained. Default to 0.

required
Source code in erniebot-agent/src/erniebot_agent/memory/sliding_window_memory.py
def __init__(self, max_round: int, retained_round: int = 0) -> None:
    """
    Args:
        max_round(int): Max number of rounds(round: human message and AI message).
        retained_round(int): The number remaining_memory rounds of memory to be retained. Default to 0.

    """

    super().__init__()
    self.max_round = max_round
    self.retained_round = retained_round

    if max_round <= 0:
        raise ValueError(f"max_round should be positive integer, but got {max_round}")

add_message(message)

Add a message to memory.

Source code in erniebot-agent/src/erniebot_agent/memory/sliding_window_memory.py
def add_message(self, message: Message) -> None:
    """Add a message to memory."""
    super().add_message(message=message)
    self.prune_message()

prune_message()

Prune memory to max_round if necessary.

Source code in erniebot-agent/src/erniebot_agent/memory/sliding_window_memory.py
def prune_message(self) -> None:
    """Prune memory to max_round if necessary."""
    while len(self.get_messages()) > self.max_round * 2:
        self.msg_manager.pop_message(self.retained_round * 2)

        num_message = len(self.get_messages())
        if num_message % 2 == 0:
            if len(self.get_messages()) > self.retained_round * 2:
                self.msg_manager.pop_message(self.retained_round * 2)
            else:
                self.msg_manager.pop_message(num_message - 1)