Model Storage
上一章节我们聊了聊自定义图组件,这一章节我们来聊一聊模型持久化。图组件在训练期间需要持久化数据,这些数据在推理时应该可用于图组件,最典型的例子是是存储模型权重。还是从官方文档开始,在上一章节自定义图组件的源码分析中,自定义图组件实现create和load方法都包含model_storage参数,用于持久化和加载图组件。resource参数用于标识图组件在model_storage中的位置。
官方文档给出了两个例子:将resource写入和读取Model Storage,我们一个一个看一下
Writing to the Model Storage
from __future__ import annotations
import json
from typing import Optional, Dict, Any, Text
from rasa.engine.graph import GraphComponent, ExecutionContext
from rasa.engine.storage.resource import Resource
from rasa.engine.storage.storage import ModelStorage
from rasa.shared.nlu.training_data.training_data import TrainingData
class MyComponent(GraphComponent):
def __init__(
self,
model_storage: ModelStorage,
resource: Resource,
training_artifact: Optional[Dict],
) -> None:
# Store both `model_storage` and `resource` as object attributes to be able
# to utilize them at the end of the training
self._model_storage = model_storage
self._resource = resource
@classmethod
def create(
cls,
config: Dict[Text, Any],
model_storage: ModelStorage,
resource: Resource,
execution_context: ExecutionContext,
) -> MyComponent:
return cls(model_storage, resource, training_artifact=None)
def train(self, training_data: TrainingData) -> Resource:
# Train your graph component
...
# Persist your graph component
with self._model_storage.write_to(self._resource) as directory_path:
with open(directory_path / "artifact.json", "w") as file:
json.dump({"my": "training artifact"}, file)
# Return resource to make sure the training artifacts
# can be cached.
return self._resource
代码片段演示了如何将图组件的数据写入模型存储(Model Storage)。为了在训练后持久化图组件,train方法需要访问model_storage和resource的值。因此,应该在初始化model_storage和resource的值。
图组件的train方法必须返回resource,以便Rasa可以在训练之间缓存训练结果。self._model_storage.write_to(self.resource)上下文管理器提供了一个目录路径,可以在该目录中持久保存图组件所需的任何数据。
Reading from the Model Storage
from __future__ import annotations
import json
from typing import Optional, Dict, Any, Text
from rasa.engine.graph import GraphComponent, ExecutionContext
from rasa.engine.storage.resource import Resource
from rasa.engine.storage.storage import ModelStorage
class MyComponent(GraphComponent):
def __init__(
self,
model_storage: ModelStorage,
resource: Resource,
training_artifact: Optional[Dict],
) -> None:
self._model_storage = model_storage
self._resource = resource
@classmethod
def load(
cls,
config: Dict[Text, Any],
model_storage: ModelStorage,
resource: Resource,
execution_context: ExecutionContext,
**kwargs: Any,
) -> MyComponent:
try:
with model_storage.read_from(resource) as directory_path:
with open(directory_path / "artifact.json", "r") as file:
training_artifact = json.load(file)
return cls(
model_storage, resource, training_artifact=training_artifact
)
except ValueError:
# This allows you to handle the case if there was no
# persisted data for your component
...
ModelStorage Interface
既然咱们是源码解析,浅浅看一下ModelStorage接口和例子中提到的两个方法做了什么,首先看一下源码:
class ModelStorage(abc.ABC):
@contextmanager
@abc.abstractmethod
def write_to(self, resource: Resource) -> Generator[Path, None, None]:
...
@contextmanager
@abc.abstractmethod
def read_from(self, resource: Resource) -> Generator[Path, None, None]:
...
write_to方法为给定的Resource的数据提供持久化,方法将返回一个用于持久化Resource的目录。read_from用于访问持久化组件的数据,返回一个目录包含持久化数据的目录,在该方法中荣国给定Resource没有数据,会抛出一个ValueError。
现在我们大概知道了ModelStorage是做什么的,总的来说,该类为需要持久化的图组件GraphComponent提供存储后端。看到这里,你可能有所疑问,Resource是做什么的,该类又实现了哪些功能呢。
Resource
首先,我们看一下Resource的源码。
@dataclass
class Resource:
name: Text
output_fingerprint: Text = field(
default_factory=lambda: uuid.uuid4().hex,
compare=False,
)
从源码中可以看到,Resource类其实很简单,该类是dataclass,包含两个属性:
- name:Resource的标识符,用于定位来自ModelStorage的相关数据,通常与节点名称相同
- output_fingerprint:Resource特定实例化的标识符,当存储到缓存时,用于区分Resource的实例化
同时Rasa提供了两个方法from_cache和to_cache,用于实现从缓存中读取Resource和存储Resource到缓存。让我们看一下方法做了什么。
from_cache
@classmethod
def from_cache(
cls,
node_name: Text,
directory: Path,
model_storage: ModelStorage,
output_fingerprint: Text,
) -> Resource:
logger.debug(f"Loading resource '{node_name}' from cache.")
resource = Resource(node_name, output_fingerprint=output_fingerprint)
if not any(directory.glob("*")):
logger.debug(f"Cached resource for '{node_name}' was empty.")
return resource
try:
with model_storage.write_to(resource) as resource_directory:
rasa.utils.common.copy_directory(directory, resource_directory)
except ValueError:
if not rasa.utils.io.are_directories_equal(directory, resource_directory):
raise
logger.debug(f"Successfully initialized resource '{node_name}' from cache.")
return resource
方法用于从缓存中读取一个Resource,为此提供以下参数:
- node_name:Resource对应的节点名称
- directory:缓存Resource的目录
- model_storage:缓存的Resource将被添加到ModelStorage中,为了Resource就可以被其他图节点访问
- output_fingerprint:Resource的output_fingerprint用于指定实例化的Resource
函数首先初始化Resource,判断是否目录内有缓存的数据,如果有的话,则将Resource读取到ModelStorage(使用write_to方法),ModelStorage将返回一个可以持久化给定Resource的目录,然后将缓存数据读取到指定目录当中
to_cache
def to_cache(self, directory: Path, model_storage: ModelStorage) -> None:
try:
with model_storage.read_from(self) as resource_directory:
rasa.utils.common.copy_directory(resource_directory, directory)
except ValueError:
logger.debug(
f"Skipped caching resource '{self.name}' as no persisted "
f"data was found."
)
方法用于持久化Resource到缓存,为此需要提供以下参数:
- directory:用于持久化Resource的目录
- model_storage:包含持久化Resource的ModelStorage
方法调用了ModelStorage的read_from方法得到Resource的目录,并把持久化的数据存到用于持久化的目录。
小结
本文介绍了数据持久化相关的ModelStorage接口和Resource数据类的源码,总结一下:
- ModelStorage接口的功能是为模型即图组件提供数据持久化的后端支持,主要是写入和读取持久化的Resource的数据。
- Resource类的作用是标识一个持久化图组件,可以配合ModelStorage完成持久化数据存储到缓存,以及从缓存中读取数据。
感谢各位的耐心学习,如有疑惑或建议,欢迎评论区留言。让我们一起学习进步!