diff --git a/jsonstore.py b/jsonstore.py index 216b502..cdfe4c5 100644 --- a/jsonstore.py +++ b/jsonstore.py @@ -7,8 +7,12 @@ import json import os.path import sys +import io +import pyAesCrypt +from os import stat, remove from collections import OrderedDict from copy import deepcopy +from tempfile import mktemp __all__ = ["JsonStore"] @@ -34,11 +38,30 @@ def _do_auto_commit(self): self._save() def _load(self): + # Check if file exists. Create it if it doesn't if not os.path.exists(self._path): - with open(self._path, "w+b") as store: - store.write("{}".encode("utf-8")) - with open(self._path, "r+b") as store: - raw_data = store.read().decode("utf-8") + empty_json_data = "{}".encode(self._encoding) + if self._secure: + tempFile = mktemp() + with open(tempFile, "wb") as tempStore: + tempStore.write(empty_json_data) + pyAesCrypt.encryptFile(tempFile, self._path, self._password, self._bufferSize) + os.remove(tempFile) + else: + with open(self._path, "wb") as store: + store.write(empty_json_data) + + # Read the contents of the file + if self._secure: + tempFile = mktemp() + pyAesCrypt.decryptFile(self._path, tempFile, self._password, self._bufferSize) + with open(tempFile, "rb") as tmp: + raw_data = tmp.read().decode(self._encoding) + os.remove(tempFile) + else: + with open(self._path, "rb") as store: + raw_data = store.read().decode(self._encoding) + if not raw_data: data = OrderedDict() else: @@ -48,26 +71,39 @@ def _load(self): raise ValueError("Root element is not an object") self.__dict__["_data"] = data + def get_dump(self): + return self._data + def _save(self): temp = self._path + "~" - with open(temp, "wb") as store: - output = json.dumps(self._data, indent=self._indent) - store.write(output.encode("utf-8")) + tempFile = temp + "2" if self._secure else temp + with open(temp, "wb") as tempStore: + jsonStr = json.dumps(self._data, indent=self._indent) + data = jsonStr.encode(self._encoding) + tempStore.write(data) + + if self._secure: + pyAesCrypt.encryptFile(temp, tempFile, self._password, self._bufferSize) + os.remove(temp) if sys.version_info >= (3, 3): - os.replace(temp, self._path) + os.replace(tempFile, self._path) elif os.name == "windows": os.remove(self._path) - os.rename(temp, self._path) + os.rename(tempFile, self._path) else: - os.rename(temp, self._path) + os.rename(tempFile, self._path) - def __init__(self, path, indent=2, auto_commit=True): + def __init__(self, path, indent=2, auto_commit=False, password=None): self.__dict__.update( { "_auto_commit": auto_commit, "_data": None, "_path": path, + "_encoding": "utf-8", + "_secure": True if password else None, + "_password": password, + "_bufferSize": 64 * 1024, "_indent": indent, "_states": [], } diff --git a/test_jsonstore_secure.py b/test_jsonstore_secure.py new file mode 100644 index 0000000..7e875af --- /dev/null +++ b/test_jsonstore_secure.py @@ -0,0 +1,339 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +from __future__ import absolute_import + +import json +import os +import unittest +import pyAesCrypt +from tempfile import mktemp + +from jsonstore import JsonStore + +class TransactionBreaker(Exception): + pass + + +class Tests(unittest.TestCase): + TEST_DATA = ( + ("string", "hello"), + ("unicode", u"💩"), + ("integer", 1), + ("none", None), + ("big_integer", 18446744073709551616), + ("float", 1.0), + ("boolean", True), + ("list", [1, 2]), + ("tuple", (1, 2)), + ("dictionary", {"key": "value"}), + ) + + TEST_PASSWORD = "fooTestIzDaBest" + BUFFER_SIZE = 64 * 1024 + + def setUp(self): + self._store_file = mktemp() + ".aes" + self.store = JsonStore(self._store_file, indent=None, auto_commit=True, password=self.TEST_PASSWORD) + + def tearDown(self): + os.remove(self._store_file) + + def _setattr(self, key, value): + """ + Return a callable that assigns self.store.key to value + """ + + def handle(): + setattr(self.store, key, value) + + return handle + + def _setitem(self, key, value): + """ + Return a callable that assigns self.store[key] to value + """ + + def handle(): + self.store[key] = value + + return handle + + def _getattr(self, key): + """ + Return a callable that assigns self.store.key to value + """ + + def handle(): + return getattr(self.store, key) + + return handle + + def _getitem(self, key): + """ + Return a callable that assigns self.store[key] to value + """ + + def handle(): + return self.store[key] + + return handle + + def test_new_store(self): + store_file = mktemp() + ".aes" + JsonStore(store_file, auto_commit=True) + decryptedTemp = store_file + "~2" + pyAesCrypt.decryptFile(self._store_file, decryptedTemp, self.TEST_PASSWORD, self.BUFFER_SIZE) + with open(decryptedTemp) as secure: + self.assertEqual(secure.read(), "{}") + os.remove(store_file) + os.remove(decryptedTemp) + + JsonStore(store_file, auto_commit=False) + decryptedTemp = store_file + "~2" + pyAesCrypt.decryptFile(self._store_file, decryptedTemp, self.TEST_PASSWORD, self.BUFFER_SIZE) + with open(decryptedTemp) as handle: + self.assertEqual(handle.read(), "{}") + + os.remove(store_file) + os.remove(decryptedTemp) + + def test_assign_valid_types(self): + for name, value in self.TEST_DATA: + self.store[name] = value + self.store[name] == value + getattr(self.store, name) == value + + del self.store[name] + self.assertRaises(KeyError, self._getitem(name)) + self.assertRaises(AttributeError, self._getattr(name)) + + setattr(self.store, name, value) + self.store[name] == value + getattr(self.store, name) == value + + delattr(self.store, name) + self.assertRaises(KeyError, self._getitem(name)) + self.assertRaises(AttributeError, self._getattr(name)) + + def test_assign_invalid_types(self): + for method in (self._setattr, self._setitem): + + def assign(value): + return method("key", value) + + self.assertRaises(AttributeError, assign(set())) + self.assertRaises(AttributeError, assign(object())) + self.assertRaises(AttributeError, assign(None for i in range(2))) + + def test_assign_bad_keys(self): + # FIXME: a ValueError would make more sense + self.assertRaises(AttributeError, self._setitem(1, 2)) + + def test_retrieve_values(self): + for name, value in self.TEST_DATA: + self.store[name] = value + self.assertEqual(getattr(self.store, name), value) + self.assertEqual(self.store[name], value) + + def test_has_values(self): + for name, value in self.TEST_DATA: + self.store[name] = value + self.assertTrue(name in self.store) + + self.assertFalse("foo" in self.store) + + def test_empty_key(self): + with self.assertRaises(KeyError): + return self.store[""] + + def test_empty_store(self): + store_file = mktemp() + with open(store_file, "wb") as f: + f.write(b"") + self.assertTrue(JsonStore(f.name)) + + def test_assign_cycle(self): + test_list = [] + test_dict = {} + test_list.append(test_dict) + test_dict["list"] = test_list + for method in (self._setattr, self._setitem): + self.assertRaises(ValueError, method("key", test_list)) + self.assertRaises(ValueError, method("key", test_dict)) + + def test_nested_dict_helper(self): + self.assertRaises(KeyError, self._setitem("dictionary.noexist", None)) + self.assertRaises(KeyError, self._getitem("dictionary.noexist")) + + for access_key in ("dictionary.exist", ("dictionary", "exist"), ["dictionary", "exist"]): + self.store.dictionary = {"a": 1} + self.store["dictionary.exist"] = None + self.assertIsNone(self.store.dictionary["exist"]) + self.assertIsNone(self.store[access_key]) + + self.store["dictionary.a"] = 2 + del self.store[access_key] + self.assertRaises(KeyError, self._getitem(access_key)) + self.assertNotIn("exist", self.store.dictionary) + self.assertEqual(self.store.dictionary, {"a": 2}) + + def test_nested_getitem(self): + self.store["list"] = [ + { + "key": [None, "value", "last"] + } + ] + assert self.store["list", 0, "key", 1] == "value" + assert self.store[["list", 0, "key", -1]] == "last" + self.assertRaises(TypeError, self._getitem("list.0.key.1")) + assert len(self.store["list", 0, "key", 1:]) == 2 + + def test_del(self): + self.store.key = None + del self.store.key + self.assertRaises(KeyError, self._getitem("key")) + + self.store.key = None + del self.store["key"] + self.assertRaises(KeyError, self._getitem("key")) + + def test_context_and_deserialisation(self): + store_file = mktemp() + for name, value in self.TEST_DATA: + if isinstance(value, tuple): + value = list(value) + with JsonStore(store_file) as store: + store[name] = value + with JsonStore(store_file) as store: + self.assertEqual(getattr(store, name), value) + + def test_deep_copying(self): + inner_list = [] + outer_list = [inner_list] + inner_dict = {} + outer_dict = {"key": inner_dict} + + for method in (self._getattr, self._getitem): + self.store.list = outer_list + self.assertIsNot(method("list")(), outer_list) + self.assertIsNot(method("list")()[0], inner_list) + + self.store.dict = outer_dict + self.assertIsNot(method("dict")(), outer_dict) + self.assertIsNot(method("dict")()["key"], inner_dict) + + self.assertIsNot(method("list")(), method("list")()) + self.assertIsNot(method("list")()[0], method("list")()[0]) + self.assertIsNot(method("dict")(), method("dict")()) + self.assertIsNot(method("dict")()["key"], method("dict")()["key"]) + + def test_load(self): + for good_data in ("{}", '{"key": "value"}'): + data = good_data.encode("utf-8") + if self._store_file.endswith(".aes"): + tempStoreFile = self._store_file + "_TEST" + with open(tempStoreFile, "wb") as temp: + temp.write(data) + pyAesCrypt.encryptFile(tempStoreFile, self._store_file, self.TEST_PASSWORD, self.BUFFER_SIZE) + os.remove(tempStoreFile) + else: + with open(self._store_file, "wb") as handle: + handle.write(data) + self.store._load() + + for bad_data in ("[]", "1", "nill", '"x"'): + data = bad_data.encode("utf-8") + if self._store_file.endswith(".aes"): + tempStoreFile = self._store_file + "_TEST" + with open(tempStoreFile, "wb") as temp: + temp.write(data) + pyAesCrypt.encryptFile(tempStoreFile, self._store_file, self.TEST_PASSWORD, self.BUFFER_SIZE) + os.remove(tempStoreFile) + else: + with open(self._store_file, "wb") as handle: + handle.write(data) + self.assertRaises(ValueError, self.store._load) + + def test_auto_commit(self): + store_file = mktemp() + store = JsonStore(store_file, indent=None, auto_commit=True) + store.value1 = 1 + with open(store_file) as handle: + self.assertEqual({"value1": 1}, json.load(handle)) + store["value2"] = 2 + with open(store_file) as handle: + self.assertEqual({"value1": 1, "value2": 2}, json.load(handle)) + + def test_no_auto_commit(self): + store_file = mktemp() + store = JsonStore(store_file, indent=None, auto_commit=False) + store.value1 = 1 + store["value2"] = 2 + with open(store_file) as handle: + self.assertEqual({}, json.load(handle)) + + def test_transaction_rollback(self): + self.store.value = 1 + try: + with self.store: + self.store.value = 2 + try: + with self.store: + self.store.value = 3 + raise TransactionBreaker + except TransactionBreaker: + pass + self.assertEqual(self.store.value, 2) + raise TransactionBreaker + except TransactionBreaker: + pass + self.assertEqual(self.store.value, 1) + + def test_transaction_commit(self): + self.store.value = 1 + self.store.remove_me = "bye" + with self.store: + self.store.value = 2 + del self.store.remove_me + self.assertEqual(self.store.value, 2) + self.assertRaises(AttributeError, self._getattr("remove_me")) + + def test_transaction_write(self): + with self.store: + self.store.value1 = 1 + decryptedFile = mktemp() + pyAesCrypt.decryptFile(self._store_file, decryptedFile, self.TEST_PASSWORD, self.BUFFER_SIZE) + with open(decryptedFile) as handle: + self.assertEqual(handle.read(), "{}") + os.remove(decryptedFile) + + with self.store: + self.store.value2 = 2 + + decryptedFile = mktemp() + pyAesCrypt.decryptFile(self._store_file, decryptedFile, self.TEST_PASSWORD, self.BUFFER_SIZE) + with open(decryptedFile) as handle: + self.assertEqual(handle.read(), "{}") + os.remove(decryptedFile) + + decryptedFile = mktemp() + pyAesCrypt.decryptFile(self._store_file, decryptedFile, self.TEST_PASSWORD, self.BUFFER_SIZE) + with open(decryptedFile) as handle: + self.assertEqual(handle.read(), '{"value1": 1, "value2": 2}') + os.remove(decryptedFile) + + def test_list_concat_inplace(self): + self.store.list = [] + extension = [{"key": "value"}] + + # make sure += happens + self.store["list"] += extension + self.store.list += extension + self.assertEqual(self.store.list, extension * 2) + + # make sure a deepcopy occurred + self.assertIsNot(self.store.list[0], extension[0]) + + +if __name__ == "__main__": + unittest.main()