Skip to content
This repository was archived by the owner on Jun 30, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions medcat/cdb/cdb.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Iterable, Any, Collection
from typing import Iterable, Any, Collection, Union

from medcat.storage.serialisables import AbstractSerialisable
from medcat.cdb.concepts import CUIInfo, NameInfo, TypeInfo
from medcat.cdb.concepts import get_new_cui_info, get_new_name_info
from medcat.cdb.concepts import reset_cui_training
from medcat.storage.serialisers import deserialise
from medcat.storage.serialisers import (
deserialise, AvailableSerialisers, serialise)
from medcat.utils.defaults import default_weighted_average, StatusTypes as ST
from medcat.utils.hasher import Hasher
from medcat.preprocessors.cleaners import NameDescriptor
Expand Down Expand Up @@ -480,6 +481,23 @@ def get_basic_info(self) -> CDBInfo:
"Supervised training history": sup_history,
}

def save(self, save_path: str,
serialiser: Union[
str, AvailableSerialisers] = AvailableSerialisers.dill,
overwrite: bool = False,
) -> None:
"""Save CDB at path.

Args:
save_path (str):
The path to save at.
serialiser (Union[ str, AvailableSerialisers], optional):
The serialiser. Defaults to AvailableSerialisers.dill.
overwrite (bool, optional):
Whether to allow overwriting existing files. Defaults to False.
"""
serialise(serialiser, self, save_path, overwrite=overwrite)

@classmethod
def load(cls, path: str) -> 'CDB':
cdb = deserialise(path)
Expand Down
22 changes: 20 additions & 2 deletions medcat/vocab.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional, Any, cast
from typing import Optional, Any, cast, Union
from typing_extensions import TypedDict

# import dill
import numpy as np

from medcat.storage.serialisables import AbstractSerialisable
from medcat.storage.serialisers import deserialise
from medcat.storage.serialisers import (
deserialise, AvailableSerialisers, serialise)


WordDescriptor = TypedDict('WordDescriptor',
Expand Down Expand Up @@ -293,6 +294,23 @@ def __eq__(self, other: Any) -> bool:
self.index2word == other.index2word and
self.vec_index2word == other.vec_index2word)

def save(self, save_path: str,
serialiser: Union[
str, AvailableSerialisers] = AvailableSerialisers.dill,
overwrite: bool = False,
) -> None:
"""Save Vocab at path.

Args:
save_path (str):
The path to save at.
serialiser (Union[ str, AvailableSerialisers], optional):
The serialiser. Defaults to AvailableSerialisers.dill.
overwrite (bool, optional):
Whether to allow overwriting existing files. Defaults to False.
"""
serialise(serialiser, self, save_path, overwrite=overwrite)

@classmethod
def load(cls, path: str) -> 'Vocab':
vocab = deserialise(path)
Expand Down
10 changes: 9 additions & 1 deletion tests/cdb/test_cdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from medcat.preprocessors.cleaners import NameDescriptor

from unittest import TestCase
import tempfile

from .. import UNPACKED_EXAMPLE_MODEL_PACK_PATH

Expand All @@ -21,7 +22,14 @@ class CDBTests(TestCase):
def setUpClass(cls):
cls.cdb = cast(cdb.CDB, deserialise(cls.CDB_PATH))

def test_convenience_methods(self):
def test_convenience_method_save(self):
with tempfile.TemporaryDirectory() as dir:
self.cdb.save(dir)
self.assertTrue(os.path.exists(dir))
obj = deserialise(dir)
self.assertIsInstance(obj, cdb.CDB)

def test_convenience_method_load(self):
ccdb = cdb.CDB.load(self.CDB_PATH)
self.assertIsInstance(ccdb, cdb.CDB)

Expand Down
7 changes: 7 additions & 0 deletions tests/test_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,13 @@ def test_has_correct_vectors(self):
with self.subTest(w):
self.assertEqual(info['vector'].shape, self.EXP_SHAPE)

def test_convenience_save(self):
with tempfile.TemporaryDirectory() as dir:
self.vocab.save(dir)
self.assertTrue(os.path.exists(dir))
obj = deserialise(dir)
self.assertIsInstance(obj, Vocab)

def test_convenience_load(self):
vocab = Vocab.load(self.VOCAB_PATH)
self.assertIsInstance(vocab, Vocab)