Skip to content

Index

FormAgentExecutor

Bases: StateGraph

Source code in converso/conversational_engine/form_agent/form_agent_executor.py
class FormAgentExecutor(StateGraph):

    MAX_INTERMEDIATE_STEPS = 5

    def __init__(
        self,
        tools: Sequence[Type[Any]] = [],
        on_tool_start: callable = None,
        on_tool_end: callable = None,
    ) -> None:
        super().__init__(AgentState)

        self._on_tool_start = on_tool_start
        self._on_tool_end = on_tool_end
        self._tools = tools
        self.__build_graph()

    def __build_graph(self):

        self.add_node("agent", self.call_agent)
        self.add_node("tool", self.call_tool)

        self.add_conditional_edges(
            "agent",
            self.should_continue_after_agent,
            {
                "tool": "tool",
                "error": "agent",
                "end": END
            }
        )

        self.add_conditional_edges(
            "tool",
            self.should_continue_after_tool,
            {
                "error": "agent",
                "continue": "agent",
                "end": END
            }
        )

        self.set_entry_point("agent")
        self.app = self.compile()

    def get_tools(self, state: AgentState):
        return filter_active_tools(self._tools[:], state)

    def get_tool_by_name(self, name: str, agent_state: AgentState):
        return next((tool for tool in self.get_tools(
            agent_state) if tool.name == name), None)

    def get_tool_executor(self, state: AgentState):
        return FormToolExecutor(self.get_tools(state))

    def should_continue_after_agent(self, state: AgentState):
        if state.get("error"):
            return "error"
        elif isinstance(state.get("agent_outcome"), AgentFinish):
            return "end"
        if isinstance(state.get("agent_outcome"), list):
            return "tool"

    def should_continue_after_tool(self, state: AgentState):
        if state.get("error"):
            return "error"
        elif isinstance(state.get("tool_outcome"), FormToolOutcome) and state.get("tool_outcome").return_direct:
            return "end"
        else:
            return "continue"

    def build_model(self, state: AgentState):
        return ModelFactory.build_model(
            state=state,
            tools=self.get_tools(state)
        )

    # Define the function that calls the model
    def call_agent(self, state: AgentState):
        try:
            # Cap the number of intermediate steps in a prompt to 5
            if len(state.get("intermediate_steps")
                   ) > self.MAX_INTERMEDIATE_STEPS:
                state["intermediate_steps"] = state.get(
                    "intermediate_steps")[-self.MAX_INTERMEDIATE_STEPS:]

            agent_outcome = self.build_model(state=state).invoke(state)

            updates = {
                "agent_outcome": agent_outcome,
                "tool_choice": None,  # Reset the function call
                "tool_outcome": None,  # Reset the tool outcome
                "error": None  # Reset the error
            }
            return updates
        # TODO: if other exceptions are raised, we should handle them here
        except OutputParserException as e:
            traceback.print_exc()
            updates = {"error": str(e)}
            return updates

    def on_tool_start(self, tool: BaseTool, tool_input: dict):
        if self._on_tool_start:
            self._on_tool_start(tool, tool_input)

    def on_tool_end(self, tool: BaseTool, tool_output: Any):
        if self._on_tool_end:
            self._on_tool_end(tool, tool_output)

    def call_tool(self, state: AgentState):
        try:
            actions = state.get("agent_outcome")
            intermediate_steps = []

            for action in actions:
                tool = self.get_tool_by_name(action.tool, state)

                self.on_tool_start(tool=tool, tool_input=action.tool_input)
                tool_outcome = self.get_tool_executor(state).invoke(action)
                self.on_tool_end(tool=tool, tool_output=tool_outcome.output)

                intermediate_steps.append(
                    (
                        action,
                        FunctionMessage(
                            content=str(tool_outcome.output),
                            name=action.tool
                        )
                    )
                )

            updates = {
                **tool_outcome.state_update,
                "intermediate_steps": intermediate_steps,
                "tool_outcome": tool_outcome,  # this isn't really correct with multiple tools
                "agent_outcome": None,
                "error": None
            }

        except Exception as e:
            traceback.print_exc()
            updates = {
                "intermediate_steps": [(action, FunctionMessage(
                    content=f"{type(e).__name__}: {str(e)}",
                    name=action.tool
                ))],
                "error": str(e)
            }
        finally:
            return updates

    def parse_output(self, graph_output: dict) -> str:
        """
        Parses the final state of the graph.
        Theoretically, only one between tool_outcome and agent_outcome are set.
        Returns the str to be considered the output of the graph.
        """

        state = graph_output[END]

        output = None
        if state.get("tool_outcome"):
            output = state.get("tool_outcome").output
        elif state.get("agent_outcome"):
            output = state.get("agent_outcome").return_values["output"]

        return output

parse_output(graph_output)

Parses the final state of the graph. Theoretically, only one between tool_outcome and agent_outcome are set. Returns the str to be considered the output of the graph.

Source code in converso/conversational_engine/form_agent/form_agent_executor.py
def parse_output(self, graph_output: dict) -> str:
    """
    Parses the final state of the graph.
    Theoretically, only one between tool_outcome and agent_outcome are set.
    Returns the str to be considered the output of the graph.
    """

    state = graph_output[END]

    output = None
    if state.get("tool_outcome"):
        output = state.get("tool_outcome").output
    elif state.get("agent_outcome"):
        output = state.get("agent_outcome").return_values["output"]

    return output

FormTool

Bases: StructuredTool, ABC

Source code in converso/conversational_engine/form_agent/form_tool.py
class FormTool(StructuredTool, ABC):
    form: BaseModel = None
    state: Union[FormToolState | None] = None
    skip_confirm: Optional[bool] = False

    # Backup attributes for handling changes in the state
    args_schema_: Optional[Type[BaseModel]] = None
    description_: Optional[str] = None
    name_: Optional[str] = None

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.args_schema_ = None
        self.name_ = None
        self.description_ = None
        self.init_state()

    def init_state(self):
        state_initializer = {
            None: self.enter_inactive_state,
            FormToolState.INACTIVE: self.enter_inactive_state,
            FormToolState.ACTIVE: self.enter_active_state,
            FormToolState.FILLED: self.enter_filled_state
        }
        state_initializer[self.state]()

    def enter_inactive_state(self):
        # Guard so that we don't overwrite the original args_schema if
        # set_inactive_state is called multiple times
        if not self.state == FormToolState.INACTIVE:
            self.state = FormToolState.INACTIVE
            self.name_ = self.name
            self.name = f"{self.name_}Start"
            self.description_ = self.description
            self.description = f"""Starts the form {
                self.name}, which {self.description_}"""
            self.args_schema_ = self.args_schema
            self.args_schema = FormToolInactivePayload

    def enter_active_state(self):
        # if not self.state == FormToolState.ACTIVE:
        self.state = FormToolState.ACTIVE
        self.name = f"{self.name_}Update"
        self.description = f"""Updates data for form {
            self.name}, which {self.description_}"""
        self.args_schema = make_optional_model(self.args_schema_)
        if not self.form:
            self.form = self.args_schema()
        elif isinstance(self.form, str):
            self.form = self.args_schema(**json.loads(self.form))

    def enter_filled_state(self):
        self.state = FormToolState.FILLED
        self.name = f"{self.name_}Finalize"
        self.description = f"""Finalizes form {
            self.name}, which {self.description_}"""
        self.args_schema = make_optional_model(self.args_schema_)
        if not self.form:
            self.form = self.args_schema()
        elif isinstance(self.form, str):
            self.form = self.args_schema(**json.loads(self.form))
        self.args_schema = FormToolConfirmPayload

    def activate(
        self,
        *args,
        run_manager: Optional[CallbackManagerForToolRun] = None,
        **kwargs
    ) -> FormToolOutcome:
        self.enter_active_state()
        return FormToolOutcome(
            output=f"""Starting form {
                self.name}. If the user as already provided some information, call {self.name}.""",
            active_form_tool=self,
            tool_choice=self.name
        )

    def update(
        self,
        *args,
        run_manager: Optional[CallbackManagerForToolRun] = None,
        **kwargs
    ) -> FormToolOutcome:
        self._update_form(**kwargs)
        if self.is_form_filled():
            self.enter_filled_state()
            if self.skip_confirm:
                return self.finalize(confirm=True)
            else:
                return FormToolOutcome(
                    active_form_tool=self,
                    output="Form is filled. Ask the user to confirm the information."
                )
        else:
            return FormToolOutcome(
                active_form_tool=self,
                output="Form updated with the provided information. Ask the user for the next field."
            )

    def finalize(
        self,
        *args,
        run_manager: Optional[CallbackManagerForToolRun] = None,
        **kwargs
    ) -> FormToolOutcome:
        if kwargs.get("confirm"):
            # The FormTool could use self.form to get the data, but we pass it as kwargs to
            # keep the signature consistent with _run
            result = self._run_when_complete(**self.form.model_dump())
            return FormToolOutcome(
                active_form_tool=None,
                output=result,
                return_direct=self.return_direct
            )
        else:
            self.enter_active_state()
            return FormToolOutcome(
                active_form_tool=self,
                output="Ask the user to update the form."
            )

    def _run(
        self,
        *args,
        run_manager: Optional[CallbackManagerForToolRun] = None,
        **kwargs
    ) -> str:
        match self.state:
            case FormToolState.INACTIVE:
                return self.activate(*args, **kwargs, run_manager=run_manager)

            case FormToolState.ACTIVE:
                return self.update(*args, **kwargs, run_manager=run_manager)

            case FormToolState.FILLED:
                return self.finalize(*args, **kwargs, run_manager=run_manager)

    @abstractmethod
    def _run_when_complete(self) -> str:
        """
        Should raise an exception if something goes wrong.
        The message should describe the error and will be sent back to the agent to try to fix it.
        """

    def _update_form(self, **kwargs):
        try:
            model_class = type(self.form)
            data = self.form.model_dump()
            data.update(kwargs)
            # Recreate the model with the new data merged to the old one
            # This allows to validate multiple fields at once
            self.form = model_class(**data)
        except ValidationError as e:
            raise ToolException(str(e))

    def get_next_field_to_collect(
        self,
        run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> str:
        """
        The default implementation returns the first field that is not set.
        """
        if self.state == FormToolState.FILLED:
            return None

        for field_name, field_info in self.args_schema.__fields__.items():
            if not getattr(self.form, field_name):
                return field_name

    def is_form_filled(self) -> bool:
        return self.get_next_field_to_collect() is None

    def get_tool_start_message(self, input: dict) -> str:
        message = ""
        match self.state:
            case FormToolState.INACTIVE:
                message = f"Starting {self.name}"
            case FormToolState.ACTIVE:
                message = f"Updating form for {self.name}"
            case FormToolState.FILLED:
                message = f"Completed {self.name}"
        return message

get_next_field_to_collect(run_manager=None)

The default implementation returns the first field that is not set.

Source code in converso/conversational_engine/form_agent/form_tool.py
def get_next_field_to_collect(
    self,
    run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
    """
    The default implementation returns the first field that is not set.
    """
    if self.state == FormToolState.FILLED:
        return None

    for field_name, field_info in self.args_schema.__fields__.items():
        if not getattr(self.form, field_name):
            return field_name

filter_active_tools(tools, context)

Form tools are replaced by their activators if they are not active.

Source code in converso/conversational_engine/form_agent/form_agent_executor.py
def filter_active_tools(
    tools: Sequence[BaseTool],
    context: AgentState
):
    """
    Form tools are replaced by their activators if they are not active.
    """
    if context.get("active_form_tool"):
        # If a form_tool is active, it is the only form tool available
        base_tools = [
            tool for tool in tools if not isinstance(
                tool, FormTool)]
        tools = [
            *base_tools,
            context.get("active_form_tool"),
            FormReset(context=context)
        ]
    return tools