Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 90 additions & 3 deletions Functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
from tensorflow.keras import utils as np_utils
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model

from sklearn.metrics import accuracy_score
from tensorflow.keras.utils import to_categorical


###### data segmenting and relabeling functions ######
Expand Down Expand Up @@ -47,9 +51,10 @@ def filter_and_relabel(data, label, keep_labels, new_labels):
filtered_label = np.array([new_labels[l] for l in filtered_label])
return filtered_data, filtered_label

def generate_paths(subj_id, task, nclass, session_num, model_type, data_folder):
def generate_paths(subj_id, task, nclass, session_num, model_type, data_folder, finetune_eval_mode=False):
# get the file paths to the training data
subject_folder = os.path.join(data_folder, f'S{subj_id:02}')
print("Subject folder:", subject_folder)

if task == 'MI':
prefix = '*Imagery'
Expand All @@ -62,18 +67,23 @@ def generate_paths(subj_id, task, nclass, session_num, model_type, data_folder):
suffix = f'{nclass}class_Base' # 3-class model is fine-tuned on 3-class same day data
else:
suffix = 'Base' # 2-class model is fine-tuned on both 2-class and 3-class same day data

if finetune_eval_mode:
suffix = 'Finetune' # hardcoding suffix for evaluation of finetuned model

pattern = os.path.join(subject_folder, f'{prefix_online}*{suffix}')
data_paths = sorted(glob.glob(pattern))
else:
# load the offline session data
offline_pattern = os.path.join(subject_folder, prefix)
print("Loading offline data from:", offline_pattern)
data_paths = sorted(glob.glob(offline_pattern))

# load the prior online sessions
for session in range(1,session_num):
prefix_online = f'{prefix}_Sess{session:02}'
online_pattern = os.path.join(subject_folder, f'{prefix_online}*')
print("Loading online data from:", online_pattern)
data_paths.extend(sorted(glob.glob(online_pattern)))
return data_paths

Expand Down Expand Up @@ -227,9 +237,9 @@ def train_models(data, label, save_name, params):
callback_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=30)

if 'modelpath' in params.keys(): # finetune: smaller starting lr
optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=1e-4)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
else:
optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=0.001)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

model.compile(loss='categorical_crossentropy', optimizer=optimizer,
metrics = ['accuracy'])
Expand Down Expand Up @@ -262,3 +272,80 @@ def train_models(data, label, save_name, params):
print("Training Finished!")
print(f"Model saved to {save_name}")
return save_name


def evaluate_model(model_path, eval_data_paths, params):
# Load and preprocess evaluation data
trial_data, trial_labels, _ = load_and_filter_data(eval_data_paths, params)

nChan = np.size(trial_data, axis=1)
DesiredLen = int(params['windowlen'] * params['downsrate'])
segment_size = int(params['windowlen'] * params['srate'])
step_size = 128 # This is the step_size used in the paper's model training

# Segment data - capture the new trial_ids array (from repeated_indices)
segmented_data, segment_labels, segment_trial_ids = segment_data(
trial_data, trial_labels, segment_size, step_size
)

# --- Preprocessing steps (same as original, but using segmented_data) ---

# Downsample
segmented_data = resample(segmented_data, DesiredLen, t=None, axis=2, window=None, domain='time')

# Bandpass filtering
padding_length = 100
segmented_data = np.pad(segmented_data, ((0,0),(0,0),(padding_length,padding_length)), 'constant', constant_values=0)
b, a = scipy.signal.butter(4, params['bandpass_filt'], btype='bandpass', fs=params['downsrate'])
segmented_data = scipy.signal.lfilter(b, a, segmented_data, axis=-1)
segmented_data = segmented_data[:,:,padding_length:-padding_length]

# Z-score normalization
segmented_data = scipy.stats.zscore(segmented_data, axis=2, nan_policy='omit')

# Prepare for model input
kernels, chans, samples = 1, nChan, DesiredLen
segmented_data = segmented_data.reshape(segmented_data.shape[0], chans, samples, kernels)

# We need the 0-based true labels for comparison later.
true_labels_0based = segment_labels - 1

model = load_model(model_path, compile=False) # load without compiling

# Get predictions for all segments
predictions_prob = model.predict(segmented_data, verbose=0)
predicted_classes_segment = np.argmax(predictions_prob, axis=1) # The predicted class ID (0, 1, 2) for each segment

final_trial_predictions = []
true_trial_labels = []

# Iterate through the unique trial IDs
unique_trial_ids = np.unique(segment_trial_ids)

for trial_id in unique_trial_ids:
# 1. Group segments belonging to this trial
trial_indices = np.where(segment_trial_ids == trial_id)[0]

# 2. Get predictions and perform MAJORITY VOTE
trial_segment_predictions = predicted_classes_segment[trial_indices]

# scipy.stats.mode finds the most frequent prediction (the majority vote)
# mode returns (mode_value, count). [0] gets the value.
# We use axis=None to find the mode of the flattened array.
majority_prediction = scipy.stats.mode(trial_segment_predictions, keepdims=False)[0]

# 3. Store the majority vote and the true label (from the first segment)
final_trial_predictions.append(majority_prediction)
# Use the true label of the first segment as the true label for the whole trial
true_trial_labels.append(true_labels_0based[trial_indices[0]])

# Calculate Final Majority Voting Accuracy
final_trial_predictions = np.array(final_trial_predictions)
true_trial_labels = np.array(true_trial_labels)

majority_vote_acc = accuracy_score(true_trial_labels, final_trial_predictions)

print(f"Total Trials Evaluated: {len(unique_trial_ids)}")
print(f"Majority Voting Accuracy: {majority_vote_acc * 100:.2f}%")

return majority_vote_acc
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,9 @@ This work was supported by the National Institutes of Health via grants NS124564
[1] Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon, S. M., Hung, C. P., & Lance, B. J. EEGNet: a compact convolutional neural network for EEG-based brain-computer interfaces. Journal of neural engineering, 15, 056013. (2018).

Army Research Laboratory (ARL) EEGModels project repository: https://github.com/vlawhern/arl-eegmodels

## Changes

Added Evaluation code which was missing. Run the script file `run.sh` to run the code for all the subjects to train the base model and finetune it which then will save all the evaluation results one by one in a csv file.

The dependencies used are given in `requirements.txt`.
40 changes: 39 additions & 1 deletion main_model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@

# %%

from Functions import load_and_filter_data, generate_paths, train_models
from Functions import load_and_filter_data, generate_paths, train_models, evaluate_model
import tensorflow as tf

import os
import sys
import numpy as np
import pandas as pd

# Read command-line arguments
subj_id = int(sys.argv[1])
Expand Down Expand Up @@ -59,6 +61,42 @@

save_name = os.path.join(save_folder, f'S{subj_id:02}_Sess{session_num:02}_{task}_{nclass}class_{modeltype}.h5')

finetune_eval_mode = False
if modeltype == 'Finetune':
finetune_eval_mode = True
params['modelpath'] = save_name.replace('Finetune','Orig') # the pre-trained model to be fine-tuned on
save_name = train_models(data, label, save_name, params)

#evaluation
model_path = save_name

eval_data_paths = generate_paths(subj_id, task, nclass, session_num, model_type = 'Finetune', data_folder = data_folder, finetune_eval_mode=finetune_eval_mode)
acc = evaluate_model(model_path, eval_data_paths, params)

# Define the metrics data
metrics_data = {
'subject_id': [f'S{subj_id:02}'],
'session_number': [session_num],
'nclass': [nclass],
'task': [task],
'modeltype': [modeltype],
'accuracy': [acc]
}

# Convert to DataFrame
metrics_df = pd.DataFrame(metrics_data)

# File path for the CSV
csv_file_path = 'metrics.csv'

# Check if the file exists
if os.path.exists(csv_file_path):
# If it exists, append the new data
existing_df = pd.read_csv(csv_file_path)
updated_df = pd.concat([existing_df, metrics_df], ignore_index=True)
else:
# If it doesn't exist, create a new file
updated_df = metrics_df

# Save the updated DataFrame to the CSV file
updated_df.to_csv(csv_file_path, index=False)
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
tensorflow==2.4.0
numpy==1.19.5
pandas==1.1.5
scikit-learn==1.3.2
scipy==1.10.1
16 changes: 16 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash

subject_ids=($(seq 1 21))
sessions=(1 2)
nclasses=(2 3)
modeltypes=("Orig" "Finetune")

for modeltype in "${modeltypes[@]}"; do
for subject_id in "${subject_ids[@]}"; do
for session in "${sessions[@]}"; do
for nclass in "${nclasses[@]}"; do
python main_model_training.py "$subject_id" "$session" "$nclass" ME "$modeltype"
done
done
done
done