rasa培训课程:Rasa微服务Action自定义及Slot Validation详解

539 阅读7分钟

第3课:Rasa微服务Action自定义及Slot Validation详解

\

做Rasa智能对话机器人,微服务Action自定义开发是绕不开的内容,只要做开发,一定会用到Actions、Tracker、Dispatcher、Events

\

Action类是任何自定义操作的基类。要定义自定义操作,请创建操作类的子类,并覆盖两个必需的方法:name和run。当收到运行操作的请求时,Action服务器将根据其name方法的返回值调用操作。

\

\

官网文档的一个示例:

\

class MyCustomAction(Action):

\

def name(self) -> Text:

\

return "action_name"

\

async def run(

self, dispatcher, tracker: Tracker, domain: Dict[Text, Any],

) -> List[Dict[Text, Any]]:

\

return []

rasa server的Action是一个普通的类

\

class Action:

"""Next action to be taken in response to a dialogue state."""

\

def name(self) -> Text:

"""Unique identifier of this simple action."""

\

raise NotImplementedError

\

async def run(

self,

output_channel: "OutputChannel",

nlg: "NaturalLanguageGenerator",

tracker: "DialogueStateTracker",

domain: "Domain",

) -> List[Event]:

"""Execute the side effects of this action.

\

Args:

nlg: which nlg to use for response generation

output_channel: output_channel to which to send the resulting message.

tracker (DialogueStateTracker): the state tracker for the current

user. You can access slot values using

tracker.get_slot(slot_name) and the most recent user

message is tracker.latest_message.text.

domain (Domain): the bot's domain

\

Returns:

A list of :class:rasa.core.events.Event instances

"""

raise NotImplementedError

\

def str(self) -> Text:

"""Returns text representation of form."""

return f"{self.class.name}('{self.name()}')"

\

def event_for_successful_execution(

self, prediction: PolicyPrediction

) -> ActionExecuted:

"""Event which should be logged for the successful execution of this action.

\

Args:

prediction: Prediction which led to the execution of this event.

\

Returns:

Event which should be logged onto the tracker.

"""

return ActionExecuted(

self.name(),

prediction.policy_name,

prediction.max_confidence,

hide_rule_turn=prediction.hide_rule_turn,

metadata=prediction.action_metadata,

)

\

rasa sdk server的action类:

\

class Action:

"""Next action to be taken in response to a dialogue state."""

\

def name(self) -> Text:

"""Unique identifier of this simple action."""

\

raise NotImplementedError("An action must implement a name")

\

async def run(

self,

dispatcher: "CollectingDispatcher",

tracker: Tracker,

domain: "DomainDict",

) -> List[Dict[Text, Any]]:

"""Execute the side effects of this action.

\

Args:

dispatcher: the dispatcher which is used to

send messages back to the user. Use

dispatcher.utter_message() for sending messages.

tracker: the state tracker for the current

user. You can access slot values using

tracker.get_slot(slot_name), the most recent user message

is tracker.latest_message.text and any other

rasa_sdk.Tracker property.

domain: the bot's domain

Returns:

A dictionary of rasa_sdk.events.Event instances that is

returned through the endpoint

"""

\

raise NotImplementedError("An action must implement its run method")

\

def str(self) -> Text:

return f"Action('{self.name()}')"

\

\

reminder机器人 复写run方法,重点关注dispatcher: CollectingDispatcher, tracker: Tracker,

\

class ActionSetReminder(Action):

"""Schedules a reminder, supplied with the last message's entities."""

\

def name(self) -> Text:

return "action_set_reminder"

\

async def run(

self,

dispatcher: CollectingDispatcher,

tracker: Tracker,

domain: Dict[Text, Any],

) -> List[Dict[Text, Any]]:

\

dispatcher.utter_message("I will remind you in 5 seconds.")

\

date = datetime.datetime.now() + datetime.timedelta(seconds=5)

entities = tracker.latest_message.get("entities")

\

reminder = ReminderScheduled(

"EXTERNAL_reminder",

trigger_date_time=date,

entities=entities,

name="my_reminder",

kill_on_user_message=False,

)

\

return [reminder]

\

\

executor.py的CollectingDispatcher类

\

\

class CollectingDispatcher:

"""Send messages back to user"""

\

def init(self) -> None:

\

self.messages: List[Dict[Text, Any]] = []

\

def utter_message(

self,

text: Optional[Text] = None,

image: Optional[Text] = None,

json_message: Optional[Dict[Text, Any]] = None,

template: Optional[Text] = None,

response: Optional[Text] = None,

attachment: Optional[Text] = None,

buttons: Optional[List[Dict[Text, Any]]] = None,

elements: Optional[List[Dict[Text, Any]]] = None,

**kwargs: Any,

) -> None:

"""Send a text to the output channel."""

if template and not response:

response = template

warnings.warn(

"Please pass the parameter response instead of template "

"to utter_message. template will be deprecated in Rasa 3.0.0. ",

FutureWarning,

)

message = {

"text": text,

"buttons": buttons or [],

"elements": elements or [],

"custom": json_message or {},

"template": response,

"response": response,

"image": image,

"attachment": attachment,

}

message.update(kwargs)

\

self.messages.append(message)

\

deprecated

def utter_custom_message(self, *elements: Dict[Text, Any], **kwargs: Any) -> None:

warnings.warn(

"Use of utter_custom_message is deprecated. "

"Use utter_message(elements=<list of elements>) instead.",

FutureWarning,

)

self.utter_message(elements=list(elements), **kwargs)

\

def utter_elements(self, *elements: Dict[Text, Any], **kwargs: Any) -> None:

"""Sends a message with custom elements to the output channel."""

warnings.warn(

"Use of utter_elements is deprecated. "

"Use utter_message(elements=<list of elements>) instead.",

FutureWarning,

)

self.utter_message(elements=list(elements), **kwargs)

\

def utter_button_message(

self, text: Text, buttons: List[Dict[Text, Any]], **kwargs: Any

) -> None:

"""Sends a message with buttons to the output channel."""

warnings.warn(

"Use of utter_button_message is deprecated. "

"Use utter_message(text=<text> , buttons=<list of buttons>) instead.",

FutureWarning,

)

\

self.utter_message(text=text, buttons=buttons, **kwargs)

\

def utter_attachment(self, attachment: Text, **kwargs: Any) -> None:

"""Send a message to the client with attachments."""

warnings.warn(

"Use of utter_attachment is deprecated. "

"Use utter_message(attachment=<attachment>) instead.",

FutureWarning,

)

\

self.utter_message(attachment=attachment, **kwargs)

\

noinspection PyUnusedLocal

def utter_button_template(

self,

template: Text,

buttons: List[Dict[Text, Any]],

tracker: Tracker,

silent_fail: bool = False,

**kwargs: Any,

) -> None:

"""Sends a message template with buttons to the output channel."""

warnings.warn(

"Use of utter_button_template is deprecated. "

"Use utter_message(template=<template name>, buttons=<list of buttons>) instead.",

FutureWarning,

)

\

self.utter_message(template=template, buttons=buttons, **kwargs)

\

noinspection PyUnusedLocal

def utter_template(

self, template: Text, tracker: Tracker, silent_fail: bool = False, **kwargs: Any

) -> None:

"""Send a message to the client based on a template."""

warnings.warn(

"Use of utter_template is deprecated. "

"Use utter_message(response=<template_name>) instead.",

FutureWarning,

)

\

self.utter_message(response=template, **kwargs)

\

def utter_custom_json(self, json_message: Dict[Text, Any], **kwargs: Any) -> None:

"""Sends custom json to the output channel."""

warnings.warn(

"Use of utter_custom_json is deprecated. "

"Use utter_message(json_message=<message dict>) instead.",

FutureWarning,

)

\

self.utter_message(json_message=json_message, **kwargs)

\

def utter_image_url(self, image: Text, **kwargs: Any) -> None:

"""Sends url of image attachment to the output channel."""

warnings.warn(

"Use of utter_image_url is deprecated. "

"Use utter_message(image=<image url>) instead.",

FutureWarning,

)

\

self.utter_message(image=image, **kwargs)

Tracker类:

\

class Tracker:

"""Maintains the state of a conversation."""

\

@classmethod

def from_dict(cls, state: "TrackerState") -> "Tracker":

"""Create a tracker from dump."""

\

return Tracker(

state["sender_id"],

state.get("slots", {}),

state.get("latest_message", {}),

state.get("events", []),

state.get("paused", False),

state.get("followup_action"),

state.get("active_loop", state.get("active_form", {})),

state.get("latest_action_name"),

)

\

def init(

self,

sender_id: Text,

slots: Dict[Text, Any],

latest_message: Optional[Dict[Text, Any]],

events: List[Dict[Text, Any]],

paused: bool,

followup_action: Optional[Text],

active_loop: Dict[Text, Any],

latest_action_name: Optional[Text],

) -> None:

"""Initialize the tracker."""

\

list of previously seen events

self.events = events

id of the source of the messages

self.sender_id = sender_id

slots that can be filled in this domain

self.slots = slots

\

self.followup_action = followup_action

\

self._paused = paused

\

latest_message is parse_data,

which is a dict: {"intent": UserUttered.intent,

"entities": UserUttered.entities,

"text": text}

self.latest_message = latest_message if latest_message else {}

self.active_loop = active_loop

self.latest_action_name = latest_action_name

\

@property

def active_form(self) -> Dict[Text, Any]:

warnings.warn(

"Use of active_form is deprecated. Please use `active_loop insteaad.",

DeprecationWarning,

)

return self.active_loop

\

def current_state(self) -> Dict[Text, Any]:

"""Return the current tracker state as an object."""

\

if len(self.events) > 0:

latest_event_time = self.events[-1].get("timestamp")

else:

latest_event_time = None

\

return {

"sender_id": self.sender_id,

"slots": self.slots,

"latest_message": self.latest_message,

"latest_event_time": latest_event_time,

"paused": self.is_paused(),

"events": self.events,

"latest_input_channel": self.get_latest_input_channel(),

"active_loop": self.active_loop,

"latest_action_name": self.latest_action_name,

}

\

def current_slot_values(self) -> Dict[Text, Any]:

"""Return the currently set values of the slots"""

return self.slots

\

def get_slot(self, key) -> Optional[Any]:

"""Retrieves the value of a slot."""

\

if key in self.slots:

return self.slots[key]

else:

logger.info(f"Tried to access non existent slot '{key}'.")

return None

\

def get_latest_entity_values(

self,

entity_type: Text,

entity_role: Optional[Text] = None,

entity_group: Optional[Text] = None,

) -> Iterator[Text]:

"""Get entity values found for the passed entity type and optional role and

group in latest message.

\

If you are only interested in the first entity of a given type use

next(tracker.get_latest_entity_values("my_entity_name"), None).

If no entity is found None is the default result.

\

Args:

entity_type: the entity type of interest

entity_role: optional entity role of interest

entity_group: optional entity group of interest

\

Returns:

List of entity values.

"""entities = self.latest_message.get("entities", [])

return (

x.get("value")

for x in entities

if x.get("entity") == entity_type

and x.get("group") == entity_group

and x.get("role") == entity_role

)

\

def get_latest_input_channel(self) -> Optional[Text]:

"""Get the name of the input_channel of the latest UserUttered event"""

\

for e in reversed(self.events):

if e.get("event") == "user":

return e.get("input_channel")

return None

\

def is_paused(self) -> bool:

"""State whether the tracker is currently paused."""

return self._paused

\

def idx_after_latest_restart(self) -> int:

"""Return the idx of the most recent restart in the list of events.

\

If the conversation has not been restarted, 0 is returned.

"""

idx = 0

for i, event in enumerate(self.events):

if event.get("event") == "restart":

idx = i + 1

return idx

\

def events_after_latest_restart(self) -> List[dict]:

"""Return a list of events after the most recent restart."""

return list(self.events)[self.idx_after_latest_restart() :]

\

@property

def active_loop_name(self) -> Optional[Text]:

"""Get the name of the currently active loop.

\

Returns: None if no active loop or the name of the currently active loop.

"""

if not self.active_loop or self.active_loop.get("name") == "should_not_be_set":

return None

\

return self.active_loop.get("name")

\

def eq(self, other: Any) -> bool:

if isinstance(self, type(other)):

return other.events == self.events and self.sender_id == other.sender_id

else:

return False

\

def ne(self, other: Any) -> bool:

return not self.eq(other)

\

def copy(self) -> "Tracker":

return Tracker(

self.sender_id,

copy.deepcopy(self.slots),

copy.deepcopy(self.latest_message),

copy.deepcopy(self.events),

self._paused,

self.followup_action,

self.active_loop,

self.latest_action_name,

)

\

def last_executed_action_has(self, name: Text, skip: int = 0) -> bool:

last = self.get_last_event_for(

"action", exclude=[ACTION_LISTEN_NAME], skip=skip

)

return last is not None and last["name"] == name

\

def get_last_event_for(

self, event_type: Text, exclude: List[Text] = [], skip: int = 0

) -> Optional[Dict[Text, Any]]:

def filter_function(e: Dict[Text, Any]) -> bool:

has_instance = e["event"] == event_type

excluded = e["event"] == "action" and e["name"] in exclude

\

return has_instance and not excluded

\

filtered = filter(filter_function, reversed(self.applied_events()))

for _ in range(skip): next(filtered, None)

\

return next(filtered, None)

\

def applied_events(self) -> List[Dict[Text, Any]]:

"""Returns all actions that should be applied - w/o reverted events."""

\

def undo_till_previous(event_type: Text, done_events: List[Dict[Text, Any]]):

"""Removes events from done_events until the first

occurrence event_type is found which is also removed."""

list gets modified - hence we need to copy events!

for e in reversed(done_events[:]):

del done_events[-1]

if e["event"] == event_type:

break

\

applied_events: List[Dict[Text, Any]] = []

for event in self.events:

event_type = event.get("event")

if event_type == "restart":

applied_events = []

elif event_type == "undo":

undo_till_previous("action", applied_events)

elif event_type == "rewind":

Seeing a user uttered event automatically implies there was

a listen event right before it, so we'll first rewind the

user utterance, then get the action right before it (also removes

the action_listen action right before it).

undo_till_previous("user", applied_events)

undo_till_previous("action", applied_events)

else:

applied_events.append(event)

return applied_events

\

def slots_to_validate(self) -> Dict[Text, Any]:

"""Get slots which were recently set.

\

This can e.g. be used to validate form slots after they were extracted.

\

Returns:

A mapping of extracted slot candidates and their values.

"""

\

slots: Dict[Text, Any] = {}

count: int = 0

\

for event in reversed(self.events):

The FormAction in Rasa Open Source will append all slot candidates

at the end of the tracker events.

if event["event"] == "slot":

count += 1

else:

Stop as soon as there is another event type as this means that we

checked all potential slot candidates.

break

\

for event in self.events[len(self.events) - count :]:

slots[event["name"]] = event["value"]

\

return slots

\

def add_slots(self, slots: List[EventType]) -> None:

"""Adds slots to the current tracker.

\

Args:

slots: SlotSet events.

"""

for event in slots:

if not event.get("event") == "slot":

continue

self.slots[event["name"]] = event["value"]

self.events.append(event)

\

def get_intent_of_latest_message(

self, skip_fallback_intent: bool = True

) -> Optional[Text]:

"""Retrieves the intent the last user message.

\

Args:

skip_fallback_intent: Optionally skip the nlu_fallback intent

and return the next.


Returns:

Intent of latest message if available.

"""

latest_message = self.latest_message

if not latest_message:

return None

\

intent_ranking = latest_message.get("intent_ranking")

if not intent_ranking:

return None

\

highest_ranking_intent = intent_ranking[0]

if (

highest_ranking_intent["name"] == NLU_FALLBACK_INTENT_NAME

and skip_fallback_intent

):

if len(intent_ranking) >= 2:

return intent_ranking[1]["name"]

else:

return None

else:

return highest_ranking_intent["name"]

run方法返回的是rasa_sdk.events.Event

类型是List[Dict[str, Any]]

\

用户输入一个信息,提取出Entities加进去,使用EntitiesAdded,用于将提取的实体添加到跟踪器状态

\

rasa core的event.py

\

class EntitiesAdded(SkipEventInMDStoryMixin):

"""Event that is used to add extracted entities to the tracker state."""

\

type_name = "entities"

\

def init(

self,

entities: List[Dict[Text, Any]],

timestamp: Optional[float] = None,

metadata: Optional[Dict[Text, Any]] = None,

) -> None:

"""Initializes event.

\

Args:

entities: Entities extracted from previous user message. This can either

be done by NLU components or end-to-end policy predictions.

timestamp: the timestamp

metadata: some optional metadata

"""

super().init(timestamp, metadata)

self.entities = entities

\

def str(self) -> Text:

"""Returns the string representation of the event."""

entity_str = [e[ENTITY_ATTRIBUTE_TYPE] for e in self.entities]

return f"{self.class.name}({entity_str})"

\

def hash(self) -> int:

"""Returns the hash value of the event."""

return hash(json.dumps(self.entities))

\

def eq(self, other: Any) -> bool:

"""Compares this event with another event."""

if not isinstance(other, EntitiesAdded):

return NotImplemented

\

return self.entities == other.entities

\

@classmethod

def _from_parameters(cls, parameters: Dict[Text, Any]) -> "EntitiesAdded":

return EntitiesAdded( parameters.get(ENTITIES),

parameters.get("timestamp"),

parameters.get("metadata"),

)

\

def as_dict(self) -> Dict[Text, Any]:

"""Converts the event into a dict.

\

Returns:

A dict that represents this event.

"""

d = super().as_dict()

d.update({ENTITIES: self.entities})

return d

\

def apply_to(self, tracker: "DialogueStateTracker") -> None:

"""Applies event to current conversation state.

\

Args:

tracker: The current conversation state.

"""

if tracker.latest_action_name != ACTION_LISTEN_NAME:

entities belong only to the last user message

a user message always comes after action listen

return

\

for entity in self.entities:

if entity not in tracker.latest_message.entities:

tracker.latest_message.entities.append(entity)

rasa官网的例子

\

from typing import Text, Dict, Any, List

from rasa_sdk import Action

from rasa_sdk.events import SlotSet

\

class ActionCheckRestaurants(Action):

def name(self) -> Text:

return "action_check_restaurants"

\

def run(self,

dispatcher: CollectingDispatcher,

tracker: Tracker,

domain: Dict[Text, Any]) -> List[Dict[Text, Any]]:

\

cuisine = tracker.get_slot('cuisine')

q = "select * from restaurants where cuisine='{0}' limit 1".format(cuisine)

result = db.query(q)

\

return [SlotSet("matches", result if result is not None else [])]

————————————————

rasa完整视频