diff --git a/pyproject.toml b/pyproject.toml index 0807a5f..823f6c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "smashed" -version = "0.21.5" +version = "0.21.6" description = """\ SMASHED is a toolkit designed to apply transformations to samples in \ datasets, such as fields extraction, tokenization, prompting, batching, \ diff --git a/src/smashed/mappers/promptsource.py b/src/smashed/mappers/promptsource.py index 0a77f04..65a57fe 100644 --- a/src/smashed/mappers/promptsource.py +++ b/src/smashed/mappers/promptsource.py @@ -141,16 +141,33 @@ def approx_input_fields(self) -> Tuple[Set[str], ...]: are used, nor cases where members of a variable are accessed. """ - # we compute variables first because some might - all_variables = set( - field - for field in self.get_vars_from_txt(self.template) - if field not in self.extra_vars - ) - return tuple( - {v for v in all_variables if v in fragment} - for fragment in self.template.split("|||") + output = tuple( + set( + field + for field in self.get_vars_from_txt(t) + if field not in self.extra_vars + ) + for t in self.template.split("|||") ) + print(output) + + return output + + # # we compute variables first because some might + # all_variables = set( + # field + # for field in self.get_vars_from_txt(self.template) + # if field not in self.extra_vars + # ) + # out = tuple( + # { + # field + # for field in all_variables + # if (field in fragment and field not in self.extra_vars) + # } + # for fragment in self.template.split("|||") + # ) + # return out @property def template_text(self) -> Tuple[str, ...]: