diff --git a/medcat/cdb/cdb.py b/medcat/cdb/cdb.py index 339b65a..b4c2dd4 100644 --- a/medcat/cdb/cdb.py +++ b/medcat/cdb/cdb.py @@ -1,10 +1,11 @@ -from typing import Iterable, Any, Collection +from typing import Iterable, Any, Collection, Union from medcat.storage.serialisables import AbstractSerialisable from medcat.cdb.concepts import CUIInfo, NameInfo, TypeInfo from medcat.cdb.concepts import get_new_cui_info, get_new_name_info from medcat.cdb.concepts import reset_cui_training -from medcat.storage.serialisers import deserialise +from medcat.storage.serialisers import ( + deserialise, AvailableSerialisers, serialise) from medcat.utils.defaults import default_weighted_average, StatusTypes as ST from medcat.utils.hasher import Hasher from medcat.preprocessors.cleaners import NameDescriptor @@ -480,6 +481,23 @@ def get_basic_info(self) -> CDBInfo: "Supervised training history": sup_history, } + def save(self, save_path: str, + serialiser: Union[ + str, AvailableSerialisers] = AvailableSerialisers.dill, + overwrite: bool = False, + ) -> None: + """Save CDB at path. + + Args: + save_path (str): + The path to save at. + serialiser (Union[ str, AvailableSerialisers], optional): + The serialiser. Defaults to AvailableSerialisers.dill. + overwrite (bool, optional): + Whether to allow overwriting existing files. Defaults to False. + """ + serialise(serialiser, self, save_path, overwrite=overwrite) + @classmethod def load(cls, path: str) -> 'CDB': cdb = deserialise(path) diff --git a/medcat/vocab.py b/medcat/vocab.py index 02aedd7..2511880 100644 --- a/medcat/vocab.py +++ b/medcat/vocab.py @@ -1,11 +1,12 @@ -from typing import Optional, Any, cast +from typing import Optional, Any, cast, Union from typing_extensions import TypedDict # import dill import numpy as np from medcat.storage.serialisables import AbstractSerialisable -from medcat.storage.serialisers import deserialise +from medcat.storage.serialisers import ( + deserialise, AvailableSerialisers, serialise) WordDescriptor = TypedDict('WordDescriptor', @@ -293,6 +294,23 @@ def __eq__(self, other: Any) -> bool: self.index2word == other.index2word and self.vec_index2word == other.vec_index2word) + def save(self, save_path: str, + serialiser: Union[ + str, AvailableSerialisers] = AvailableSerialisers.dill, + overwrite: bool = False, + ) -> None: + """Save Vocab at path. + + Args: + save_path (str): + The path to save at. + serialiser (Union[ str, AvailableSerialisers], optional): + The serialiser. Defaults to AvailableSerialisers.dill. + overwrite (bool, optional): + Whether to allow overwriting existing files. Defaults to False. + """ + serialise(serialiser, self, save_path, overwrite=overwrite) + @classmethod def load(cls, path: str) -> 'Vocab': vocab = deserialise(path) diff --git a/tests/cdb/test_cdb.py b/tests/cdb/test_cdb.py index 0e4545d..50634fa 100644 --- a/tests/cdb/test_cdb.py +++ b/tests/cdb/test_cdb.py @@ -7,6 +7,7 @@ from medcat.preprocessors.cleaners import NameDescriptor from unittest import TestCase +import tempfile from .. import UNPACKED_EXAMPLE_MODEL_PACK_PATH @@ -21,7 +22,14 @@ class CDBTests(TestCase): def setUpClass(cls): cls.cdb = cast(cdb.CDB, deserialise(cls.CDB_PATH)) - def test_convenience_methods(self): + def test_convenience_method_save(self): + with tempfile.TemporaryDirectory() as dir: + self.cdb.save(dir) + self.assertTrue(os.path.exists(dir)) + obj = deserialise(dir) + self.assertIsInstance(obj, cdb.CDB) + + def test_convenience_method_load(self): ccdb = cdb.CDB.load(self.CDB_PATH) self.assertIsInstance(ccdb, cdb.CDB) diff --git a/tests/test_vocab.py b/tests/test_vocab.py index 028a5b2..b5b445b 100644 --- a/tests/test_vocab.py +++ b/tests/test_vocab.py @@ -181,6 +181,13 @@ def test_has_correct_vectors(self): with self.subTest(w): self.assertEqual(info['vector'].shape, self.EXP_SHAPE) + def test_convenience_save(self): + with tempfile.TemporaryDirectory() as dir: + self.vocab.save(dir) + self.assertTrue(os.path.exists(dir)) + obj = deserialise(dir) + self.assertIsInstance(obj, Vocab) + def test_convenience_load(self): vocab = Vocab.load(self.VOCAB_PATH) self.assertIsInstance(vocab, Vocab)