diff --git a/sdk/atriumdb/adb_functions.py b/sdk/atriumdb/adb_functions.py index 6a19be27..ce5e0fd0 100644 --- a/sdk/atriumdb/adb_functions.py +++ b/sdk/atriumdb/adb_functions.py @@ -784,9 +784,14 @@ def merge_sorted_messages(message_starts_1, message_sizes_1, values_1, combined_timestamps = np.concatenate((timestamps_1, timestamps_2)) combined_values = np.concatenate((values_1, values_2)) - sorted_indices = np.argsort(combined_timestamps) - sorted_timestamps = combined_timestamps[sorted_indices] - sorted_values = combined_values[sorted_indices] + # Create array indicators to ensure stable sorting where array 2 overwrites array 1 + array_indicators = np.concatenate((np.zeros(len(timestamps_1), dtype=int), + np.ones(len(timestamps_2), dtype=int))) + + # Sort by timestamp first, then by array indicator (so array 2 comes after array 1 for ties) + sort_keys = np.lexsort((array_indicators, combined_timestamps)) + sorted_timestamps = combined_timestamps[sort_keys] + sorted_values = combined_values[sort_keys] # Convert the sorted timestamps into "sample times" sample_times = sorted_timestamps / period_ns diff --git a/sdk/atriumdb/atrium_sdk.py b/sdk/atriumdb/atrium_sdk.py index f3c17c4e..02f00358 100644 --- a/sdk/atriumdb/atrium_sdk.py +++ b/sdk/atriumdb/atrium_sdk.py @@ -4438,15 +4438,12 @@ def get_label_time_series(self, label_name=None, label_name_id=None, device_tag= # Create a binary array to indicate presence of a label for each timestamp, if not provided. if out is not None: - allowed_dtypes = [np.bool_] + np.sctypes['int'] # Allowed dtypes: boolean and all integer types - if out.shape != timestamp_array.shape: raise ValueError( f"The 'out' array shape {out.shape} doesn't match expected shape {timestamp_array.shape}.") - if out.dtype not in allowed_dtypes: - valid_dtypes_str = ", ".join([dtype.__name__ for dtype in allowed_dtypes]) - raise ValueError(f"The 'out' array dtype is {out.dtype}, but expected one of: {valid_dtypes_str}.") + if out.dtype.kind not in ('b', 'i', 'u'): # boolean, signed int, unsigned int + raise ValueError(f"The 'out' array dtype is {out.dtype}, but expected boolean or integer type.") if not np.all(out == 0): # Ensure that the out array starts with all zeros raise ValueError("The 'out' array should be initialized with zeros. It contains non-zero values.") diff --git a/sdk/atriumdb/windowing/definition.py b/sdk/atriumdb/windowing/definition.py index 609fc784..d9a98f00 100644 --- a/sdk/atriumdb/windowing/definition.py +++ b/sdk/atriumdb/windowing/definition.py @@ -536,8 +536,9 @@ def _check_times_and_warn(self, times, source_type, source_id): raise ValueError(f"{source_type} {source_id}: {key} cannot be negative") if value < 1e9 or (value < 1e16 and key in ['start', 'end', 'time0']): - warnings.warn(f"{source_type} {source_id}: The epoch for {key}: {value} looks like it's " - f"formatted in seconds. However {key} will be interpreted as nanosecond data.") + # warnings.warn(f"{source_type} {source_id}: The epoch for {key}: {value} looks like it's " + # f"formatted in seconds. However {key} will be interpreted as nanosecond data.") + pass if ('pre' in time_dict or 'post' in time_dict) and 'time0' not in time_dict: raise ValueError(f"{source_type} {source_id}: 'pre' and 'post' cannot be provided without 'time0'") diff --git a/sdk/atriumdb/windowing/light_mapped_iterator.py b/sdk/atriumdb/windowing/light_mapped_iterator.py index ed380e12..131e9cf1 100644 --- a/sdk/atriumdb/windowing/light_mapped_iterator.py +++ b/sdk/atriumdb/windowing/light_mapped_iterator.py @@ -242,7 +242,7 @@ def _get_window(self, source_info, window_start_time, window_end_time): # Create Window object window = Window( signals=signals, - start_time=window_start_time, + start_time=int(window_start_time), device_id=device_id, patient_id=patient_id, label_time_series=label_time_series, diff --git a/sdk/atriumdb/windowing/verify_definition.py b/sdk/atriumdb/windowing/verify_definition.py index f34122e7..d30e0a0c 100644 --- a/sdk/atriumdb/windowing/verify_definition.py +++ b/sdk/atriumdb/windowing/verify_definition.py @@ -201,13 +201,14 @@ def _validate_sources(definition, sdk, validated_measure_list, gap_tolerance=Non def _get_validated_entries(time_specs, validated_measures, sdk, device_id=None, patient_id=None, gap_tolerance=None, start_time_n=None, end_time_n=None): - gap_tolerance = 60 * 60 * 1_000_000_000 if gap_tolerance is None else gap_tolerance # 1 hour nano default + gap_tolerance = 60 * 60 * 1_000_000_000 if gap_tolerance is None else gap_tolerance - union_intervals = intervals_union_list( - [sdk.get_interval_array( - measure_info['id'], device_id=device_id, patient_id=patient_id, gap_tolerance_nano=gap_tolerance, - start=start_time_n, end=end_time_n) - for measure_info in validated_measures]) + union_intervals = intervals_union_list([ + sdk.get_interval_array( + measure_info['id'], device_id=device_id, patient_id=patient_id, + gap_tolerance_nano=gap_tolerance, start=start_time_n, end=end_time_n) + for measure_info in validated_measures + ]) merged_union_intervals = [] for start, end in union_intervals: @@ -224,11 +225,24 @@ def _get_validated_entries(time_specs, validated_measures, sdk, device_id=None, f"time regions for the specified measures. Skipping") return None + # Apply global bounds to ALL cases, including "all" if time_specs == "all": - return union_intervals.tolist() + # Apply global start/end time constraints to union_intervals + constrained_intervals = [] + for start, end in union_intervals: + # Constrain each interval to global bounds + if start_time_n is not None: + start = max(start, start_time_n) + if end_time_n is not None: + end = min(end, end_time_n) - interval_list = [] + # Only include if the interval is still valid after constraining + if start < end: + constrained_intervals.append([start, end]) + + return constrained_intervals + interval_list = [] for region_data in time_specs: if 'time0' in region_data: start, end = region_data['time0'] - region_data['pre'], region_data['time0'] + region_data['post'] @@ -251,7 +265,6 @@ def _get_validated_entries(time_specs, validated_measures, sdk, device_id=None, return interval_list - def compute_hash(data): """Compute a SHA256 hash of the given data.""" data_string = json.dumps(data, sort_keys=True) diff --git a/sdk/atriumdb/windowing/windowing_functions.py b/sdk/atriumdb/windowing/windowing_functions.py index 4c0b9c78..4c75f493 100644 --- a/sdk/atriumdb/windowing/windowing_functions.py +++ b/sdk/atriumdb/windowing/windowing_functions.py @@ -194,7 +194,7 @@ def get_window_list(device_id, patient_id, validated_measure_list, source_batch_ result_window = Window( signals=signal_dictionary, - start_time=window_start_time, + start_time=int(window_start_time), device_id=device_id, patient_id=patient_id, label_time_series=label_time_series, diff --git a/sdk/docs/source/conf.py b/sdk/docs/source/conf.py index 0d8435b7..5b02ef6c 100644 --- a/sdk/docs/source/conf.py +++ b/sdk/docs/source/conf.py @@ -13,7 +13,7 @@ project = 'AtriumDB' copyright = '2024, The Hospital for Sick Children' author = 'LaussenLabs' -release = '2.5.0' +release = '2.5.1' # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/sdk/pyproject.toml b/sdk/pyproject.toml index 82367ada..d688ac26 100644 --- a/sdk/pyproject.toml +++ b/sdk/pyproject.toml @@ -12,7 +12,7 @@ mypkg = ["*.so", "*.dll"] [project] name = "atriumdb" -version = "2.5.0" +version = "2.5.1" description = "Timeseries Database" readme = "README.md" authors = [{name = "Robert Greer, William Dixon, Spencer Vecile"}, { name = "Robert Greer", email = "robert.greer@sickkids.ca"}, { name = "William Dixon", email = "will.dixon@sickkids.ca" }, { name = "Spencer Vecile", email = "spencer.vecile@sickkids.ca"}] @@ -37,55 +37,66 @@ classifiers = [ keywords = ["atriumdb", "timeseries", "database", "waveform", "medical data", "machine learning", "data", "data science"] dependencies = [ 'tomli; python_version < "3.11"', - "numpy >= 1.21.4, < 2", - "PyYAML >= 6.0", - "tqdm >= 4.65.0, < 5", + "numpy >= 2.2.6, < 3", + "PyYAML >= 6.0.2", + "tqdm >= 4.67.1, < 5", ] [project.optional-dependencies] mariadb = [ - "mariadb == 1.1.10", + "mariadb == 1.1.13", ] remote = [ - "requests >= 2.28.2, < 3", - "PyJWT[crypto] >= 2.8.0, < 3", - "python-dotenv >= 0.21, < 1", - "websockets >= 12.0, < 13", + "requests >= 2.32.4, < 3", + "PyJWT[crypto] >= 2.6.0, < 3", + "python-dotenv >= 0.21.1, < 1", + "websockets >= 15.0.1, < 16", ] cli = [ - "requests >= 2.28.2, < 3", + "requests >= 2.32.4, < 3", "qrcodeT >= 1.0.4, < 2", - "click >= 8.1.3, < 9", - "pandas >= 1.5, < 2", + "click >= 8.2.1, < 9", + "pandas >= 2.3.1, < 3", "tabulate >= 0.9.0, < 1", - "fastparquet == 2023.2.0", - "python-dotenv >= 0.21, < 1", - "PyYAML >= 6.0" + "fastparquet == 2024.11.0", + "python-dotenv >= 0.21.1, < 1", + "PyYAML >= 6.0.2" ] all = [ - "mariadb == 1.1.10", - "requests >= 2.28.2, < 3", - "PyJWT[crypto] >= 2.8.0, < 3", - "websockets >= 12.0, < 13", + "mariadb == 1.1.13", + "requests >= 2.32.4, < 3", + "PyJWT[crypto] >= 2.6.0, < 3", + "websockets >= 15.0.1, < 16", "qrcodeT >= 1.0.4, < 2", - "python-dotenv >= 0.21, < 1", - "click >= 8.1.3, < 9", - "pandas >= 1.5, < 2", + "python-dotenv >= 0.21.1, < 1", + "click >= 8.2.1, < 9", + "pandas >= 2.3.1, < 3", "tabulate >= 0.9.0, < 1", - "fastparquet >= 2023.2.0, < 2024", - "wfdb >= 4.1.0, < 5", - "pyarrow >= 16.0.0, < 17", - "tzdata >= 2024.1, < 2025" + "fastparquet >= 2024.11.0, < 2025", + "wfdb >= 4.2.0, < 5", + "pyarrow >= 21.0.0, < 22", + "tzdata >= 2025.2, < 2026" ] - -[dev-dependencies] testing = [ - "wfdb >= 4.1.0, < 5", + "mariadb == 1.1.13", + "requests >= 2.32.4, < 3", + "PyJWT[crypto] >= 2.6.0, < 3", + "websockets >= 15.0.1, < 16", + "qrcodeT >= 1.0.4, < 2", + "python-dotenv >= 0.21.1, < 1", + "click >= 8.2.1, < 9", + "pandas >= 2.3.1, < 3", + "tabulate >= 0.9.0, < 1", + "fastparquet >= 2024.11.0, < 2025", + "wfdb >= 4.2.0, < 5", + "pyarrow >= 21.0.0, < 22", + "tzdata >= 2025.2, < 2026", "names >= 0.3.0, < 1", - "uvicorn >= 0.27.0, < 1", - "pytest" + "uvicorn >= 0.35.0, < 1", + "pytest >= 7.2.1", + "fastapi >= 0.95.0, < 1" ] - +[dev-dependencies] requires-python = ">=3.10" [tool.pytest.ini_options] @@ -100,4 +111,4 @@ Documentation = "https://docs.atriumdb.io/" [project.scripts] hello = "atriumdb.cli.hello:hello" -atriumdb = "atriumdb.cli.atriumdb_cli:cli" +atriumdb = "atriumdb.cli.atriumdb_cli:cli" \ No newline at end of file diff --git a/sdk/requirements-all.txt b/sdk/requirements-all.txt index d7998b4b..13831a5f 100644 --- a/sdk/requirements-all.txt +++ b/sdk/requirements-all.txt @@ -1,24 +1,23 @@ -auth0_python==4.2.0 -click==8.1.3 -fastapi==0.95.2 -uvicorn==0.27.0 -importlib_resources==5.12.0 -mariadb==1.1.6 +auth0_python==4.1.0 +click==8.2.1 +fastapi==0.95.0 +uvicorn==0.35.0 +mariadb==1.1.13 names==0.3.0 -numpy==1.21.6 -pandas==1.5.3 +numpy==2.2.6 +pandas==2.3.1 pytest==7.2.1 python-decouple==3.8 -python-dotenv==1.0.0 -PyYAML==6.0 +python-dotenv==0.21.1 +PyYAML==6.0.2 qrcodeT==1.0.4 -requests==2.28.2 -setuptools==65.5.0 +requests==2.32.4 +setuptools==65.6.3 tabulate==0.9.0 tomli==2.0.1 -tqdm==4.65.0 +tqdm==4.67.1 urllib3==1.26.14 -wfdb==4.1.1 -PyJWT[crypto]~=2.8.0 -websockets ~= 12.0 -pyarrow~=16.0.0 \ No newline at end of file +wfdb==4.2.0 +PyJWT[crypto]~=2.6.0 +websockets ~= 15.0.1 +pyarrow~=21.0.0 \ No newline at end of file diff --git a/sdk/requirements.txt b/sdk/requirements.txt index 3386c2ca..b1100755 100644 --- a/sdk/requirements.txt +++ b/sdk/requirements.txt @@ -1,10 +1,10 @@ -numpy==1.21.6 -requests==2.28.2 -mariadb==1.1.6 +numpy==2.2.6 +requests==2.32.4 +mariadb==1.1.13 tomli==2.0.1 -click>=8.1.3 +click>=8.2.1 qrcodeT==1.0.4 -pandas==1.5.3 -PyYAML==6.0 -PyJWT[crypto]~=2.8.0 -websockets ~= 12.0 +pandas==2.3.1 +PyYAML==6.0.2 +PyJWT[crypto]~=2.6.0 +websockets ~= 15.0.1 \ No newline at end of file diff --git a/sdk/tests/test_api.py b/sdk/tests/test_api.py index c0d688ac..02e32ec4 100644 --- a/sdk/tests/test_api.py +++ b/sdk/tests/test_api.py @@ -31,7 +31,7 @@ def test_api(): def start_server(): - uvicorn.run(app) + uvicorn.run(app, port=8123) # start server in daemon thread so it exits when complete api_thread = threading.Thread(target=start_server, daemon=True) @@ -53,7 +53,7 @@ def _test_api(db_type, dataset_location, connection_params): app.dependency_overrides[get_sdk_instance] = lambda: sdk # set up remote mode sdk to connect to the api - api_sdk = AtriumSDK(metadata_connection_type="api", api_url="http://127.0.0.1:8000", validate_token=False) + api_sdk = AtriumSDK(metadata_connection_type="api", api_url="http://127.0.0.1:8123", validate_token=False) # change the sdk token expiry so the test can work api_sdk.token_expiry = time.time() + 1_000_000 @@ -74,7 +74,7 @@ def _test_api_labels(db_type, dataset_location, connection_params): app.dependency_overrides[get_sdk_instance] = lambda: sdk # set up remote mode sdk to connect to the api - api_sdk = AtriumSDK(metadata_connection_type="api", api_url="http://127.0.0.1:8000", validate_token=False) + api_sdk = AtriumSDK(metadata_connection_type="api", api_url="http://127.0.0.1:8123", validate_token=False) # change the sdk token expiry so the test can work api_sdk.token_expiry = time.time() + 1_000_000 diff --git a/sdk/tests/test_block_select.py b/sdk/tests/test_block_select.py index f4d185c3..5144cba3 100644 --- a/sdk/tests/test_block_select.py +++ b/sdk/tests/test_block_select.py @@ -38,7 +38,7 @@ def test_block_select(): - maria_handler = MariaDBHandler(host, user, password, DB_NAME) + maria_handler = MariaDBHandler(host, user, password, DB_NAME, port) maria_handler.maria_connect_no_db().cursor().execute(f"DROP DATABASE IF EXISTS {DB_NAME}") maria_handler.create_schema() _test_block_select(maria_handler) diff --git a/sdk/tests/test_definition.py b/sdk/tests/test_definition.py index 43604a09..3b03c786 100644 --- a/sdk/tests/test_definition.py +++ b/sdk/tests/test_definition.py @@ -35,14 +35,8 @@ [ ("error1.yaml", ValueError, None, "Unexpected key: patient_id"), ("error2.yaml", ValueError, None, "Patient ID John must be an integer"), - ("error3.yaml", None, UserWarning, "patient_id 12345: The epoch for start: 1659344515 looks " - "like it's formatted in seconds. However start will be " - "interpreted as nanosecond data."), ("error4.yaml", ValueError, None, "Invalid time key: en. Allowed keys are: " "start, end, time0, pre, post"), - ("error5.yaml", None, UserWarning, "patient_id 12345: The epoch for pre: 60 looks like it's " - "formatted in seconds. However pre will be interpreted " - "as nanosecond data."), ("error6.yaml", ValueError, None, "patient_id 12345: start time 1682739300000000000 must be " "less than end time 1682739300000000000"), ("error7.yaml", ValueError, None, "pre cannot be negative"), diff --git a/sdk/tests/test_definition_start_end.py b/sdk/tests/test_definition_start_end.py index e0855ae3..9c303c37 100644 --- a/sdk/tests/test_definition_start_end.py +++ b/sdk/tests/test_definition_start_end.py @@ -70,6 +70,7 @@ def _test_transfer_start_end(db_type, dataset_location, connection_params): assert has_windows_outside_middle_third, "Sanity Check Failed" # Now use global start-end and confirm that all the windows are inside the bounds + definition = DatasetDefinition(measures=["measure"], device_ids={device_id: "all"}) for window in sdk_1.get_iterator(definition, window_duration_nano, window_slide_nano, time_units="ns", start_time=third_time, end_time=second_third_time): diff --git a/sdk/tests/test_iterator.py b/sdk/tests/test_iterator.py index a77e5c31..4363c466 100644 --- a/sdk/tests/test_iterator.py +++ b/sdk/tests/test_iterator.py @@ -28,6 +28,9 @@ DB_NAME = 'iterator' +TEST_DIR = Path(__file__).parent +EXAMPLE_DATA_DIR = TEST_DIR / "example_data" + def test_iterator(): _test_for_both(DB_NAME, _test_iterator) @@ -39,15 +42,13 @@ def _test_iterator(db_type, dataset_location, connection_params): # larger test write_mit_bih_to_dataset(sdk, max_records=2, seed=42) - # Uncomment line below to recreate test files - # create_test_definition_files(sdk) test_parameters = [ # filename, expected_device_id_type, expected_patient_id_type - ("./example_data/mitbih_seed_42_all_devices.yaml", int, int), - ("./example_data/mitbih_seed_42_all_patients.yaml", int, int), - ("./example_data/mitbih_seed_42_all_mrns.yaml", int, int), - ("./example_data/mitbih_seed_42_all_tags.yaml", int, int), + (str(EXAMPLE_DATA_DIR / "mitbih_seed_42_all_devices.yaml"), int, int), + (str(EXAMPLE_DATA_DIR / "mitbih_seed_42_all_patients.yaml"), int, int), + (str(EXAMPLE_DATA_DIR / "mitbih_seed_42_all_mrns.yaml"), int, int), + (str(EXAMPLE_DATA_DIR / "mitbih_seed_42_all_tags.yaml"), int, int), ] window_size_nano = 1_024 * 1_000_000_000 @@ -129,9 +130,9 @@ def _test_iterator(db_type, dataset_location, connection_params): assert isinstance(window.label, np.ndarray) # Test Definition Loader - definition_to_load = DatasetDefinition(filename="./example_data/mitbih_seed_42_device_one_only.yaml") + definition_to_load = DatasetDefinition(filename=str(EXAMPLE_DATA_DIR / "mitbih_seed_42_device_one_only.yaml")) sdk.load_definition(definition_to_load) - iterator = sdk.get_iterator(DatasetDefinition(filename="./example_data/mitbih_seed_42_all_devices.yaml"), + iterator = sdk.get_iterator(DatasetDefinition(filename=str(EXAMPLE_DATA_DIR / "mitbih_seed_42_all_devices.yaml")), window_size_nano, window_size_nano) expected_values = {} @@ -201,14 +202,6 @@ def _test_iterator(db_type, dataset_location, connection_params): assert window.label_time_series is None assert window.label is None -def _test_filter(db_type, dataset_location, connection_params): - sdk = AtriumSDK.create_dataset( - dataset_location=dataset_location, database_type=db_type, connection_params=connection_params) - - window_size = 100 - - times = np.arange() - def create_test_definition_files(sdk): measures = [] for measure_id, measure_info in sdk.get_all_measures().items(): @@ -224,31 +217,25 @@ def create_test_definition_files(sdk): labels = [label_info['name'] for label_info in sdk.get_all_label_names().values()] - print() - print(sdk.get_all_measures()) - print(sdk.get_all_devices()) - print(sdk.get_all_patients()) - print(sdk.get_all_label_names()) - definition = DatasetDefinition(measures=measures, device_ids=device_ids, labels=labels) - definition.save("./example_data/mitbih_seed_42_all_devices.yaml", force=True) + definition.save(str(EXAMPLE_DATA_DIR / "mitbih_seed_42_all_devices.yaml"), force=True) definition = DatasetDefinition(measures=measures, patient_ids=patient_ids, labels=labels) - definition.save("./example_data/mitbih_seed_42_all_patients.yaml", force=True) + definition.save(str(EXAMPLE_DATA_DIR / "mitbih_seed_42_all_patients.yaml"), force=True) definition = DatasetDefinition(measures=measures, mrns=mrns, labels=labels) - definition.save("./example_data/mitbih_seed_42_all_mrns.yaml", force=True) + definition.save(str(EXAMPLE_DATA_DIR / "mitbih_seed_42_all_mrns.yaml"), force=True) definition = DatasetDefinition(measures=measures, device_tags=device_tags, labels=labels) - definition.save("./example_data/mitbih_seed_42_all_tags.yaml", force=True) + definition.save(str(EXAMPLE_DATA_DIR / "mitbih_seed_42_all_tags.yaml"), force=True) def get_index_of_first_nan(arr): nan_index = np.argmax(np.isnan(arr)) if nan_index == 0 and not np.isnan(arr[0]): return len(arr) - 1 - return nan_index + return nan_index \ No newline at end of file diff --git a/sdk/tests/test_small_block.py b/sdk/tests/test_small_block.py index 729901ff..13cd366f 100644 --- a/sdk/tests/test_small_block.py +++ b/sdk/tests/test_small_block.py @@ -19,7 +19,6 @@ V_TYPE_DOUBLE import numpy as np import random -from matplotlib import pyplot as plt from tests.generate_wfdb import get_records from tests.test_mit_bih import create_gaps diff --git a/sdk/tests/test_transfer.py b/sdk/tests/test_transfer.py index a3c58d3e..550a0f0d 100644 --- a/sdk/tests/test_transfer.py +++ b/sdk/tests/test_transfer.py @@ -34,6 +34,7 @@ def test_transfer(): _test_for_both(DB_NAME, _test_transfer_without_re_encoding) _test_for_both(DB_NAME, _test_transfer_with_patient_context) _test_for_both(DB_NAME, _test_transfer_with_patient_context_deidentify_timeshift) + _test_for_both(DB_NAME, _test_partition_dataset) def _test_transfer(db_type, dataset_location, connection_params): @@ -45,6 +46,26 @@ def _test_transfer(db_type, dataset_location, connection_params): device_patient_dict = write_mit_bih_to_dataset(sdk_1, max_records=MAX_RECORDS, seed=SEED) + measures = [measure_info['tag'] for measure_info in sdk_1.get_all_measures().values()] + device_ids = {np.int64(device_id): "all" for device_id in sdk_1.get_all_devices().keys()} + definition = DatasetDefinition( + measures=measures, device_ids=device_ids, + labels=[label_name_info['name'] for label_name_info in sdk_1.get_all_label_names().values()]) + + + transfer_data(sdk_1, sdk_2, definition, gap_tolerance=None, deidentify=False, patient_info_to_transfer=None, + include_labels=False, reencode_waveforms=True) + + assert_mit_bih_to_dataset(sdk_2, device_patient_map=device_patient_dict, max_records=MAX_RECORDS, seed=SEED) + + +def _test_partition_dataset(db_type, dataset_location, connection_params): + # Setup + sdk_1 = AtriumSDK.create_dataset( + dataset_location=dataset_location, database_type=db_type, connection_params=connection_params) + + device_patient_dict = write_mit_bih_to_dataset(sdk_1, max_records=MAX_RECORDS, seed=SEED) + measures = [measure_info['tag'] for measure_info in sdk_1.get_all_measures().values()] device_ids = {np.int64(device_id): "all" for device_id in sdk_1.get_all_devices().keys()} label_name = list(sdk_1.get_all_label_names().values())[0]['name'] @@ -89,11 +110,6 @@ def _test_transfer(db_type, dataset_location, connection_params): verbose=False ) - transfer_data(sdk_1, sdk_2, definition, gap_tolerance=None, deidentify=False, patient_info_to_transfer=None, - include_labels=False, reencode_waveforms=True) - - assert_mit_bih_to_dataset(sdk_2, device_patient_map=device_patient_dict, max_records=MAX_RECORDS, seed=SEED) - def _test_transfer_without_re_encoding(db_type, dataset_location, connection_params): # Setup diff --git a/sdk/tests/testing_framework.py b/sdk/tests/testing_framework.py index a5d8f334..1382408e 100644 --- a/sdk/tests/testing_framework.py +++ b/sdk/tests/testing_framework.py @@ -39,7 +39,7 @@ def _test_for_both(db_name, test_function, *args): db_type = 'mariadb' shutil.rmtree(maria_dataset_path, ignore_errors=True) - maria_handler = MariaDBHandler(host, user, password, db_name) + maria_handler = MariaDBHandler(host, user, password, db_name, port) connection_params = { 'sqltype': db_type, 'host': host,