From 9d61e76ed179028ef9a1bd38a42af78e6af03ef8 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Fri, 22 Sep 2023 13:29:29 -0700 Subject: [PATCH 1/3] fixed regression --- pyproject.toml | 2 +- src/smashed/mappers/promptsource.py | 8 ++++++-- tests/test_promptsource_recipe.py | 12 ++++++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0807a5f..823f6c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "smashed" -version = "0.21.5" +version = "0.21.6" description = """\ SMASHED is a toolkit designed to apply transformations to samples in \ datasets, such as fields extraction, tokenization, prompting, batching, \ diff --git a/src/smashed/mappers/promptsource.py b/src/smashed/mappers/promptsource.py index 0a77f04..9688361 100644 --- a/src/smashed/mappers/promptsource.py +++ b/src/smashed/mappers/promptsource.py @@ -147,10 +147,14 @@ def approx_input_fields(self) -> Tuple[Set[str], ...]: for field in self.get_vars_from_txt(self.template) if field not in self.extra_vars ) - return tuple( - {v for v in all_variables if v in fragment} + out = tuple( + { + field for field in all_variables + if (field in fragment and field not in self.extra_vars) + } for fragment in self.template.split("|||") ) + return out @property def template_text(self) -> Tuple[str, ...]: diff --git a/tests/test_promptsource_recipe.py b/tests/test_promptsource_recipe.py index 6f0e62c..6a07527 100644 --- a/tests/test_promptsource_recipe.py +++ b/tests/test_promptsource_recipe.py @@ -106,6 +106,18 @@ def test_few_shot_truncation(self): # The fact that the prompt is a bit different from the template # is totally fine: T5 removes multiple spaces, turns newlines into # spaces, and decoding strips the trailing spaces. + + print('\n' + self.tokenizer.decode(mapped_dataset[0]["input_ids"])) + print('-----------') + print( + f"Q: {FEW_SHOT_DATASET[0]['question'][:14].rstrip()} " + f"A: {FEW_SHOT_DATASET[0]['answer'][:14].rstrip()} " + f"Q: {FEW_SHOT_DATASET[1]['question'][:14].rstrip()} " + f"A: {FEW_SHOT_DATASET[1]['answer'][:14].rstrip()} " + f"Q: {FEW_SHOT_DATASET[2]['question'][:14].rstrip()} " + "A:" + ) + return self.assertEqual( self.tokenizer.decode(mapped_dataset[0]["input_ids"]), ( From 43afaa56e6de49606946194aee6486e6d9eb1e59 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Fri, 22 Sep 2023 13:39:32 -0700 Subject: [PATCH 2/3] formatting --- src/smashed/mappers/promptsource.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/smashed/mappers/promptsource.py b/src/smashed/mappers/promptsource.py index 9688361..5240f5e 100644 --- a/src/smashed/mappers/promptsource.py +++ b/src/smashed/mappers/promptsource.py @@ -149,7 +149,8 @@ def approx_input_fields(self) -> Tuple[Set[str], ...]: ) out = tuple( { - field for field in all_variables + field + for field in all_variables if (field in fragment and field not in self.extra_vars) } for fragment in self.template.split("|||") From 0cc1e47ed9800218d8aa8b06da2676c9526e5a3c Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Thu, 23 May 2024 18:11:51 -0700 Subject: [PATCH 3/3] wip --- src/smashed/mappers/promptsource.py | 38 +++++++++++++++++++---------- tests/test_promptsource_recipe.py | 12 --------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/smashed/mappers/promptsource.py b/src/smashed/mappers/promptsource.py index 5240f5e..65a57fe 100644 --- a/src/smashed/mappers/promptsource.py +++ b/src/smashed/mappers/promptsource.py @@ -141,21 +141,33 @@ def approx_input_fields(self) -> Tuple[Set[str], ...]: are used, nor cases where members of a variable are accessed. """ - # we compute variables first because some might - all_variables = set( - field - for field in self.get_vars_from_txt(self.template) - if field not in self.extra_vars - ) - out = tuple( - { + output = tuple( + set( field - for field in all_variables - if (field in fragment and field not in self.extra_vars) - } - for fragment in self.template.split("|||") + for field in self.get_vars_from_txt(t) + if field not in self.extra_vars + ) + for t in self.template.split("|||") ) - return out + print(output) + + return output + + # # we compute variables first because some might + # all_variables = set( + # field + # for field in self.get_vars_from_txt(self.template) + # if field not in self.extra_vars + # ) + # out = tuple( + # { + # field + # for field in all_variables + # if (field in fragment and field not in self.extra_vars) + # } + # for fragment in self.template.split("|||") + # ) + # return out @property def template_text(self) -> Tuple[str, ...]: diff --git a/tests/test_promptsource_recipe.py b/tests/test_promptsource_recipe.py index 6a07527..6f0e62c 100644 --- a/tests/test_promptsource_recipe.py +++ b/tests/test_promptsource_recipe.py @@ -106,18 +106,6 @@ def test_few_shot_truncation(self): # The fact that the prompt is a bit different from the template # is totally fine: T5 removes multiple spaces, turns newlines into # spaces, and decoding strips the trailing spaces. - - print('\n' + self.tokenizer.decode(mapped_dataset[0]["input_ids"])) - print('-----------') - print( - f"Q: {FEW_SHOT_DATASET[0]['question'][:14].rstrip()} " - f"A: {FEW_SHOT_DATASET[0]['answer'][:14].rstrip()} " - f"Q: {FEW_SHOT_DATASET[1]['question'][:14].rstrip()} " - f"A: {FEW_SHOT_DATASET[1]['answer'][:14].rstrip()} " - f"Q: {FEW_SHOT_DATASET[2]['question'][:14].rstrip()} " - "A:" - ) - return self.assertEqual( self.tokenizer.decode(mapped_dataset[0]["input_ids"]), (