Skip to content

Cross Dataset Evaluation #703

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 20 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/whats_new.rst
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ Enhancements
- Adding :class:`moabb.dataset.Beetl2021A` and :class:`moabb.dataset.Beetl2021B`(:gh:`675` by `Samuel Boehm_`)
- Adding :class:`moabb.evaluations.splitters.CrossSessionSplitter` (:gh:`720` by `Bruna Lopes`_ and `Bruno Aristimunha`_)
- Adding :class:`moabb.dataset.base.BaseBIDSDataset` and :class:`moabb.dataset.base.LocalBIDSDataset` (:gh:`724` by `Pierre Guetschel`_)

- Adding :class:`moabb.evaluations.CrossDatasetEvaluation` for cross-dataset evaluation, enabling training on one dataset and testing on another (:gh:`703` by `Ali Imran`_)

Bugs
~~~~
@@ -547,3 +547,4 @@ API changes
.. _AFF: https://github.com/allwaysFindFood
.. _Marco Congedo: https://github.com/Marco-Congedo
.. _Samuel Boehm: https://github.com/Samuel-Boehm
.. _Ali Imran: https://github.com/EazyAl
186 changes: 186 additions & 0 deletions examples/advanced_examples/plot_cross_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""Cross-dataset motor imagery classification example.

This example demonstrates how to perform cross-dataset evaluation using MOABB,
training on one dataset and testing on another.
"""

# Standard library imports
import logging
from typing import List

# Third-party imports
import matplotlib.pyplot as plt
import mne
import numpy as np
import pandas as pd
from mne.io import RawArray
from mne.io.cnt.cnt import RawCNT
from pyriemann.estimation import Covariances
from pyriemann.spatialfilters import CSP
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer
from sklearn.svm import SVC

# MOABB imports
from moabb import set_log_level
from moabb.analysis.plotting import score_plot
from moabb.datasets import BNCI2014001, Zhou2016
from moabb.evaluations.evaluations import CrossDatasetEvaluation
from moabb.paradigms import MotorImagery


# Configure logging
set_log_level("WARNING")
logging.getLogger("mne").setLevel(logging.ERROR)


def create_pipeline(common_channels: List[str]) -> Pipeline:
"""Create classification pipeline with CSP and SVM.

Parameters
----------
common_channels : List[str]
List of channel names to use in the pipeline

Returns
-------
Pipeline
Sklearn pipeline for classification
"""

def raw_to_data(X: np.ndarray) -> np.ndarray:
"""Convert raw MNE data to numpy array format.

Parameters
----------
X : np.ndarray or mne.io.Raw
Input data to convert

Returns
-------
np.ndarray
Converted data array
"""
if hasattr(X, "get_data"):
picks = mne.pick_channels(
X.info["ch_names"], include=common_channels, ordered=True
)
data = X.get_data()
if data.ndim == 2:
data = data.reshape(1, *data.shape)
data = data[:, picks, :]
return data
return X

pipeline = Pipeline(
[
("to_array", FunctionTransformer(raw_to_data)),
("covariances", Covariances(estimator="oas")),
("csp", CSP(nfilter=4, log=True)),
("classifier", SVC(kernel="rbf", C=0.1)),
]
)

return pipeline


# Define datasets
train_dataset = BNCI2014001()
test_dataset = Zhou2016()

# Create a dictionary of datasets for easier handling
datasets_dict = {"train_dataset": train_dataset, "test_dataset": test_dataset}

# Get the list of channels from each dataset before matching
print("\nChannels before matching:")
for ds_name, ds in datasets_dict.items():
try:
# Load data for first subject to get channel information
data = ds.get_data([ds.subject_list[0]]) # Get data for first subject
first_subject = list(data.keys())[0]
first_session = list(data[first_subject].keys())[0]
first_run = list(data[first_subject][first_session].keys())[0]
run_data = data[first_subject][first_session][first_run]

if isinstance(run_data, (RawArray, RawCNT)):
channels = run_data.info["ch_names"]
else:
# Assuming the channels are stored in the dataset class after loading
channels = ds.channels
print(f"{ds_name}: {channels}")
except Exception as e:
print(f"Error getting channels for {ds_name}: {str(e)}")
Comment on lines +94 to +112
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this


# Use MOABB's match_all for channel handling
print("\nMatching channels across datasets...")
paradigm = MotorImagery()

# Apply match_all to all datasets
all_datasets = list(datasets_dict.values())
paradigm.match_all(all_datasets, channel_merge_strategy="intersect")

# Get channels from all datasets after matching to ensure we have the correct intersection
all_channels_after_matching = []
print("\nChannels after matching:")
for i, (ds_name, _) in enumerate(datasets_dict.items()):
ds = all_datasets[i] # Get the matched dataset
try:
data = ds.get_data([ds.subject_list[0]])
subject = list(data.keys())[0]
session = list(data[subject].keys())[0]
run = list(data[subject][session].keys())[0]
run_data = data[subject][session][run]

if isinstance(run_data, (RawArray, RawCNT)):
channels = run_data.info["ch_names"]
else:
channels = ds.channels
all_channels_after_matching.append(set(channels))
print(f"{ds_name}: {channels}")
except Exception as e:
print(f"Error getting channels for {ds_name} after matching: {str(e)}")

# Get the intersection of all channel sets
common_channels = sorted(list(set.intersection(*all_channels_after_matching)))
print(f"\nCommon channels after matching: {common_channels}")
print(f"Number of common channels: {len(common_channels)}")

# Update the datasets_dict with the matched datasets
for i, (name, _) in enumerate(datasets_dict.items()):
datasets_dict[name] = all_datasets[i]

train_dataset = datasets_dict["train_dataset"]
test_dataset = datasets_dict["test_dataset"]

# Initialize the paradigm with common channels
paradigm = MotorImagery(channels=common_channels, n_classes=2, fmin=8, fmax=32)
Comment on lines +122 to +156
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this.
match_all don't change the number of channels in the dataset,
it just automatically set the filter in the paradigm.


# Initialize the CrossDatasetEvaluation
evaluation = CrossDatasetEvaluation(
paradigm=paradigm,
train_dataset=train_dataset,
test_dataset=test_dataset,
hdf5_path="./res_test",
save_model=True,
)

# Run the evaluation
results = []
for result in evaluation.evaluate(
dataset=None, pipelines={"CSP_SVM": create_pipeline(common_channels)}
):
result["subject"] = "all"
print(f"Cross-dataset score: {result.get('score', 'N/A'):.3f}")
results.append(result)

# Convert results to DataFrame and process
results_df = pd.DataFrame(results)
results_df["dataset"] = results_df["dataset"].apply(lambda x: x.__class__.__name__)

# Print evaluation scores
print("\nCross-dataset evaluation scores:")
print(results_df[["dataset", "score", "time"]])

# Plot the results
score_plot(results_df)
plt.show()
Loading
Loading