add a persistent dict impl
This commit is contained in:
parent
0a93c76e66
commit
15e7458666
@ -12,6 +12,9 @@ import sys
|
||||
import logging
|
||||
import os
|
||||
import requests
|
||||
import marshal
|
||||
from contextlib import contextmanager
|
||||
import filelock
|
||||
|
||||
|
||||
DICTPROXY_LOOKUP_CHAR = "L"
|
||||
@ -24,6 +27,38 @@ DICTPROXY_TRANSACTION_CHARS = "SBC"
|
||||
METADATA_TOKEN_KEY = "devicetoken"
|
||||
|
||||
|
||||
class PersistentDict:
|
||||
"""Concurrency-safe multi-reader-single-writer Persistent Dict."""
|
||||
|
||||
def __init__(self, path, timeout=5.0):
|
||||
self.path = path
|
||||
self.lock_path = path.with_name(path.name + ".lock")
|
||||
self.timeout = timeout
|
||||
|
||||
@contextmanager
|
||||
def modify(self):
|
||||
try:
|
||||
with filelock.FileLock(self.lock_path, timeout=self.timeout):
|
||||
data = self.get()
|
||||
yield data
|
||||
write_path = self.path.with_suffix(".tmp")
|
||||
with write_path.open("wb") as f:
|
||||
marshal.dump(data, f)
|
||||
os.rename(write_path, self.path)
|
||||
except filelock.Timeout:
|
||||
logging.warning("could not obtain lock, removing: %r", self.lock_path)
|
||||
os.remove(self.lock_path)
|
||||
with self.modify() as d:
|
||||
yield d
|
||||
|
||||
def get(self):
|
||||
try:
|
||||
with self.path.open("rb") as f:
|
||||
return marshal.load(f)
|
||||
except FileNotFoundError:
|
||||
return {}
|
||||
|
||||
|
||||
class Notifier:
|
||||
def __init__(self, vmail_dir):
|
||||
self.vmail_dir = vmail_dir
|
||||
|
@ -5,6 +5,7 @@ from chatmaild.metadata import (
|
||||
handle_dovecot_request,
|
||||
handle_dovecot_protocol,
|
||||
Notifier,
|
||||
PersistentDict,
|
||||
)
|
||||
|
||||
|
||||
@ -215,3 +216,28 @@ def test_notifier_thread_run_gone_removes_token(notifier):
|
||||
url, data, timeout = requests[1]
|
||||
assert data == "45678"
|
||||
assert notifier.get_tokens("user@example.org") == ["45678"]
|
||||
|
||||
|
||||
class TestPersistentDict:
|
||||
@pytest.fixture
|
||||
def store(self, tmp_path):
|
||||
return PersistentDict(tmp_path.joinpath("metadata"))
|
||||
|
||||
def test_basic(self, store):
|
||||
assert store.get() == {}
|
||||
with store.modify() as d:
|
||||
d["devicetoken"] = [1, 2, 3]
|
||||
d["456"] = 4.2
|
||||
new = store.get()
|
||||
assert new["devicetoken"] == [1, 2, 3]
|
||||
assert new["456"] == 4.2
|
||||
|
||||
def test_dying_lock(self, tmp_path, caplog):
|
||||
store1 = PersistentDict(tmp_path.joinpath("metadata"))
|
||||
store2 = PersistentDict(tmp_path.joinpath("metadata"), timeout=0.1)
|
||||
with store1.modify() as d:
|
||||
with store2.modify() as d2:
|
||||
d2["1"] = "2"
|
||||
assert "could not obtain" in caplog.records[0].msg
|
||||
d["1"] = "3"
|
||||
assert store1.get()["1"] == "3"
|
||||
|
Loading…
Reference in New Issue
Block a user