diff --git a/langfun/core/template.py b/langfun/core/template.py index 5bda967e..11147e05 100644 --- a/langfun/core/template.py +++ b/langfun/core/template.py @@ -280,6 +280,10 @@ def _preprocess_template(self, template_str: str) -> str: """ return template_str + def _postprocess_rendered(self, rendered_text: str) -> str: + """Postprocesses the rendered text.""" + return rendered_text + def vars( self, specified: bool | None = None, @@ -441,7 +445,9 @@ def render( # natural language when they are directly returned as rendering # elements in the template. with modality.format_modality_as_ref(): - rendered_text = self._template.render(**inputs) + rendered_text = self._postprocess_rendered( + self._template.render(**inputs) + ) # Carry the modality references passed from the constructor. # This is to support modality objects that is already rendered diff --git a/langfun/core/template_test.py b/langfun/core/template_test.py index eba6d1ad..3f4086e9 100644 --- a/langfun/core/template_test.py +++ b/langfun/core/template_test.py @@ -191,6 +191,22 @@ def _preprocess_template(self, template_str: str) -> str: 'Google is good' ) + def test_postprocess_rendered(self): + + class MyTemplate(Template): + """My template with postprocess. + + $COMPANY {{x}} + """ + + def _postprocess_rendered(self, rendered_text: str) -> str: + return rendered_text.replace('$COMPANY', 'Google') + + self.assertEqual( + MyTemplate(x='is $COMPANY').render(), + 'Google is Google' + ) + class FromValueTest(unittest.TestCase):