diff --git a/chatmaild/src/chatmaild/metadata.py b/chatmaild/src/chatmaild/metadata.py index bcc34c2..2165a76 100644 --- a/chatmaild/src/chatmaild/metadata.py +++ b/chatmaild/src/chatmaild/metadata.py @@ -12,6 +12,9 @@ import sys import logging import os import requests +import marshal +from contextlib import contextmanager +import filelock DICTPROXY_LOOKUP_CHAR = "L" @@ -24,6 +27,38 @@ DICTPROXY_TRANSACTION_CHARS = "SBC" METADATA_TOKEN_KEY = "devicetoken" +class PersistentDict: + """Concurrency-safe multi-reader-single-writer Persistent Dict.""" + + def __init__(self, path, timeout=5.0): + self.path = path + self.lock_path = path.with_name(path.name + ".lock") + self.timeout = timeout + + @contextmanager + def modify(self): + try: + with filelock.FileLock(self.lock_path, timeout=self.timeout): + data = self.get() + yield data + write_path = self.path.with_suffix(".tmp") + with write_path.open("wb") as f: + marshal.dump(data, f) + os.rename(write_path, self.path) + except filelock.Timeout: + logging.warning("could not obtain lock, removing: %r", self.lock_path) + os.remove(self.lock_path) + with self.modify() as d: + yield d + + def get(self): + try: + with self.path.open("rb") as f: + return marshal.load(f) + except FileNotFoundError: + return {} + + class Notifier: def __init__(self, vmail_dir): self.vmail_dir = vmail_dir diff --git a/chatmaild/src/chatmaild/tests/test_metadata.py b/chatmaild/src/chatmaild/tests/test_metadata.py index 63f48f3..cf53480 100644 --- a/chatmaild/src/chatmaild/tests/test_metadata.py +++ b/chatmaild/src/chatmaild/tests/test_metadata.py @@ -5,6 +5,7 @@ from chatmaild.metadata import ( handle_dovecot_request, handle_dovecot_protocol, Notifier, + PersistentDict, ) @@ -215,3 +216,28 @@ def test_notifier_thread_run_gone_removes_token(notifier): url, data, timeout = requests[1] assert data == "45678" assert notifier.get_tokens("user@example.org") == ["45678"] + + +class TestPersistentDict: + @pytest.fixture + def store(self, tmp_path): + return PersistentDict(tmp_path.joinpath("metadata")) + + def test_basic(self, store): + assert store.get() == {} + with store.modify() as d: + d["devicetoken"] = [1, 2, 3] + d["456"] = 4.2 + new = store.get() + assert new["devicetoken"] == [1, 2, 3] + assert new["456"] == 4.2 + + def test_dying_lock(self, tmp_path, caplog): + store1 = PersistentDict(tmp_path.joinpath("metadata")) + store2 = PersistentDict(tmp_path.joinpath("metadata"), timeout=0.1) + with store1.modify() as d: + with store2.modify() as d2: + d2["1"] = "2" + assert "could not obtain" in caplog.records[0].msg + d["1"] = "3" + assert store1.get()["1"] == "3"