From ec0dead7bcf36caae59f2d0e38292f76a95c193a Mon Sep 17 00:00:00 2001 From: eyasayesh Date: Wed, 1 Oct 2025 15:05:01 -0400 Subject: [PATCH 01/10] remove prints --- encoding/assembly/base_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/encoding/assembly/base_processor.py b/encoding/assembly/base_processor.py index de09078..f283ebd 100644 --- a/encoding/assembly/base_processor.py +++ b/encoding/assembly/base_processor.py @@ -131,8 +131,8 @@ def _process_fullcontext( total_len = len(transcript["word_orig"]) ds_data = transcript["word_orig"].astype(str) stimuli = [] - print(f"this is the lookback: {lookback}") - print(f"heloo") + #print(f"this is the lookback: {lookback}") + #print(f"heloo") for i, w in enumerate(ds_data): if w != "": From c22cb1e4d52fd3e8386c96d5c888bc1c191c9952 Mon Sep 17 00:00:00 2001 From: eyasayesh Date: Wed, 1 Oct 2025 16:12:38 -0400 Subject: [PATCH 02/10] customize story selection for lebel --- encoding/assembly/assembly_generator.py | 4 ++++ encoding/assembly/lebel_processor.py | 5 ++++- encoding/assembly/lpp_processor.py | 1 + encoding/assembly/narratives_processor.py | 1 + 4 files changed, 10 insertions(+), 1 deletion(-) diff --git a/encoding/assembly/assembly_generator.py b/encoding/assembly/assembly_generator.py index ea5aee1..de7ec32 100644 --- a/encoding/assembly/assembly_generator.py +++ b/encoding/assembly/assembly_generator.py @@ -27,6 +27,7 @@ def create( mask_path: Optional[str] = None, analysis_mask_path: Optional[str] = None, tokenizer: Optional[GPT2Tokenizer] = None, + **kwargs, ) -> BaseAssemblyGenerator: """Create a dataset-specific assembly generator. @@ -56,6 +57,7 @@ def create( mask_path, analysis_mask_path, tokenizer, + **kwargs ) @staticmethod @@ -72,6 +74,7 @@ def generate_assembly( generate_temporal_baseline: bool = False, analysis_mask_path: Optional[str] = None, tokenizer: Optional[GPT2Tokenizer] = None, + **kwargs, ) -> SimpleNeuroidAssembly: """Generate assembly for a subject using the appropriate dataset processor. @@ -98,6 +101,7 @@ def generate_assembly( mask_path, analysis_mask_path, tokenizer, + **kwargs ) return generator.generate_assembly( subject, diff --git a/encoding/assembly/lebel_processor.py b/encoding/assembly/lebel_processor.py index d67d7d4..db2c7eb 100644 --- a/encoding/assembly/lebel_processor.py +++ b/encoding/assembly/lebel_processor.py @@ -27,10 +27,12 @@ def __init__( mask_path: Optional[str] = None, analysis_mask_path: Optional[str] = None, tokenizer: Optional[GPT2Tokenizer] = None, + stories: Optional[List[str]] = None, + **kwargs, ): super().__init__(data_dir, dataset_type, tr, use_volume, mask_path, tokenizer) self.analysis_mask = analysis_mask_path - self.stories = [ + self.stories = stories if stories is not None else [ "adollshouse", "adventuresinsayingyes", "alternateithicatom", @@ -115,6 +117,7 @@ def _process_single_story( """Process a single story and return its data using a specified context type. Args: + subject: Subject identifier story_name: Name of the story being processed wordseq: Word sequence data for the story brain_data: Neural activity data for the story diff --git a/encoding/assembly/lpp_processor.py b/encoding/assembly/lpp_processor.py index 2114e42..6e9b873 100644 --- a/encoding/assembly/lpp_processor.py +++ b/encoding/assembly/lpp_processor.py @@ -22,6 +22,7 @@ def __init__( mask_path: Optional[str] = None, analysis_mask_path: Optional[str] = None, tokenizer: Optional[GPT2Tokenizer] = None, + **kwargs, ): super().__init__(data_dir, dataset_type, tr, use_volume, mask_path, tokenizer) self.analysis_mask = analysis_mask_path diff --git a/encoding/assembly/narratives_processor.py b/encoding/assembly/narratives_processor.py index 4a5a6d3..97be115 100644 --- a/encoding/assembly/narratives_processor.py +++ b/encoding/assembly/narratives_processor.py @@ -22,6 +22,7 @@ def __init__( mask_path: Optional[str] = None, analysis_mask_path: Optional[str] = None, tokenizer: Optional[GPT2Tokenizer] = None, + **kwargs, ): super().__init__(data_dir, dataset_type, tr, use_volume, mask_path, tokenizer) self.analysis_mask = analysis_mask_path From e220f8a25f7f2651b7466f519f76d763117e8631 Mon Sep 17 00:00:00 2001 From: eyasayesh Date: Wed, 1 Oct 2025 16:39:30 -0400 Subject: [PATCH 03/10] make audio path optional --- encoding/assembly/lebel_processor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/encoding/assembly/lebel_processor.py b/encoding/assembly/lebel_processor.py index db2c7eb..31787d1 100644 --- a/encoding/assembly/lebel_processor.py +++ b/encoding/assembly/lebel_processor.py @@ -67,6 +67,7 @@ def generate_assembly( context_type: str = "fullcontext", correlation_length: int = 100, generate_temporal_baseline: bool = False, + audio_path: Optional[str] = None, ) -> SimpleNeuroidAssembly: """Generate assembly for a subject by processing all stories. @@ -84,7 +85,6 @@ def generate_assembly( # Process each story for story in self.stories: - audio_path = f"{self.data_dir}/audio_files/{story}.wav" story_data = self._process_single_story( subject, story, @@ -109,7 +109,6 @@ def _process_single_story( self, subject: str, story_name: str, - volume_path: str, correlation_length: int = 100, generate_temporal_baseline: bool = False, audio_path: Optional[str] = None, From 1fc0ba26fc7eef25d41b537066a5b16f53da1e25 Mon Sep 17 00:00:00 2001 From: eyasayesh Date: Wed, 1 Oct 2025 17:56:00 -0400 Subject: [PATCH 04/10] add todo and flexible brainresp path --- encoding/assembly/assembly_generator.py | 1 + encoding/assembly/base_processor.py | 2 ++ encoding/assembly/lebel_processor.py | 23 ++++++++++++----------- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/encoding/assembly/assembly_generator.py b/encoding/assembly/assembly_generator.py index de7ec32..c724e73 100644 --- a/encoding/assembly/assembly_generator.py +++ b/encoding/assembly/assembly_generator.py @@ -109,4 +109,5 @@ def generate_assembly( context_type, correlation_length, generate_temporal_baseline, + **kwargs ) diff --git a/encoding/assembly/base_processor.py b/encoding/assembly/base_processor.py index f283ebd..ae1c771 100644 --- a/encoding/assembly/base_processor.py +++ b/encoding/assembly/base_processor.py @@ -339,6 +339,8 @@ def process_transcript( self, data_dir: str, story_name: str ) -> Tuple[pd.DataFrame, List[int], np.ndarray, np.ndarray]: """Process transcript data and generate split indices and timing information.""" + #TODO: Unify file structure for LITcoder + #TODO: make it so stimulis data is in a dictionary instead of a list, or make file store modular # read pickle file with open(os.path.join(data_dir, f"{self.dataset_type}_data.pkl"), "rb") as f: data = pickle.load(f) diff --git a/encoding/assembly/lebel_processor.py b/encoding/assembly/lebel_processor.py index 31787d1..2c37f1e 100644 --- a/encoding/assembly/lebel_processor.py +++ b/encoding/assembly/lebel_processor.py @@ -68,6 +68,7 @@ def generate_assembly( correlation_length: int = 100, generate_temporal_baseline: bool = False, audio_path: Optional[str] = None, + **kwargs, ) -> SimpleNeuroidAssembly: """Generate assembly for a subject by processing all stories. @@ -84,14 +85,15 @@ def generate_assembly( self.generate_temporal_baseline = generate_temporal_baseline # Process each story + #TODO: fix this to load big files once outside loop for story in self.stories: story_data = self._process_single_story( - subject, - story, - None, - correlation_length, - generate_temporal_baseline, + subject=subject, + story_name=story, + correlation_length=correlation_length, + generate_temporal_baseline=generate_temporal_baseline, audio_path=audio_path, + brain_resp_file= kwargs.get('brain_resp_file', 'brain_resp_huge.pkl'), ) story_data_list.append(story_data) @@ -112,6 +114,7 @@ def _process_single_story( correlation_length: int = 100, generate_temporal_baseline: bool = False, audio_path: Optional[str] = None, + brain_resp_file: Optional[str] = None, ) -> StoryData: """Process a single story and return its data using a specified context type. @@ -127,12 +130,10 @@ def _process_single_story( Returns: StoryData object containing processed story information """ - if self.use_volume: - with open(f"{self.data_dir}/noslice_sub-{subject}_story_data.pkl", "rb") as f: - resp_dict = pickle.load(f) - else: - with open(f"{self.data_dir}/noslice_sub-{subject}_story_data_surface.pkl", "rb") as f: - resp_dict = pickle.load(f) + #TODO: Unify file structure for LITcoder + with open(f"{self.data_dir}/{subject}/{brain_resp_file}", "rb") as f: + resp_dict = pickle.load(f) + brain_data = resp_dict.get(story_name) transcript, split_indices, tr_times, data_times, _ = self.process_transcript( From eb0e2145f4aaa935b982f3e91dc52aa333d77196 Mon Sep 17 00:00:00 2001 From: eyasayesh Date: Wed, 1 Oct 2025 20:17:49 -0400 Subject: [PATCH 05/10] preprocessing helper module --- encoding/data_prep/__init__.py | 0 encoding/data_prep/data_prep_utils.py | 153 ++++++ encoding/data_prep/textgrid.py | 651 ++++++++++++++++++++++++++ 3 files changed, 804 insertions(+) create mode 100644 encoding/data_prep/__init__.py create mode 100644 encoding/data_prep/data_prep_utils.py create mode 100644 encoding/data_prep/textgrid.py diff --git a/encoding/data_prep/__init__.py b/encoding/data_prep/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/encoding/data_prep/data_prep_utils.py b/encoding/data_prep/data_prep_utils.py new file mode 100644 index 0000000..4043e36 --- /dev/null +++ b/encoding/data_prep/data_prep_utils.py @@ -0,0 +1,153 @@ +"""Utils for preparing datasets before using litcode. +Adapted from Huth ridge utils. +TODO: Add proper citation""" +from typing import List, Tuple, Union +from pathlib import Path +import json +import numpy as np +from .textgrid import TextGrid +import pickle +import h5py + +DEFAULT_BAD_WORDS = frozenset(["sentence_start", "sentence_end", "br", "lg", "ls", "ns", "sp"]) + + + + +##########Transcript Preprocessing Utils########### + +def create_lebel_transcripts( story_list: List[str], + textgrids_dir: Union[str, Path], + respdict_path: Union[str, Path], + output_dir: Union[str, Path], + file_name: str = "lebel_transcripts.pkl" + ) -> None: + + """Create transcripts for the given stories and save them to the output directory. + + Args: + story_list: List of story names to process + textgrids_dir: Directory containing the TextGrid files + respdict_path: Path to the response dictionary JSON file + output_dir: Directory to save the generated transcripts + """ + + text_grids = _load_textgrids(story_list, textgrids_dir) + with open(respdict_path, "r") as f: + respdict = json.load(f) + tr_times = _simulate_trtimes(story_list, respdict) + processed_transcripts = _process_textgrids(text_grids, tr_times) + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / file_name + with open(output_path, "wb") as f: + pickle.dump(processed_transcripts, f) + + +def _load_textgrids(stories: List[str], textgrids_dir: Union[str, Path]) -> dict: + """Load TextGrid files for the given stories from the specified TextGrid directory. + + Args: + stories: List of story names + data_dir: Directory containing the TextGrid files + + Returns: + Dictionary mapping story names to their corresponding TextGrid objects + """ + + grids = {} + for story in stories: + grid_path = Path(textgrids_dir) / f"{story}.TextGrid" + grids[story] = TextGrid.load(grid_path) + return grids + + +def _simulate_trtimes(stories: List[str], respdict: dict, tr: float = 2.0, start_time: float = 10.0, pad: int = 10) -> dict: + """Simulate TR times for the given stories based on the response dictionary. + + Args: + stories: List of story names + respdict: Dictionary mapping story names to their response lengths + tr: Expected TR value + start_time: Start time for the simulation + pad: Padding to subtract from the response length + + Returns: + Dictionary mapping story names to their simulated TR times + """ + tr_times = {} + for story in stories: + resp_length = respdict.get(story, 0) + tr_times[story] = list(np.arange(-start_time, (resp_length - pad) * tr, tr)) + return tr_times + +def _process_textgrids(text_grids: dict, + tr_times: dict, + bad_words: frozenset = DEFAULT_BAD_WORDS + ) -> dict[dict]: + """Process the loaded TextGrid files to extract word sequences, filtering out bad words. + + Args: + text_grids: Dictionary mapping story names to their corresponding TextGrid objects + bad_words: Set of words to filter out from the transcripts + """ + processed_transcripts = {} + for story in text_grids.keys(): + simple_transcript = text_grids[story].tiers[1].make_simple_transcript() + ## Filter out bad words + filtered_transcript = [x for x in simple_transcript if x[2].lower().strip("{}").strip() not in bad_words] + # Further processing can be done here as needed + processed_transcripts[story] = _process_single_story(filtered_transcript, tr_times[story]) + + return processed_transcripts + +def _process_single_story(processed_transcript: List[Tuple], + tr_times: List[float]) -> dict: + """Process a single story's transcript and TR times to create a structured representation. + Args: + proceesed_transcript: List of tuples representing the transcript (start_time, end_time, word) + tr_times: List of TR times for the story + Returns: + Tuple containing processed story information + """ + + data_entries = list(zip(*processed_transcript))[2] + if isinstance(data_entries[0], str): + data = list(map(str.lower, list(zip(*processed_transcript))[2])) + else: + data = data_entries + word_starts = np.array(list(map(float, list(zip(*processed_transcript))[0]))) + word_ends = np.array(list(map(float, list(zip(*processed_transcript))[1]))) + word_avgtimes = (word_starts + word_ends)/2.0 + + tr = np.mean(np.diff(tr_times)) + tr_midpoints = np.array(tr_times) + tr/2.0 + + split_inds = [(word_starts<(t+tr)).sum() for t in tr_times][:-1] + return {"words": data, "split_indices": split_inds, "data_times":word_avgtimes,"tr_times": tr_midpoints} + +def create_brain_response_dict(story_list: List[str], + resp_data_dir: Union[str, Path], + output_dir: Union[str, Path], + file_name: str = "brain_resp_huge.pkl" + ) -> None: + """Create a dictionary of brain responses for the given stories and save it to the output path. + + Args: + story_list: List of story names to process + neural_data_dir: Directory containing the neural data files + output_path: Path to save the generated brain response dictionary + file_name: Name of the output file + """ + + brain_responses = {} + for story in story_list: + resp_data_path = Path(resp_data_dir) / f"{story}.hf5" + with h5py.File(resp_data_path, "r") as f: + brain_responses[story] = f["data"][:] + + output_dir = Path(output_dir) + output_dir.parent.mkdir(parents=True, exist_ok=True) + with open(output_dir / file_name, "wb") as f: + pickle.dump(brain_responses, f) \ No newline at end of file diff --git a/encoding/data_prep/textgrid.py b/encoding/data_prep/textgrid.py new file mode 100644 index 0000000..a4ae69d --- /dev/null +++ b/encoding/data_prep/textgrid.py @@ -0,0 +1,651 @@ +# Natural Language Toolkit: TextGrid analysis +# +# Copyright (C) 2001-2011 NLTK Project +# Author: Margaret Mitchell +# Steven Bird (revisions) +# URL: +# For license information, see LICENSE.TXT +# + +""" +Tools for reading TextGrid files, the format used by Praat. + +Module contents +=============== + +The textgrid corpus reader provides 4 data items and 1 function +for each textgrid file. For each tier in the file, the reader +provides 10 data items and 2 functions. + +For the full textgrid file: + + - size + The number of tiers in the file. + + - xmin + First marked time of the file. + + - xmax + Last marked time of the file. + + - t_time + xmax - xmin. + + - text_type + The style of TextGrid format: + - ooTextFile: Organized by tier. + - ChronTextFile: Organized by time. + - OldooTextFile: Similar to ooTextFile. + + - to_chron() + Convert given file to a ChronTextFile format. + + - to_oo() + Convert given file to an ooTextFile format. + +For each tier: + + - text_type + The style of TextGrid format, as above. + + - classid + The style of transcription on this tier: + - IntervalTier: Transcription is marked as intervals. + - TextTier: Transcription is marked as single points. + + - nameid + The name of the tier. + + - xmin + First marked time of the tier. + + - xmax + Last marked time of the tier. + + - size + Number of entries in the tier. + + - transcript + The raw transcript for the tier. + + - simple_transcript + The transcript formatted as a list of tuples: (time1, time2, utterance). + + - tier_info + List of (classid, nameid, xmin, xmax, size, transcript). + + - min_max() + A tuple of (xmin, xmax). + + - time(non_speech_marker) + Returns the utterance time of a given tier. + Excludes entries that begin with a non-speech marker. + +""" + +# needs more cleanup, subclassing, epydoc docstrings + +import sys +import re + +TEXTTIER = "TextTier" +INTERVALTIER = "IntervalTier" + +OOTEXTFILE = re.compile(r"""(?x) + xmin\ =\ (.*)[\r\n]+ + xmax\ =\ (.*)[\r\n]+ + [\s\S]+?size\ =\ (.*)[\r\n]+ +""") + +CHRONTEXTFILE = re.compile(r"""(?x) + [\r\n]+(\S+)\ + (\S+)\ +!\ Time\ domain.\ *[\r\n]+ + (\S+)\ +!\ Number\ of\ tiers.\ *[\r\n]+" +""") + +OLDOOTEXTFILE = re.compile(r"""(?x) + [\r\n]+(\S+) + [\r\n]+(\S+) + [\r\n]+.+[\r\n]+(\S+) +""") + + + +################################################################# +# TextGrid Class +################################################################# + +class TextGrid(object): + """ + Class to manipulate the TextGrid format used by Praat. + Separates each tier within this file into its own Tier + object. Each TextGrid object has + a number of tiers (size), xmin, xmax, a text type to help + with the different styles of TextGrid format, and tiers with their + own attributes. + """ + + def __init__(self, read_file): + """ + Takes open read file as input, initializes attributes + of the TextGrid file. + @type read_file: An open TextGrid file, mode "r". + @param size: Number of tiers. + @param xmin: xmin. + @param xmax: xmax. + @param t_time: Total time of TextGrid file. + @param text_type: TextGrid format. + @type tiers: A list of tier objects. + """ + + self.read_file = read_file + self.size = 0 + self.xmin = 0 + self.xmax = 0 + self.t_time = 0 + self.text_type = self._check_type() + self.tiers = self._find_tiers() + + def __iter__(self): + for tier in self.tiers: + yield tier + + + @staticmethod + def load(file): + """ + @param file: a file in TextGrid format + """ + + return TextGrid(open(file).read()) + + + def _check_type(self): + """ + Figures out the TextGrid format. + """ + #using regex to match and capture spefic pattern from the read_file + #captures the first 4 line + m = re.match("(.*)[\r\n](.*)[\r\n](.*)[\r\n](.*)", self.read_file) + try: + type_id = m.group(1).strip() + except AttributeError: + raise TypeError("Cannot read file -- try TextGrid.load()") + xmin = m.group(4) + if type_id == "File type = \"ooTextFile\"": + if "xmin" not in xmin: + text_type = "OldooTextFile" + else: + text_type = "ooTextFile" + elif type_id == "\"Praat chronological TextGrid text file\"": + text_type = "ChronTextFile" + else: + raise TypeError("Unknown format '(%s)'", (type_id)) + return text_type + + def _find_tiers(self): + """ + Splits the textgrid file into substrings corresponding to tiers. + """ + + if self.text_type == "ooTextFile": + m = OOTEXTFILE + header = " +item \[" + elif self.text_type == "ChronTextFile": + m = CHRONTEXTFILE + header = "\"\S+\" \".*\" \d+\.?\d* \d+\.?\d*" + elif self.text_type == "OldooTextFile": + m = OLDOOTEXTFILE + header = "\".*\"[\r\n]+\".*\"" + + file_info = m.findall(self.read_file)[0] + self.xmin = float(file_info[0]) + self.xmax = float(file_info[1]) + self.t_time = self.xmax - self.xmin + self.size = int(file_info[2]) + tiers = self._load_tiers(header) + return tiers + + def _load_tiers(self, header): + """ + Iterates over each tier and grabs tier information. + """ + + tiers = [] + if self.text_type == "ChronTextFile": + m = re.compile(header) + tier_headers = m.findall(self.read_file) + tier_re = " \d+.?\d* \d+.?\d*[\r\n]+\"[^\"]*\"" + for i in range(0, self.size): + tier_info = [tier_headers[i]] + \ + re.findall(str(i + 1) + tier_re, self.read_file) + tier_info = "\n".join(tier_info) + tiers.append(Tier(tier_info, self.text_type, self.t_time)) + return tiers + + tier_re = header + "[\s\S]+?(?=" + header + "|$$)" + m = re.compile(tier_re) + tier_iter = m.finditer(self.read_file) + for iterator in tier_iter: + (begin, end) = iterator.span() + tier_info = self.read_file[begin:end] + tiers.append(Tier(tier_info, self.text_type, self.t_time)) + return tiers + + + def to_chron(self): + """ + @return: String in Chronological TextGrid file format. + """ + + chron_file = "" + chron_file += "\"Praat chronological TextGrid text file\"\n" + chron_file += str(self.xmin) + " " + str(self.xmax) + chron_file += " ! Time domain.\n" + chron_file += str(self.size) + " ! Number of tiers.\n" + for tier in self.tiers: + idx = (self.tiers.index(tier)) + 1 + tier_header = "\"" + tier.classid + "\" \"" \ + + tier.nameid + "\" " + str(tier.xmin) \ + + " " + str(tier.xmax) + chron_file += tier_header + "\n" + transcript = tier.simple_transcript + for (xmin, xmax, utt) in transcript: + chron_file += str(idx) + " " + str(xmin) + chron_file += " " + str(xmax) +"\n" + chron_file += "\"" + utt + "\"\n" + return chron_file + + def to_oo(self): + """ + @return: A string in OoTextGrid file format. + """ + + oo_file = "" + oo_file += "File type = \"ooTextFile\"\n" + oo_file += "Object class = \"TextGrid\"\n\n" + oo_file += "xmin = ", self.xmin, "\n" + oo_file += "xmax = ", self.xmax, "\n" + oo_file += "tiers? \n" + oo_file += "size = ", self.size, "\n" + oo_file += "item []:\n" + for i in range(len(self.tiers)): + oo_file += "%4s%s [%s]" % ("", "item", i + 1) + _curr_tier = self.tiers[i] + for (x, y) in _curr_tier.header: + oo_file += "%8s%s = \"%s\"" % ("", x, y) + if _curr_tier.classid != TEXTTIER: + for (xmin, xmax, text) in _curr_tier.simple_transcript: + oo_file += "%12s%s = %s" % ("", "xmin", xmin) + oo_file += "%12s%s = %s" % ("", "xmax", xmax) + oo_file += "%12s%s = \"%s\"" % ("", "text", text) + else: + for (time, mark) in _curr_tier.simple_transcript: + oo_file += "%12s%s = %s" % ("", "time", time) + oo_file += "%12s%s = %s" % ("", "mark", mark) + return oo_file + + +################################################################# +# Tier Class +################################################################# + +class Tier(object): + """ + A container for each tier. + """ + + def __init__(self, tier, text_type, t_time): + """ + Initializes attributes of the tier: class, name, xmin, xmax + size, transcript, total time. + Utilizes text_type to guide how to parse the file. + @type tier: a tier object; single item in the TextGrid list. + @param text_type: TextGrid format + @param t_time: Total time of TextGrid file. + @param classid: Type of tier (point or interval). + @param nameid: Name of tier. + @param xmin: xmin of the tier. + @param xmax: xmax of the tier. + @param size: Number of entries in the tier + @param transcript: The raw transcript for the tier. + """ + + self.tier = tier + self.text_type = text_type + self.t_time = t_time + self.classid = "" + self.nameid = "" + self.xmin = 0 + self.xmax = 0 + self.size = 0 + self.transcript = "" + self.tier_info = "" + self._make_info() + self.simple_transcript = self.make_simple_transcript() + if self.classid != TEXTTIER: + self.mark_type = "intervals" + else: + self.mark_type = "points" + self.header = [("class", self.classid), ("name", self.nameid), \ + ("xmin", self.xmin), ("xmax", self.xmax), ("size", self.size)] + + def __iter__(self): + return self + + def _make_info(self): + """ + Figures out most attributes of the tier object: + class, name, xmin, xmax, transcript. + """ + + trans = "([\S\s]*)" + if self.text_type == "ChronTextFile": + classid = "\"(.*)\" +" + nameid = "\"(.*)\" +" + xmin = "(\d+\.?\d*) +" + xmax = "(\d+\.?\d*) *[\r\n]+" + # No size values are given in the Chronological Text File format. + self.size = None + size = "" + elif self.text_type == "ooTextFile": + classid = " +class = \"(.*)\" *[\r\n]+" + nameid = " +name = \"(.*)\" *[\r\n]+" + xmin = " +xmin = (\d+\.?\d*) *[\r\n]+" + xmax = " +xmax = (\d+\.?\d*) *[\r\n]+" + size = " +\S+: size = (\d+) *[\r\n]+" + elif self.text_type == "OldooTextFile": + classid = "\"(.*)\" *[\r\n]+" + nameid = "\"(.*)\" *[\r\n]+" + xmin = "(\d+\.?\d*) *[\r\n]+" + xmax = "(\d+\.?\d*) *[\r\n]+" + size = "(\d+) *[\r\n]+" + m = re.compile(classid + nameid + xmin + xmax + size + trans) + self.tier_info = m.findall(self.tier)[0] + self.classid = self.tier_info[0] + self.nameid = self.tier_info[1] + self.xmin = float(self.tier_info[2]) + self.xmax = float(self.tier_info[3]) + if self.size != None: + self.size = int(self.tier_info[4]) + self.transcript = self.tier_info[-1] + + def make_simple_transcript(self): + """ + @return: Transcript of the tier, in form [(start_time end_time label)] + """ + + if self.text_type == "ChronTextFile": + trans_head = "" + trans_xmin = " (\S+)" + trans_xmax = " (\S+)[\r\n]+" + trans_text = "\"([\S\s]*?)\"" + elif self.text_type == "ooTextFile": + trans_head = " +\S+ \[\d+\]: *[\r\n]+" + trans_xmin = " +\S+ = (\S+) *[\r\n]+" + trans_xmax = " +\S+ = (\S+) *[\r\n]+" + trans_text = " +\S+ = \"([^\"]*?)\"" + elif self.text_type == "OldooTextFile": + trans_head = "" + trans_xmin = "(.*)[\r\n]+" + trans_xmax = "(.*)[\r\n]+" + trans_text = "\"([\S\s]*?)\"" + if self.classid == TEXTTIER: + trans_xmin = "" + trans_m = re.compile(trans_head + trans_xmin + trans_xmax + trans_text) + self.simple_transcript = trans_m.findall(self.transcript) + return self.simple_transcript + + def transcript(self): + """ + @return: Transcript of the tier, as it appears in the file. + """ + + return self.transcript + + def time(self, non_speech_char="."): + """ + @return: Utterance time of a given tier. + Screens out entries that begin with a non-speech marker. + """ + + total = 0.0 + if self.classid != TEXTTIER: + for (time1, time2, utt) in self.simple_transcript: + utt = utt.strip() + if utt and not utt[0] == ".": + total += (float(time2) - float(time1)) + return total + + def tier_name(self): + """ + @return: Tier name of a given tier. + """ + + return self.nameid + + def classid(self): + """ + @return: Type of transcription on tier. + """ + + return self.classid + + def min_max(self): + """ + @return: (xmin, xmax) tuple for a given tier. + """ + + return (self.xmin, self.xmax) + + def __repr__(self): + return "<%s \"%s\" (%.2f, %.2f) %.2f%%>" % (self.classid, self.nameid, self.xmin, self.xmax, 100*self.time()/self.t_time) + + def __str__(self): + return self.__repr__() + "\n " + "\n ".join(" ".join(row) for row in self.simple_transcript) + +def demo_TextGrid(demo_data): + print("** Demo of the TextGrid class. **") + + fid = TextGrid(demo_data) + print("Tiers:", fid.size) + + for i, tier in enumerate(fid): + print("\n***") + print("Tier:", i + 1) + print(tier) + +def demo(): + # Each demo demonstrates different TextGrid formats. + print("Format 1") + demo_TextGrid(demo_data1) + print("\nFormat 2") + demo_TextGrid(demo_data2) + print("\nFormat 3") + demo_TextGrid(demo_data3) + + +demo_data1 = """File type = "ooTextFile" +Object class = "TextGrid" + +xmin = 0 +xmax = 2045.144149659864 +tiers? +size = 3 +item []: + item [1]: + class = "IntervalTier" + name = "utterances" + xmin = 0 + xmax = 2045.144149659864 + intervals: size = 5 + intervals [1]: + xmin = 0 + xmax = 2041.4217474125382 + text = "" + intervals [2]: + xmin = 2041.4217474125382 + xmax = 2041.968276643991 + text = "this" + intervals [3]: + xmin = 2041.968276643991 + xmax = 2042.5281632653062 + text = "is" + intervals [4]: + xmin = 2042.5281632653062 + xmax = 2044.0487352585324 + text = "a" + intervals [5]: + xmin = 2044.0487352585324 + xmax = 2045.144149659864 + text = "demo" + item [2]: + class = "TextTier" + name = "notes" + xmin = 0 + xmax = 2045.144149659864 + points: size = 3 + points [1]: + time = 2041.4217474125382 + mark = ".begin_demo" + points [2]: + time = 2043.8338291031832 + mark = "voice gets quiet here" + points [3]: + time = 2045.144149659864 + mark = ".end_demo" + item [3]: + class = "IntervalTier" + name = "phones" + xmin = 0 + xmax = 2045.144149659864 + intervals: size = 12 + intervals [1]: + xmin = 0 + xmax = 2041.4217474125382 + text = "" + intervals [2]: + xmin = 2041.4217474125382 + xmax = 2041.5438290324326 + text = "D" + intervals [3]: + xmin = 2041.5438290324326 + xmax = 2041.7321032910372 + text = "I" + intervals [4]: + xmin = 2041.7321032910372 + xmax = 2041.968276643991 + text = "s" + intervals [5]: + xmin = 2041.968276643991 + xmax = 2042.232189031843 + text = "I" + intervals [6]: + xmin = 2042.232189031843 + xmax = 2042.5281632653062 + text = "z" + intervals [7]: + xmin = 2042.5281632653062 + xmax = 2044.0487352585324 + text = "eI" + intervals [8]: + xmin = 2044.0487352585324 + xmax = 2044.2487352585324 + text = "dc" + intervals [9]: + xmin = 2044.2487352585324 + xmax = 2044.3102321849011 + text = "d" + intervals [10]: + xmin = 2044.3102321849011 + xmax = 2044.5748932104329 + text = "E" + intervals [11]: + xmin = 2044.5748932104329 + xmax = 2044.8329108578437 + text = "m" + intervals [12]: + xmin = 2044.8329108578437 + xmax = 2045.144149659864 + text = "oU" +""" + +demo_data2 = """File type = "ooTextFile" +Object class = "TextGrid" + +0 +2.8 + +2 +"IntervalTier" +"utterances" +0 +2.8 +3 +0 +1.6229213249309031 +"" +1.6229213249309031 +2.341428074708195 +"demo" +2.341428074708195 +2.8 +"" +"IntervalTier" +"phones" +0 +2.8 +6 +0 +1.6229213249309031 +"" +1.6229213249309031 +1.6428291382019483 +"dc" +1.6428291382019483 +1.65372183721983721 +"d" +1.65372183721983721 +1.94372874328943728 +"E" +1.94372874328943728 +2.13821938291038210 +"m" +2.13821938291038210 +2.341428074708195 +"oU" +2.341428074708195 +2.8 +"" +""" + +demo_data3 = """"Praat chronological TextGrid text file" +0 2.8 ! Time domain. +2 ! Number of tiers. +"IntervalTier" "utterances" 0 2.8 +"IntervalTier" "utterances" 0 2.8 +1 0 1.6229213249309031 +"" +2 0 1.6229213249309031 +"" +2 1.6229213249309031 1.6428291382019483 +"dc" +2 1.6428291382019483 1.65372183721983721 +"d" +2 1.65372183721983721 1.94372874328943728 +"E" +2 1.94372874328943728 2.13821938291038210 +"m" +2 2.13821938291038210 2.341428074708195 +"oU" +1 1.6229213249309031 2.341428074708195 +"demo" +1 2.341428074708195 2.8 +"" +2 2.341428074708195 2.8 +"" +""" + +if __name__ == "__main__": + demo() + From c1a6a46ba8c256fe8aeca64ae6417f54d63a0a29 Mon Sep 17 00:00:00 2001 From: eyasayesh Date: Wed, 1 Oct 2025 20:47:45 -0400 Subject: [PATCH 06/10] added flexibility to lebel stimulus processor --- encoding/assembly/base_processor.py | 11 ++++++++--- encoding/assembly/lebel_processor.py | 10 +++++++--- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/encoding/assembly/base_processor.py b/encoding/assembly/base_processor.py index ae1c771..ba13b9d 100644 --- a/encoding/assembly/base_processor.py +++ b/encoding/assembly/base_processor.py @@ -336,17 +336,22 @@ def compute_word_rate_features( return np.array(word_rates) # Shape: (n_trs, 1) def process_transcript( - self, data_dir: str, story_name: str + self, + data_dir: str, + transcript_file: str, + story_name: str, ) -> Tuple[pd.DataFrame, List[int], np.ndarray, np.ndarray]: """Process transcript data and generate split indices and timing information.""" #TODO: Unify file structure for LITcoder #TODO: make it so stimulis data is in a dictionary instead of a list, or make file store modular # read pickle file - with open(os.path.join(data_dir, f"{self.dataset_type}_data.pkl"), "rb") as f: + with open(os.path.join(data_dir, f"transcripts/{transcript_file}"), "rb") as f: data = pickle.load(f) # this is a list, iterate over it and find the story_name - story = next((s for s in data if s.get("story_name") == story_name), None) + #story = next((s for s in data if s.get("story_name") == story_name), None) + story = data[story_name] + if story is None: available = [s.get("story_name") for s in data] raise ValueError( diff --git a/encoding/assembly/lebel_processor.py b/encoding/assembly/lebel_processor.py index 2c37f1e..5d2b567 100644 --- a/encoding/assembly/lebel_processor.py +++ b/encoding/assembly/lebel_processor.py @@ -94,6 +94,7 @@ def generate_assembly( generate_temporal_baseline=generate_temporal_baseline, audio_path=audio_path, brain_resp_file= kwargs.get('brain_resp_file', 'brain_resp_huge.pkl'), + transcript_file= kwargs.get('transcript_file', 'lebel_transcripts.pkl' ) ) story_data_list.append(story_data) @@ -115,6 +116,7 @@ def _process_single_story( generate_temporal_baseline: bool = False, audio_path: Optional[str] = None, brain_resp_file: Optional[str] = None, + transcript_file: Optional[str] = None ) -> StoryData: """Process a single story and return its data using a specified context type. @@ -136,10 +138,12 @@ def _process_single_story( brain_data = resp_dict.get(story_name) + transcript, split_indices, tr_times, data_times, _ = self.process_transcript( - self.data_dir, - story_name - ) + self.data_dir, + transcript_file, + story_name) + stimuli = self.generate_stimuli_with_context(transcript, self.lookback) if self.analysis_mask is not None: From cd716f4f86521513979522636e2dd24b3f6d9221 Mon Sep 17 00:00:00 2001 From: eyasayesh Date: Thu, 2 Oct 2025 10:51:29 -0400 Subject: [PATCH 07/10] fix library imports for instance comparison --- encoding/features/embeddings.py | 2 +- encoding/features/factory.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/encoding/features/embeddings.py b/encoding/features/embeddings.py index 3ff62dc..fd110dd 100644 --- a/encoding/features/embeddings.py +++ b/encoding/features/embeddings.py @@ -48,7 +48,7 @@ class StaticEmbeddingFeatureExtractor(BaseFeatureExtractor): - tokenizer_pattern (str): ONLY used if input is a single string. Default r"[A-Za-z0-9_']+" (keeps underscores) - Note: This has also been tested with ENG1000. You just have to convert it to the .kv format first. We'll provide a scrip to do that! + Note: This has also been tested with ENG1000. You just have to convert it to the .kv format first. We'll provide a script to do that! """ def __init__(self, config: Dict[str, Any]): diff --git a/encoding/features/factory.py b/encoding/features/factory.py index 537300a..ce6da90 100644 --- a/encoding/features/factory.py +++ b/encoding/features/factory.py @@ -1,11 +1,11 @@ from typing import Dict, Any, Union, Optional, Tuple import numpy as np from datetime import datetime -from .base import BaseFeatureExtractor -from .language_model import LanguageModelFeatureExtractor -from .speech_model import SpeechFeatureExtractor -from .simple_features import WordRateFeatureExtractor -from .embeddings import StaticEmbeddingFeatureExtractor +from litcoder_core.encoding.features.base import BaseFeatureExtractor +from litcoder_core.encoding.features.language_model import LanguageModelFeatureExtractor +from litcoder_core.encoding.features.speech_model import SpeechFeatureExtractor +from litcoder_core.encoding.features.simple_features import WordRateFeatureExtractor +from litcoder_core.encoding.features.embeddings import StaticEmbeddingFeatureExtractor from ..utils import ActivationCache, SpeechActivationCache From 3454ec0c6c5434ec613950da0f32a2403ce38b0c Mon Sep 17 00:00:00 2001 From: eyasayesh Date: Sun, 5 Oct 2025 17:36:41 -0400 Subject: [PATCH 08/10] Bug fixes --- encoding/assembly/assemblies.py | 3 ++- encoding/assembly/lebel_processor.py | 2 +- encoding/assembly/lpp_processor.py | 2 +- encoding/assembly/narratives_processor.py | 2 +- encoding/assembly/story_data.py | 1 + encoding/downsample/downsampling.py | 4 ++-- encoding/downsample/interpdata.py | 2 +- encoding/models/nested_cv.py | 2 ++ 8 files changed, 11 insertions(+), 7 deletions(-) diff --git a/encoding/assembly/assemblies.py b/encoding/assembly/assemblies.py index da9d367..6dcc635 100644 --- a/encoding/assembly/assemblies.py +++ b/encoding/assembly/assemblies.py @@ -10,7 +10,7 @@ class SimpleNeuroidAssembly: """Simple alternative to NeuroidAssembly that doesn't require brainio and Xarray.""" - def __init__(self, story_data_list: List["StoryData"], validation_method: str): + def __init__(self, story_data_list: List["StoryData"], validation_method: str,is_volume:bool): """Initialize assembly with story-level separation. Args: @@ -19,6 +19,7 @@ def __init__(self, story_data_list: List["StoryData"], validation_method: str): self.stories = [story.name for story in story_data_list] self.story_data = {story.name: story for story in story_data_list} self.validation_method = validation_method + self.is_volume = is_volume # Store combined data for backward compatibility self.data = np.vstack([story.brain_data for story in story_data_list]) diff --git a/encoding/assembly/lebel_processor.py b/encoding/assembly/lebel_processor.py index 5d2b567..277d3ec 100644 --- a/encoding/assembly/lebel_processor.py +++ b/encoding/assembly/lebel_processor.py @@ -99,7 +99,7 @@ def generate_assembly( story_data_list.append(story_data) # Create assembly with story-level separation - return SimpleNeuroidAssembly(story_data_list, validation_method="outer") + return SimpleNeuroidAssembly(story_data_list, validation_method="outer",is_volume=self.use_volume) def _discover_stories(self, subject_dir: Path) -> List[Dict[str, str]]: """Discover all stories for a subject from the directory structure. diff --git a/encoding/assembly/lpp_processor.py b/encoding/assembly/lpp_processor.py index 6e9b873..3f6d770 100644 --- a/encoding/assembly/lpp_processor.py +++ b/encoding/assembly/lpp_processor.py @@ -70,7 +70,7 @@ def generate_assembly( story_data_list.append(story_data) # Create assembly with run-level separation - return SimpleNeuroidAssembly(story_data_list, validation_method="inner") + return SimpleNeuroidAssembly(story_data_list, validation_method="inner",is_volume=self.use_volume) def _discover_stories( self, subject_dir: Path, subject: str diff --git a/encoding/assembly/narratives_processor.py b/encoding/assembly/narratives_processor.py index 97be115..bff7394 100644 --- a/encoding/assembly/narratives_processor.py +++ b/encoding/assembly/narratives_processor.py @@ -72,7 +72,7 @@ def generate_assembly( story_data_list.append(story_data) # Create assembly with story-level separation - return SimpleNeuroidAssembly(story_data_list, validation_method="inner") + return SimpleNeuroidAssembly(story_data_list, validation_method="inner",is_volume=self.use_volume) def _discover_stories(self, subject_dir: Path) -> List[Dict[str, str]]: """Discover all stories for a subject from the directory structure.""" diff --git a/encoding/assembly/story_data.py b/encoding/assembly/story_data.py index 5bd2e51..ff25d25 100644 --- a/encoding/assembly/story_data.py +++ b/encoding/assembly/story_data.py @@ -11,6 +11,7 @@ class StoryData: Attributes: name (str): Name identifier for the story/run brain_data (np.ndarray): Brain activation data, shape (n_timepoints, n_voxels/vertices) + is_volume (bool): true if brain data is volume data, false if surface stimuli (List[str]): List of text stimuli corresponding to each timepoint split_indices (List[int]): Indices marking TR boundaries in the data tr_times (np.ndarray): Array of TR timestamps diff --git a/encoding/downsample/downsampling.py b/encoding/downsample/downsampling.py index 12b1e10..a8b47e5 100644 --- a/encoding/downsample/downsampling.py +++ b/encoding/downsample/downsampling.py @@ -143,7 +143,7 @@ def downsample( self, data: np.ndarray, data_times: np.ndarray, tr_times: np.ndarray, **kwargs ) -> np.ndarray: """Downsample using sinc interpolation.""" - return interpdata.sincinterp2D(data, data_times, tr_times, **kwargs) + return interpdata.sincinterp2D(data=data, oldtime=data_times, newtime=tr_times, **kwargs) class LanczosDownsampler(BaseDownsampler): @@ -154,7 +154,7 @@ def downsample( ) -> np.ndarray: """Downsample using Lanczos interpolation.""" # log the kwargs - return interpdata.lanczosinterp2D(data, data_times, tr_times, **kwargs) + return interpdata.lanczosinterp2D(data=data, oldtime=data_times, newtime=tr_times, **kwargs) class GaborDownsampler(BaseDownsampler): diff --git a/encoding/downsample/interpdata.py b/encoding/downsample/interpdata.py index 94bddc1..2ceb51b 100644 --- a/encoding/downsample/interpdata.py +++ b/encoding/downsample/interpdata.py @@ -84,7 +84,7 @@ def sincinterp2D( return newdata -def lanczosinterp2D(data, oldtime, newtime, window=3, cutoff_mult=1.0, rectify=False): +def lanczosinterp2D(data, oldtime, newtime, window=3, cutoff_mult=1.0, rectify=False,**kwags): """Interpolates the columns of [data], assuming that the i'th row of data corresponds to oldtime(i). A new matrix with the same number of columns and a number of rows given by the length of [newtime] is returned. diff --git a/encoding/models/nested_cv.py b/encoding/models/nested_cv.py index afa0d97..4bcb514 100644 --- a/encoding/models/nested_cv.py +++ b/encoding/models/nested_cv.py @@ -87,6 +87,8 @@ def fit_predict( device = "mps:0" elif torch.cuda.is_available(): device = "cuda" + else: #fallback + device = "cpu" else: device = "cpu" logger.info(f"Using device: {device}") From 6168126f2bada7d44a5dca2955b795342eef7733 Mon Sep 17 00:00:00 2001 From: eyasayesh Date: Sun, 5 Oct 2025 17:37:14 -0400 Subject: [PATCH 09/10] Added wandb offline logger --- encoding/trainer.py | 44 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 39 insertions(+), 5 deletions(-) diff --git a/encoding/trainer.py b/encoding/trainer.py index dc8a4b7..fe8a8d5 100644 --- a/encoding/trainer.py +++ b/encoding/trainer.py @@ -39,6 +39,7 @@ def __init__( # Logging parameters logger_backend: str = "wandb", wandb_project_name: str = "abstract-trainer", + wandb_mode: Optional[str] = None, results_dir: str = "results", run_name: Optional[str] = None, # Processing parameters @@ -61,6 +62,7 @@ def __init__( dataset_type: Dataset type for caching logger_backend: "wandb" or "tensorboard" wandb_project_name: Project name for wandb + wandb_mode: Mode for wandb ('online' or 'offline'). results_dir: Directory for results run_name: Custom run name downsample_config: Downsampling parameters @@ -91,7 +93,7 @@ def __init__( self.stories_to_process = story_selection # Setup logging - self.setup_logger(logger_backend, wandb_project_name, results_dir, run_name) + self.setup_logger(logger_backend, wandb_project_name, wandb_mode, results_dir, run_name) self.model_saver = ModelSaver(base_dir=results_dir) self.brain_plotter = BrainPlotter(self.experiment_logger) @@ -104,18 +106,47 @@ def __init__( logger.info(f"FIR delays: {self.fir_delays}") logger.info(f"Use train/test split: {self.use_train_test_split}") - def setup_logger(self, backend: str, project_name: str, results_dir: str, run_name: Optional[str]): - """Setup experiment logger.""" + def setup_logger(self, backend: str, project_name: str, wandb_mode: Optional[str], results_dir: str, run_name: Optional[str]): + """Setup experiment logger. + + Args: + backend (str): logger backend to use ('wandb' or 'tensorboard') + project_name (str): project name + wandb_mode (str): Mode for wandb ('online' or 'offline'). + If None, reads from WANDB_MODE env var (defaults to 'offline') + results_dir (str): Directory to store the results + run_name (str): custom run_name + + """ if run_name is None: run_name = f"abstract-trainer-{datetime.now().strftime('%Y%m%d-%H%M%S')}" if backend == "wandb": try: import wandb - wandb.init(project=project_name, name=run_name) + import os + + if wandb_mode is None: + wandb_mode = os.environ.get('WANDB_MODE', 'offline') + + os.environ['WANDB_MODE'] = wandb_mode + + + if wandb_mode == "offline": + wandb_dir = f'{results_dir}/wandb' + tmpdir = os.environ.get('TMPDIR', results_dir) + wandb_cache_dir = f'{tmpdir}/wandb_cache' + os.makedirs(wandb_dir, exist_ok=True) + os.makedirs(wandb_cache_dir, exist_ok=True) + os.environ.setdefault('WANDB_DIR', wandb_dir) + os.environ.setdefault('WANDB_CACHE_DIR', wandb_cache_dir) + + wandb.init(project=project_name, name=run_name,start_method="thread") self.experiment_logger = WandBLogger() + except ImportError as e: raise ImportError("wandb not installed. Install with: pip install wandb") from e + elif backend == "tensorboard": run_dir = f"{results_dir}/runs/{run_name}" self.experiment_logger = TensorBoardLogger(log_dir=run_dir) @@ -328,7 +359,10 @@ def log_metrics(self, metrics: Dict): if "correlations" in metrics and "significant_mask" in metrics: correlations = np.array(metrics["correlations"]) significant_mask = np.array(metrics["significant_mask"], dtype=bool) - self.brain_plotter.log_plots(correlations, significant_mask, "", False) + self.brain_plotter.log_plots(correlations = correlations, + significant_mask=significant_mask, + prefix="", + is_volume=self.assembly.is_volume) if "best_alpha" in metrics: self.experiment_logger.log_scalar("best_alpha", float(metrics["best_alpha"])) From f73ab7bf03452b3b6a287109007e3fc0cb366725 Mon Sep 17 00:00:00 2001 From: eyasayesh Date: Sun, 5 Oct 2025 17:37:31 -0400 Subject: [PATCH 10/10] renaming simple to lowlevel --- encoding/features/__init__.py | 2 +- encoding/features/factory.py | 2 +- encoding/features/{simple_features.py => lowlevel_features.py} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename encoding/features/{simple_features.py => lowlevel_features.py} (100%) diff --git a/encoding/features/__init__.py b/encoding/features/__init__.py index 72abfe6..729297f 100644 --- a/encoding/features/__init__.py +++ b/encoding/features/__init__.py @@ -1,6 +1,6 @@ from .language_model import LanguageModelFeatureExtractor from .speech_model import SpeechFeatureExtractor -from .simple_features import WordRateFeatureExtractor +from .lowlevel_features import WordRateFeatureExtractor from .embeddings import StaticEmbeddingFeatureExtractor from .FIR_expander import FIR from .factory import FeatureExtractorFactory diff --git a/encoding/features/factory.py b/encoding/features/factory.py index ce6da90..030c835 100644 --- a/encoding/features/factory.py +++ b/encoding/features/factory.py @@ -4,7 +4,7 @@ from litcoder_core.encoding.features.base import BaseFeatureExtractor from litcoder_core.encoding.features.language_model import LanguageModelFeatureExtractor from litcoder_core.encoding.features.speech_model import SpeechFeatureExtractor -from litcoder_core.encoding.features.simple_features import WordRateFeatureExtractor +from litcoder_core.encoding.features.lowlevel_features import WordRateFeatureExtractor from litcoder_core.encoding.features.embeddings import StaticEmbeddingFeatureExtractor from ..utils import ActivationCache, SpeechActivationCache diff --git a/encoding/features/simple_features.py b/encoding/features/lowlevel_features.py similarity index 100% rename from encoding/features/simple_features.py rename to encoding/features/lowlevel_features.py