diff --git a/outlines/prompts.py b/outlines/prompts.py index 01e900c96..a7824451a 100644 --- a/outlines/prompts.py +++ b/outlines/prompts.py @@ -24,6 +24,7 @@ class Prompt: def __post_init__(self): self.parameters: List[str] = list(self.signature.parameters.keys()) + self.jinja_environment = create_jinja_template(self.template) def __call__(self, *args, **kwargs) -> str: """Render and return the template. @@ -35,7 +36,7 @@ def __call__(self, *args, **kwargs) -> str: """ bound_arguments = self.signature.bind(*args, **kwargs) bound_arguments.apply_defaults() - return render(self.template, **bound_arguments.arguments) + return self.jinja_environment.render(**bound_arguments.arguments) def __str__(self): return self.template @@ -182,6 +183,11 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str: A string that contains the rendered template. """ + jinja_template = create_jinja_template(template) + return jinja_template.render(**values) + + +def create_jinja_template(template: str): # Dedent, and remove extra linebreak cleaned_template = inspect.cleandoc(template) @@ -210,8 +216,7 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str: env.filters["args"] = get_fn_args jinja_template = env.from_string(cleaned_template) - - return jinja_template.render(**values) + return jinja_template def get_fn_name(fn: Callable):