自定义jinja2 loader 实现基于oss的prompt-poet 提示词模板存储

prompt-poet 默认支持的是基于本地文件系统或者直接模版内容的模式进行处理,对于实际使用上并不是很方便,可以通过简单的扩展jinja2 loader 实现基于fsspec 的模版加载,这么就可以支持各类存储的模型了,以下是一个简单示例,因为目前prompt-poet 内部处理是固定的,暂时只能通过直接修改代码处理

参考处理

  • jinja2 fsspec loader

fsspec_loader.py 内部来自jinjarope ,我添加了search_path支持

from __future__ import annotations
 
import pathlib
from typing import TYPE_CHECKING,Any
 
import fsspec
import fsspec.core
import jinja2
 
if TYPE_CHECKING:
    from collections.abc import Callable
 
 
class FsSpecFileSystemLoader(jinja2.BaseLoader):
 
    ID = "fsspec"
 
    def __init__(self, fs: fsspec.AbstractFileSystem | str, **kwargs: Any):
        """Constructor.
 
        Arguments:
            fs: Either a protocol path string or an fsspec filesystem instance.
                Also supports "::dir" prefix to set the root path.
            kwargs: Optional storage options for the filesystem.
        """
        super().__init__()
        match fs:
            case str() if "://" in fs:
                self.fs, self.path = fsspec.core.url_to_fs(fs, **kwargs)
            case str():
                self.fs, self.path = fsspec.filesystem(fs, **kwargs), ""
            case _:
                self.fs, self.path = fs, ""
        self.storage_options = kwargs
        self.search_path = kwargs.get("search_path", None)
 
    def __eq__(self, other):
        return (
            type(self) == type(other)
            and self.storage_options == other.storage_options
            and self.fs == other.fs
            and self.path == other.path
        )
 
    def __hash__(self):
        return (
            hash(tuple(sorted(self.storage_options.items())))
            + hash(self.fs)
            + hash(self.path)
        )
 
    def list_templates(self) -> list[str]:
        return [
            f"{path}{self.fs.sep}{f}" if path else f
            for path, _dirs, files in self.fs.walk(self.fs.root_marker)
            for f in files
        ]
 
    def get_source(
        self,
        environment: jinja2.Environment,
        template: str,
    ) -> tuple[str, str, Callable[[], bool] | None]:
        try:
            if self.search_path:
                template = self.search_path + "/" + template
            with self.fs.open(template) as file:
                src = file.read().decode()
        except FileNotFoundError as e:
            raise jinja2.TemplateNotFound(template) from e
        path = pathlib.Path(template).as_posix()
        return src, path, lambda: True
  • 修改template_registry.py 支持基于oss 的加载
    _load_template 方法
def _load_template(
    self,
    template_name: str,
    template_dir: str = None,
    package_name: str = None,
) -> j2.Template:
    """Load template from disk."""
    loader = None
    if template_dir is None and package_name is None:
        raise ValueError(
            "Either `template_dir` or `package_name` must be provided."
        )
 
    try:
        if package_name is not None:
            loader = j2.PackageLoader(
                package_name=package_name, package_path=template_dir
            )
        else:
            print("template_dir",template_dir)
            # 此处硬编码了,主要是简单测试下,实际应该通过参数传递
            loader = FsSpecFileSystemLoader("oss", endpoint='xxx',key="xxxxx",secret="xxxxx",search_path=template_dir)
            #loader = j2.FileSystemLoader(searchpath=template_dir)
    except j2.TemplateNotFound as ex:
        raise j2.TemplateNotFound(
            f"Template not found: {ex} {template_name=} {template_dir=} {package_name=}"
        )
    env = j2.Environment(loader=loader)
    template = env.get_template(template_name)
    return template
  • oss 存储效果


app.yaml.j2

{% include 'mydemo/system_instruction.yml.j2' %}
 
- name: system demo
  role: system
  content: |
    you are a system user
 
{% if mydemo() == "dalong" %}
- name: system audio 
  role: system
  content: |
    you are a system audio user
{% endif %}

mydemo/system_instruction.yml.j2

- name: system demo
  role: system
  content: |
    demo contents
  • mydemov2.py
from prompt_poet import Prompt
 
def mydemo():
    return "dalongv1"
 
template_data = {
  "character_name": "Character Assistant",
  "username": "dalong",
  "mydemo": mydemo
}
 
prompt = Prompt(
    template_path="/prompt-template/000001/app.yaml.j2",
    template_data=template_data
)
 
print(prompt.messages)
  • 效果

说明

通过简单的修改我们就可以实现比较灵活的模版存储以及管理,实现prompt-poet 增强版本的模版仓库

参考资料

https://github.com/character-ai/prompt-poet
https://github.com/althonos/jinja2-fsloader
https://github.com/phil65/jinjarope

posted on 2024-10-14 06:40  荣锋亮  阅读(19)  评论(0编辑  收藏  举报

导航