Compare commits

...

6 Commits
main ... dev

18 changed files with 1497 additions and 5 deletions

View File

@ -16,3 +16,4 @@ __version__ = '0.1.0'
from .identity import Identity
from .document import Document
from .replica import Replica

84
earthsnake/compare.py Normal file
View File

@ -0,0 +1,84 @@
"""Comparison functions"""
from enum import Enum
from typing import Any, List, Literal, Optional
class Cmp(Enum):
"""Comparison results"""
LT = -1
EQ = 0
GT = 1
def deep_compare(val_a: Any, val_b: Any) -> Cmp:
"""Compare two dictionaries key by key"""
if isinstance(val_a, dict):
for key, elem_a in val_a.items():
elem_b = val_b[key]
cmp = deep_compare(elem_a, elem_b)
if cmp == Cmp.EQ:
continue
return cmp
return compare_arrays(list(val_a.keys()), list(val_b.keys()))
if val_a > val_b:
return Cmp.GT
if val_a < val_b:
return Cmp.LT
return Cmp.EQ
def compare_basic(val_a: Any, val_b: Any, order: Literal['ASC', 'DESC'] = 'ASC') -> Cmp:
"""Compare two basic values"""
cmp = deep_compare(val_a, val_b)
if cmp == Cmp.EQ:
return Cmp.EQ
if order == 'ASC':
return cmp
return Cmp.LT if cmp == Cmp.GT else Cmp.GT
def compare_arrays(
arr_a: List[Any],
arr_b: List[Any],
sort_orders: Optional[List[Literal['ASC', 'DESC']]] = None,
) -> Cmp:
"""Compare two arrays"""
sort_orders = sort_orders or []
for idx, (elem_a, elem_b) in enumerate(zip(arr_a, arr_b)):
try:
sort_order = sort_orders[idx]
except IndexError:
sort_order = 'ASC'
elem_cmp = compare_basic(elem_a, elem_b, sort_order)
if elem_cmp != Cmp.EQ:
return elem_cmp
if len(arr_a) == len(arr_b):
return Cmp.EQ
idx = min(len(arr_a), len(arr_b))
try:
sort_order = sort_orders[idx]
except IndexError:
sort_order = 'ASC'
return compare_basic(len(arr_a), len(arr_b), sort_order)

View File

@ -1,4 +1,4 @@
"""Format validator for raw (JSON) documents in the es.4 format"""
"""Document class for the es.4 format"""
from datetime import datetime, timezone
from hashlib import sha256
@ -29,9 +29,7 @@ class RawDocument(TypedDict, total=False):
class Es4Document(Document): # pylint: disable=too-many-instance-attributes
"""Validator for the 'es.4' format
Checks if documents are spec-compliant before ingesting, and signs them according to spec.
"""An es.4 format document
See https://earthstar-project.org/specs/data-spec
"""
@ -129,7 +127,7 @@ class Es4Document(Document): # pylint: disable=too-many-instance-attributes
signature: Optional[str] = None,
delete_after: Optional[datetime] = None,
):
self.author: Identity = author
self.author = author
self.path = path
self.signature = signature
self._content = content or ''
@ -142,6 +140,8 @@ class Es4Document(Document): # pylint: disable=too-many-instance-attributes
def from_json(cls, raw_document: Dict[str, Any]) -> 'Es4Document':
"""Validate raw_document as an es.4 document and create an ``Es4Document`` from it
Checks if documents are spec-compliant before ingesting, and signs them according to spec.
:returns: a new ``Es4Document``
:raises ValidationError: if anything is wrong
"""
@ -245,14 +245,24 @@ class Es4Document(Document): # pylint: disable=too-many-instance-attributes
hasher = sha256()
for key, value in sorted(hash_keys.items(), key=lambda elem: elem[0]):
# Skip null fields
if value is None:
continue
# Otherwise, append the fieldname and value.
# Tab and newline are our field separators.
# Convert integers to strings here.
# (The newline is included on the last field.)
hasher.update(f'{key}\t{value}\n'.encode('utf-8'))
# Binary digest, not hex digest string! Then convert bytes to Earthstar b32 format with
# leading 'b'.
return base32_bytes_to_string(hasher.digest())
def sign(self, identity: Optional[Identity] = None) -> None:
"""Sign the document and store the signature into the document (mutating it)
"""
if identity and identity != self.author:
raise ValidationError(
"when signing a document, keypair address must match document author"
@ -358,3 +368,21 @@ class Es4Document(Document): # pylint: disable=too-many-instance-attributes
if content is not None and calculated_hash != content_hash:
raise ValidationError("content does not match contentHash")
@staticmethod
def compare_newest_first(doc_a, doc_b):
"""Compare two documents based on their time stamp"""
if doc_a.timestamp < doc_b.timestamp:
return Cmp.LT
if doc_b.timestamp > doc_b.timestamp:
return Cmp.GT
if doc_a.signature < doc_b.signature:
return Cmp.LT
if doc_a.signature > doc_b.signature:
return Cmp.GT
return Cmp.EQ

View File

@ -11,3 +11,7 @@ class EarthsnakeError(Exception):
class ValidationError(EarthsnakeError):
"""Raised when something doesnt pass as a valid Earthsnake object"""
class ReplicaIsClosedError(EarthsnakeError):
"""A ReplicaBase or ReplicaDriverBase object was used after close() was called on it"""

View File

@ -147,6 +147,10 @@ class Identity:
return f'{self.name} {mnemonic}'
@property
def can_sign(self) -> bool:
return bool(self.sign_key)
def sign(self, data: str) -> str:
"""Sign data"""
@ -168,3 +172,6 @@ class Identity:
return False
return True
def hash(self) -> int:
return hash(str(self))

View File

@ -94,3 +94,6 @@ class Path:
"""Check if path ends with sub"""
return self.path.endswith(sub)
def __hash__(self) -> int:
return hash(self.path)

6
earthsnake/peer.py Normal file
View File

@ -0,0 +1,6 @@
from .syncer import SyncerBase
class Peer:
def __init__(self, syncer: SyncerBase) -> None:
pass

598
earthsnake/replica.py Normal file
View File

@ -0,0 +1,598 @@
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum, auto
import json
import logging
from typing import List, Literal, Mapping, Optional, Tuple, Union
#import { Lock, Superbus } from "../../deps.ts";
from .types import Cmp
#import {
# LocalIndex,
# ShareAddress,
#} from "../util/doc-types.ts";
from .document import Document, IncompleteDocument
from .identity import Identity
from .path import Path
from .share import Share
from .query import HistoryMode, Query
#import {
# IngestEvent,
# IReplica,
# IReplicaDriver,
# ReplicaBusChannel,
# ReplicaId,
#} from "./replica-types.ts";
#import { IFormatValidator } from "../format-validators/format-validator-types.ts";
from .format_validator import FormatValidatorBase
from .exc import EarthsnakeError, ReplicaIsClosedError, ValidationError
#import { randomId } from "../util/misc.ts";
from .util import microsecond_now, random_id
#import { compareArrays } from "./compare.ts";
#import { checkShareIsValid } from "../core-validators/addresses.ts";
#
#import { Crypto } from "../crypto/crypto.ts";
#
# --------------------------------------------------
#
#import { Logger } from "../util/log.ts";
J = json.dumps
logger = logging.getLogger(__name__)
#
# ================================================================================
class IngestEventKind(Enum):
failure = auto()
nothing_happened = auto()
success = auto()
class IngestEventBase:
kind: IngestEventKind
max_local_index: int
class IngestEventFailure(IngestEventBase):
reason: Literal['write_error', 'invalid_document']
err: Optional[Exception]
class IngestEventNothingHappened(IngestEventBase):
reason: Literal['obsolete_from_same_author', 'already_had_it']
doc: Document
class IngestEventSuccess(IngestEventBase):
doc: Document
doc_is_latest: bool
prev_doc_from_same_author: Optional[Document]
prev_latest_doc: Optional[Document]
IngestEvent = Union[IngestEventFailure, IngestEventNothingHappened, IngestEventSuccess]
def doc_compare_newest_first(a: Document, b: Document) -> Cmp:
# Sorts by timestamp DESC (newest fist) and breaks ties using the signature ASC.
return compare_arrays(
[a.timestamp, a.signature],
[b.timestamp, a.signature],
["DESC", "ASC"],
)
class Replica: # (IReplica):
"""A replica of a share's data, used to read, write, and synchronise data to
Should be closed using the `close` method when no longer being used.
.. code-block:: python
replica = Replica('+a.a123', FormatValidator.ES4, MemoryReplicaDriver())
"""
replica_id: str # todo: save it to the driver too, and reload it when starting up
#: The address of the share this replica belongs to
share: Share
#: The validator used to validate ingested documents
format_validator: FormatValidatorBase
replica_driver: ReplicaDriverBase
bus: Superbus[ReplicaBusChannel]
_is_closed = False
_ingest_lock: Lock[IngestEvent]
def __init__(
self,
share: Union[str, Workspace],
validator: FormatValidatorBase,
driver: ReplicaDriverBase,
) -> None:
if isinstance(share, str):
share = Workspace.from_string(share)
logger.debug(
'constructor. driver = %s',
driver.__class__.__name__,
)
# If we got a class instead of an actual driver object, lets instantiate the driver
if isinstance(driver, type):
driver = driver(share)
self.replica_id = 'replica-' + random_id()
self.share = share
self.format_validator = validator
self.replica_driver = driver
self.bus = Superbus('|')
self._ingest_lock = Lock()
# --------------------------------------------------
# LIFECYCLE
def is_closed(self) -> bool:
"""Returns whether the replica is closed or not
"""
return self._is_closed
async def close(self, erase: bool) -> None:
"""Closes the replica, preventing new documents from being ingested or events being emitted
Any methods called after closing will return `ReplicaIsClosedError`
:param erase: Erase the contents of the replica. Defaults to `false`
"""
logger.debug('closing...')
if self._is_closed:
raise ReplicaIsClosedError()
# TODO: do this all in a lock?
logger.debug(' sending willClose blockingly...')
await self.bus.send_and_wait("willClose")
logger.debug(' marking self as closed...')
self._is_closed = True
logger.debug(' closing ReplicaDriver (erase = %s)...', erase)
await self.replica_driver.close(erase)
logger.debug(' sending didClose nonblockingly...')
await self.bus.send_and_wait('didClose')
logger.debug('...closing done')
# --------------------------------------------------
# CONFIG
async def get_config(self, key: str) -> Optional[str]:
"""Get a specific config value"""
if self._is_closed:
raise ReplicaIsClosedError()
return await self.replica_driver.get_config(key)
async def set_config(self, key: str, value: str) -> None:
"""Set a specific configuration value"""
if self._is_closed:
raise ReplicaIsClosedError()
return await self.replica_driver.set_config(key, value)
async def list_config_keys(self, ) -> List[str]:
"""List all available configuration keys
"""
if self._is_closed:
raise ReplicaIsClosedError()
return await self.replica_driver.list_configKeys()
async def delete_config(self, key: str) -> bool:
"""Delete a key from the configuration
"""
if self._is_closed:
raise ReplicaIsClosedError()
return await self.replica_driver.delete_config(key)
# --------------------------------------------------
# GET
def get_max_local_index(self) -> int:
"""Returns the max local index of all stored documents
"""
if self._is_closed:
raise ReplicaIsClosedError()
return self.replica_driver.get_max_local_index()
async def get_docs_after_local_index(
self,
history_mode: HistoryMode,
start_after: int,
limit: Optional[int] = None,
) -> List[Document]:
"""Get all documents after a specific index
"""
logger.debug(
'get_docs_after_local_index(%s, %s, %d)',
history_mode,
start_after,
limit or -1,
)
if self._is_closed:
raise ReplicaIsClosedError()
query: Query = {
'history_mode': history_mode,
'order_by': 'localIndex ASC',
'start_after': {
'local_index': start_after,
},
'limit': limit,
}
return await self.replica_driver.query_docs(query)
async def get_all_docs(self) -> List[Document]:
"""Returns all documents, including historical versions of documents by other identities
"""
logger.debug('get_all_docs()')
if self._is_closed:
raise ReplicaIsClosedError()
return await self.replica_driver.query_docs(
history_mode=HistoryMode.ALL,
order_by='path ASC',
)
async def get_latest_docs(self) -> List[Document]:
"""Returns latest document from every path
"""
logger.debug('get_latest_docs()')
if self._is_closed:
raise ReplicaIsClosedError()
return await self.replica_driver.query_docs(
history_mode=HistoryMode.LATEST,
order_by='path ASC',
)
async def get_all_docs_at_path(self, path: Path) -> List[Document]:
"""Returns all versions of a document by different authors from a specific path
"""
logger.debug('get_all_docs_at_path("%s")', path)
if self._is_closed:
raise ReplicaIsClosedError()
return await self.replica_driver.query_docs(
history_mode=HistoryMode.ALL,
order_by='path ASC',
filter={'path': path},
)
async def get_latest_doc_at_path(self, path: Path) -> Optional[Document]:
"""Returns the most recently written version of a document at a path"""
logger.debug('get_latest_docs_at_path("%s")', path)
if self._is_closed:
raise ReplicaIsClosedError()
docs = await self.replica_driver.query_docs(
history_mode=HistoryMode.LATEST,
order_by='path ASC',
filter={'path': path},
)
if not docs:
return None
return docs[0]
async def query_docs(self, query: Optional[Query] = None) -> List[Document]:
"""Returns an array of docs for a given query
.. code-block:: python
my_query = {
'filter': {
'path_ends_with': '.txt'
},
'limit': 5,
}
first_five_text_docs = await my_replica.query_docs(my_query)
"""
logger.debug('queryDocs %s', query)
if self._is_closed:
raise ReplicaIsClosedError()
return await self.replica_driver.query_docs(query)
# def query_paths(query: Optional[Query]) -> List[Path]: pass
# def query_authors(query: Optional[Query]) -> List[AuthorAddress]: pass
# --------------------------------------------------
# SET
async def set(
self,
keypair: Identity,
doc_to_set: IncompleteDocument,
) -> IngestEvent:
"""Adds a new document to the replica
If a document signed by the same identity exists at the same path, it will be overwritten.
"""
logger.debug('set %s', doc_to_set)
if self._is_closed:
raise ReplicaIsClosedError()
logger.debug(
'...deciding timestamp: getting latest doc at the same path (from any author)',
)
timestamp: int
if isinstance(doc_to_set.timestamp, int):
timestamp = doc_to_set.timestamp
logger.debug('...docToSet already has a timestamp; not changing it from %d', timestamp)
else:
# bump timestamp if needed to win over existing latest doc at same path
latest_doc_same_path = await self.get_latest_doc_at_path(doc_to_set.path)
if latest_doc_same_path is None:
timestamp = microsecond_now()
logger.debug(
'...no existing latest doc, setting timestamp to now() = %s', timestamp
)
else:
timestamp = max(microsecond_now(), latest_doc_same_path.timestamp + 1)
logger.debug(
'...existing latest doc found, bumping timestamp to win if needed = %s',
timestamp,
)
doc = Document(
format=ValidationFormat.ES4,
author=keypair.address,
content=doc_to_set.content,
content_hash=await Crypto.sha256base32(doc_to_set.content),
delete_after=doc_to_set.delete_after or None,
path=doc_to_set.path,
timestamp=timestamp,
workspace=self.share,
signature='?', # signature will be added in just a moment
# _localIndex will be added during upsert. it's not needed for the signature.
)
logger.debug('...signing doc')
try:
signed_doc = await self.format_validator.sign_document(keypair, doc)
except EarthsnakeError:
return {
'kind': 'failure',
'reason': 'invalid_document',
'err': signed_doc,
'max_local_index': self.replica_driver.get_max_local_index(),
}
logger.debug('...signature = %s', signed_doc.signature)
logger.debug('...ingesting')
logger.debug('-----------------------')
ingest_event = await self.ingest(signed_doc)
logger.debug('-----------------------')
logger.debug('...done ingesting')
logger.debug('...set is done.')
return ingest_event
async def ingest(self, doc_to_ingest: Doc) -> IngestEvent:
"""Ingest an existing signed document to the replica
"""
logger.debug('ingest %s', doc_to_ingest)
if self._is_closed:
raise ReplicaIsClosedError()
logger.debug('...removing extra fields')
try:
remove_results_or_err = self.format_validator.remove_extra_fields(doc_to_ingest)
except EarthsnakeError as exc:
return {
'kind': "failure",
'reason': "invalid_document",
'err': exc,
'max_local_index': self.replica_driver.get_max_local_index(),
}
doc_to_ingest = remove_results_or_err.doc # a copy of doc without extra fields
extra_fields = remove_results_or_err.extras # any extra fields starting with underscores
if extra_fields:
logger.debug('...extra fields found: %s', J(extra_fields))
try:
# now actually check doc validity against core schema
self.format_validator.check_document_is_valid(doc_to_ingest)
except EarthsnakeError as exc:
return {
'kind': "failure",
'reason': "invalid_document",
'err': exc,
'max_local_index': self.replica_driver.get_max_local_index(),
}
async def write_to_driver_with_lock() -> IngestEvent:
# get other docs at the same path
logger.debug(' >> ingest: start of protected region')
logger.debug(' > getting other history docs at the same path by any author')
existing_docs_same_path = await self.get_all_docs_at_path(doc_to_ingest.path)
logger.debug(' > ...got %d', len(existing_docs_same_path))
logger.debug(' > getting prevLatest and prevSameAuthor')
prev_latest: Optional[Document] = existing_docs_same_path[0] if existing_docs_same_path else None
prev_same_author: Optional[Document] = [
document
for document in existing_docs_same_path
if document.author == doc_to_ingest.author
][0] or None
logger.debug(' > checking if new doc is latest at this path')
existing_docs_same_path.push(doc_to_ingest)
existing_docs_same_path.sort(doc_compare_newest_first)
is_latest = existing_docs_same_path[0] == doc_to_ingest
logger.debug(' > ...isLatest: %s', is_latest)
if not is_latest and prev_same_author is not None:
logger.debug(
' > new doc is not latest and there is another one from the same author...'
)
# check if this is obsolete or redudant from the same author
doc_comp = doc_compare_newest_first(doc_to_ingest, prev_same_author)
if doc_comp == Cmp.GT:
logger.debug(' > new doc is GT prevSameAuthor, so it is obsolete')
return {
'kind': "nothing_happened",
'reason': "obsolete_from_same_author",
'doc': doc_to_ingest,
'max_local_index': self.replica_driver.get_max_local_index(),
}
if doc_comp == Cmp.EQ:
logger.debug(
' > new doc is EQ prevSameAuthor, so it is redundant (already_had_it)',
)
return {
'kind': "nothing_happened",
'reason': "already_had_it",
'doc': doc_to_ingest,
'max_local_index': self.replica_driver.get_max_local_index(),
}
# save it
logger.debug(" > upserting into ReplicaDriver...")
# TODO: pass existing_docs_same_path to save another lookup
doc_as_written = await self.replica_driver.upsert(doc_to_ingest)
logger.debug(" > ...done upserting into ReplicaDriver")
logger.debug(" > ...getting ReplicaDriver maxLocalIndex...")
max_local_index = self.replica_driver.get_max_local_index()
logger.debug(
' >> ingest: end of protected region, returning a WriteEvent from the lock'
)
return {
'kind': "success",
'max_local_index': max_local_index,
'doc': doc_as_written, # with updated extra properties like _localIndex
'doc_is_latest': is_latest,
'prev_doc_from_same_author': prev_same_author,
'prev_latest_doc': prev_latest,
}
logger.debug(" >> ingest: running protected region...")
ingest_event: IngestEvent = await self._ingest_lock.run(
write_to_driver_with_lock,
)
logger.debug(" >> ingest: ...done running protected region")
logger.debug("...send ingest event after releasing the lock")
logger.debug("...ingest event: %s", ingest_event)
await self.bus.send_and_wait(
'ingest|{doc_to_ingest.path}',
ingest_event,
) # include the path in the channel even on failures
return ingest_event
async def overwrite_all_docs_by_author(self, keypair: Identity) -> int:
"""Overwrite every document from this author, including history versions, with an empty doc
:returns: the number of documents changed, or -1 if there was an error.
"""
logger.debug('overwriteAllDocsByAuthor("%s")', keypair.address)
if self._is_closed:
raise ReplicaIsClosedError()
# TODO: do this in batches
query = Query(
flt={'author': keypair.address},
history_mode=HistoryMode.ALL,
)
docs_to_overwrite = await self.query_docs(query)
logger.debug(' ...found %d docs to overwrite', len(docs_to_overwrite))
num_overwritten = 0
num_already_empty = 0
for doc in docs_to_overwrite:
if not doc.content:
num_already_empty += 1
continue
# remove extra fields
cleaned_result = self.format_validator.remove_extra_fields(doc)
cleaned_doc = cleaned_result.doc
# make new doc which is empty and just barely newer than the original
empty_doc = Document(
cleaned_doc,
content='',
content_hash=await Crypto.sha256base32(''),
timestamp=doc.timestamp + 1,
signature='?',
)
try:
# sign and ingest it
signed_doc = await self.format_validator.sign_document(keypair, empty_doc)
except EarthsnakeError:
return signed_doc
ingest_event = await self.ingest(signed_doc)
if ingest_event.kind == 'failure':
return ValidationError(
f'ingestion error during overwriteAllDocsBySameAuthor: {ingest_event.reason}: {ingest_event.err}',
)
if ingest_event.kind == 'nothing_happened':
return ValidationError(
f'ingestion did nothing during overwriteAllDocsBySameAuthor: {ingest_event.reason}',
)
# success
num_overwritten += 1
logger.debug(
' ...done; %d overwritten to be empty; %d were already empty; out of total %d docs',
num_overwritten,
num_already_empty,
len(docs_to_overwrite)
)
return num_overwritten

View File

@ -0,0 +1,171 @@
"""Replica related things"""
from abc import ABC, abstractmethod
from enum import Enum, auto
from typing import List, Optional
from ..document import Document
from ..identity import Identity
from ..path import Path
from ..query import Query
from ..share import Share
class HistoryMode(Enum):
ALL = auto()
LATEST = auto()
class IngestEvent:
pass
class Replica(ABC):
"""A replica of a shares data
Used to read, write, and synchronise data to.
Should be closed using the ``close()`` method when no longer being used.
"""
def __init__(self, share: Share, **driver_kwargs):
pass
@property
@abstractmethod
def is_closed(self):
"""Tells whether the replica is closed or not
"""
@abstractmethod
def close(self):
"""Closes the replica, preventing further use
Any method called after closing will return ``ReplicaIsClosedError``
"""
@property
@abstractmethod
def max_local_index(self) -> int:
"""Returns the maximum local index of all stored documents"""
@abstractmethod
def get_docs_after_local_index(
self,
history_mode: HistoryMode,
start_after: int,
limit: Optional[int] = None,
) -> List[Document]:
"""Get all documents after a specific local index"""
@abstractmethod
def get_all_docs(self) -> List[Document]:
"""Get all documents, including historical versions by other identities"""
@abstractmethod
def get_latest_docs(self) -> List[Document]:
"""Get the latest from every path"""
@abstractmethod
def get_all_docs_at_path(self, path: Path) -> List[Document]:
"""Get all versions of a document by different authors from a specific path"""
@abstractmethod
def get_latest_doc_at_path(self, path: Path) -> Optional[Document]:
"""Get the most recent version of a document at a specific path"""
@abstractmethod
def query_docs(self, query: Query) -> List[Document]:
"""Get a list of documents matching a given query"""
@abstractmethod
def query_paths(self, query: Query) -> List[Path]:
"""Get all document paths where documents match a given query"""
@abstractmethod
def query_authors(self, query: Query) -> List[Identity]:
"""Get all document authors where documents match a given query"""
@abstractmethod
def set(self, doc_to_set: Document) -> IngestEvent:
"""Add a new document to the replica
If a document signed by the same identity exists at the same path, it will be overwritten.
"""
@abstractmethod
def ingest(self, document: Document) -> IngestEvent:
"""Ingest an existing, signed document to the replica"""
@abstractmethod
def overwrite_docs_by_author(self, identity: Identity) -> int:
"""Overwrite every document from this author with an empty doc
This includes historical versions of documents.
:returns: the number of documents changed.
"""
"""Workspace drivers"""
from abc import ABC, ABCMeta
from typing import List, Optional
from ...document import Document
from .. import Workspace
from ...query import Query
class WorkspaceDriverConfig(ABC):
"""Base class for configurable workspace drivers"""
def get_config(self, key: str) -> Optional[str]:
"""Get a configuration value"""
raise NotImplementedError()
def set_config(self, key: str, value: str) -> None:
"""Set a configuration value"""
raise NotImplementedError()
def list_config_keys(self) -> List[str]:
"""List all configuration keys"""
raise NotImplementedError()
def delete_config(self, key: str) -> bool:
"""Delete a configuration value"""
raise NotImplementedError()
class WorkspaceDriverBase(WorkspaceDriverConfig, metaclass=ABCMeta):
"""Base class for workspace drivers"""
workspace: Workspace
@property
def is_closed(self) -> bool:
"""Tells if a workspace is closed"""
raise NotImplementedError()
def close(self, erase: bool) -> None:
"""Close the workspace"""
raise NotImplementedError()
def get_max_local_index(self) -> int:
"""Get the maximum local index count"""
raise NotImplementedError()
def query_docs(self, query: Query) -> List[Document]:
"""Query a list of documents"""
raise NotImplementedError()
def upsert(self, doc: Document) -> Document:
"""Insert or update a document"""
raise NotImplementedError()

View File

@ -0,0 +1,99 @@
"""A share driver that stores data in memory
"""
from typing import Dict, List, Tuple
from ..document import Document
from ..exc import ReplicaIsClosedError
from ..identity import Identity
from ..path import Path
from ..query import HistoryMode, Query
from ..share import Share
from . import Replica
class InMemoryReplica(Replica):
"""In-memory Replica"""
def __init__(self, share: Share, **driver_kwargs) -> None:
self.share = share
self._is_closed = False
self._max_local_index = -1
# Local Index <=> Document pairs
self._documents: List[Tuple[int, Document]] = {}
@property
def is_closed(self) -> bool:
return self._is_closed
def close(self, erase: bool = False) -> None:
if self._is_closed:
raise ReplicaIsClosedError()
if erase:
self._local_index = -1
self._documents = []
self._is_closed = True
@property
def max_local_index(self) -> int:
if self._is_closed:
raise ReplicaIsClosedError()
return self._max_local_index
def _get_all_docs(self) -> List[Document]:
"""Get all documents"""
if self._is_closed:
raise ReplicaIsClosedError()
return [document for _, document in self._documents]
def _get_latest_docs(self) -> List[Document]:
"""Get the latest version of each document"""
if self._is_closed:
raise ReplicaIsClosedError()
docs_by_path: Dict[str, Document] = {}
for document in self._documents:
if (
str(document.path) not in docs_by_path
or docs_by_path[str(document.path)].timestamp <= document.timestamp
):
docs_by_path[str(document.path)] = document
return list(docs_by_path.values())
def query_docs(self, query: Query) -> List[Document]:
"""Query a list of documents"""
if self._is_closed:
raise ReplicaIsClosedError()
if query.history_mode == HistoryMode.ALL:
docs = self._get_all_docs()
else:
docs = self._get_latest_docs()
docs_to_local_index = {
document: local_index for document, local_index in self.docs_by_local_index
}
return query({docs_to_local_index[document]: document for document in docs})
def upsert(self, new_document: Document) -> None:
if self._is_closed:
raise ReplicaIsClosedError()
self._max_local_index += 1
self._documents = [
(local_index, document)
for local_index, document in self._documents
if document.author != new_document.author or document.path != new_document.path
]
self._documents.append((self._local_index, new_document))

View File

@ -0,0 +1,104 @@
"""Workspace syncing classes"""
from enum import Enum, auto
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
from ..watchable import Watchable, WatchableSet
FnsBag = Dict[str, Callable[[Any], Any]]
Thunk = Callable[[], None]
T = TypeVar('T')
class TransportStatus(Enum):
OPEN = auto()
CLOSED = auto()
class ConnectionStatus(Enum):
CONNECTING = auto()
CLOSED = auto()
class TransportBase:
"""Base class for workspace syncers"""
status: Watchable[TransportStatus]
is_closed: bool
methods: FnsBag
device_id: str
connections: WatchableSet[Connection]
def __init__(self, device_id: str, methods: FnsBag):
raise NotImplementedError()
def on_close(self, cb: Thunk) -> Thunk:
"""Set a handler for when the connection closes"""
raise NotImplementedError()
def close(self) -> None:
"""Close the syncers connection"""
raise NotImplementedError()
class TransportLocal(TransportBase):
def __init__(self, device_id: str, methods: BagType, description: str) -> None:
self.device_id = device_id
self.methods = methods
self.description = description
@property
def is_closed(self) -> bool:
return self.status == TransportStatus.CLOSED
def on_close(self, func: Thunk) -> Thunk:
return self.status.on_change_to(TransportStatus.CLOSED)(func)
def close() -> None:
if self.is_closed:
return
self.status.set(TransportStatus.CLOSED)
for conn in self.connections:
conn.close()
self.connections.clear()
def add_connection(self, other_trans: TransportLocal[BagType]) -> Tuple[Connection, Connection]:
if self.is_closed:
raise Exception('Cant use a transport after its closed')
this_conn: Connection[BagType]
other_conn: Connection[BagType]
this_conn = Connection(
description=f'conn {self.device_id} to {other_trans.device_id}',
transport=self,
device_id=self.device_id,
methods=self.methods,
send_envelope=lambda conn: ConnectionBase[BagType], env: Envelope[BagType]: other_conn.handle_incoming_envelope(env),
)
other_conn = Connection(
description=f'conn other_trans.device_id to {this.device_id}',
transport: other_trans,
device_id: other_trans.device_id,
methods: other_trans.methods,
send_envelope: lambda conn: ConnectionBase[BagType], env: Envelope[BagType]: this_conn.handle_incoming_envelope(env),
)
@this_conn.on_close
def close_other():
other_conn.close()
self.connections.delete(this_conn)
@other_conn.on_close
def close_this():
this_conn.close()
self.connections.add(this_conn)
other_trans.connections.add(other_conn)
return this_conn, other_conn

View File

@ -0,0 +1,118 @@
from typing import Any, Callable, Dict, Generic, Optional, Set, TypeVar
from .connection import Connection, ConnectionStatus
from .envelope import Envelope, EnvelopeNotify, EnvelopeRequest
from .util import make_id
from .watchable import Watchable
from . import TransportBase
T = TypeVar('T')
Thunk = Callable[[], None]
class Connection(Generic[T]):
def __init__(self, transport: TransportBase, device_id: str, description: str, methods: T, send_envelope: Callable[[Connection[T], Envelope], None]) -> None:
self.status = Watchable(ConnectionStatus.CONNECTING)
self._close_cbs: Set[Thunk] = set()
self.description = description
self._transport = transport
self._device_id = device_id
self._other_device_id: Optional[str] = None
self._methods = methods
self._send_envelope = send_envelope
self._deferred_requests: Dict[str, Deferred[Any]] = {}
self._last_seen = 0
@property
def is_closed(self) -> bool:
return self.status == ConnectionStatus.CLOSED
def on_close(self, func: Thunk) -> Thunk:
if self.is_closed:
raise RpcErrorUseAfterClose('the connection is closed')
self._close_cbs.add(func)
def del_cb():
self._close_cbs.remove(func)
return del_cb
def close(self) -> None:
if self.is_closed:
return
self.status = ConnectionStatus.CLOSED
for func in self._close_cbs:
func()
self._close_cbs.clear()
def handle_incoming_envelope(env: Envelope[T]) -> None:
# TODO: maybe this function should be in a lock to ensure it only runs one at a time
if self.is_closed:
raise RpcErrorUseAfterClose('the connection is closed')
# TODO: throw error if status is ERROR?
if env.kind == 'NOTIFY':
if not hasattr(self._methods, env.method):
logger.warn(f'> error in NOTIFY handler: no notify method called {env.method}')
else:
try:
self._methods[env.method](*env.args)
except BaseException as exc:
logger.warn(f'> error when running NOTIFY method: {env} {exc}')
elif env.kind == 'REQUEST':
try:
if not hasattr(self._methods, env.method):
raise RpcErrorUnknownMethod(f'unknown method in REQUEST: {env.method}')
data = self._methods[env.method](*env.args)
response_env_data = EnvelopeResponseWithData(kind='RESPONSE', from_device_id=self._device_id, envelope_id=env.envelope_id, data=data)
self._send_envelope(self, response_env_data)
except BaseException as exc:
response_env_error = EnvelopeResponseWithError(kind='RESPONSE', from_device_id=self.device_id, envelope_id=env.envelope_id, error=str(exc))
self._send_envelope(self, response_error)
elif env.kind == 'RESPONSE':
deferred = self._deferred_requests.get(env.envelope_id)
if not deferred:
return
if env.data:
deferred.resolve(env.data)
elif env.error:
deferred.reject(RpcErrorFromMethod(env.error))
else:
logger.warn('> RESPONSE has neither data nor error. This should never happen')
deferred.reject(RpcError('> RESPONSE has neither data nor error??'))
self._deferred_requests.remove(env.envelope_id)
def notify(self, method: MethodKey, *args: Parameters[BagType[MethodKey]]) -> None:
if self.is_closed:
raise RpcErrorUseAfterClose('the connection is closed')
env = EnvelopeNotify(
kind='NOTIFY',
from_device_id=self._device_id,
envelope_id=f'env:{make_id()}',
method=method,
args=args
)
self._send_envelope(self, env)
def request(self, method: str, *args: Parameters[BagType[MethodKey]]) -> ReturnType[BagType[MethodKey]]:
if self.is_closed:
raise RpcErrorUseAfterClose('the connection is closed')
env = EnvelopeRequest(kind='REQUEST', from_device_id=self._device_id, envelope_id=f'env:{make_id}', method=method, args=args)
deferred = make_deferred()
self._deferred_requests[env.envelope_id] = deferred
self._send_envelope(self, env)
return deferred.promise

View File

@ -0,0 +1,115 @@
"""RPC Envelope handling"""
from enum import Enum, auto
from typing import Any, Callable, Dict, List, Literal, Union
Fn = Callable[..., Any]
FnsBag = Dict[str, Fn]
# pylint: disable=too-few-public-methods
class EnvelopeKind(Enum):
"""Types of envelopes"""
NOTIFY = auto()
REQUEST = auto()
RESPONSE = auto()
class EnvelopeBase:
"""Base envelope type"""
kind: EnvelopeKind
from_device_id: str
envelope_id: str
def as_dict(self) -> Dict[str, Any]:
"""Convert the envelope to a dictionary"""
return {
'kind': self.kind.name,
'fromDeviceId': self.from_device_id,
'envelopeId': self.envelope_id,
}
class EnvelopeNotify(EnvelopeBase):
"""Envelope type for a notification"""
kind: Literal[EnvelopeKind.NOTIFY]
method: str
args: List[str]
def as_dict(self) -> Dict[str, Any]:
envelope = super().as_dict()
envelope.update(
{
'method': self.method,
'args': self.args,
}
)
return envelope
class EnvelopeRequest(EnvelopeBase):
"""Envelope type for a request"""
kind: Literal[EnvelopeKind.REQUEST]
method: str
args: List[str]
def as_dict(self) -> Dict[str, Any]:
envelope = super().as_dict()
envelope.update(
{
'method': self.method,
'args': self.args,
}
)
return envelope
class EnvelopeResponseWithData(EnvelopeBase):
"""Envelope type for a data response"""
kind: Literal[EnvelopeKind.RESPONSE]
data: List[str]
def as_dict(self) -> Dict[str, Any]:
envelope = super().as_dict()
envelope.update(
{
'data': self.data,
}
)
return envelope
class EnvelopeResponseWithError(EnvelopeBase):
"""Envelope type for an error response"""
kind: Literal[EnvelopeKind.RESPONSE]
error: str
def as_dict(self) -> Dict[str, Any]:
envelope = super().as_dict()
envelope.update(
{
'error': self.error,
}
)
EvelopeResponse = Union[EnvelopeResponseWithData, EnvelopeResponseWithError]
Envelope = Union[
EnvelopeNotify,
EnvelopeRequest,
EnvelopeResponseWithData,
EnvelopeResponseWithError,
]

View File

View File

@ -0,0 +1,2 @@
def make_id() -> str:
return str(randint(0, 999999999999999)).zfill(15)

View File

@ -1,8 +1,15 @@
"""Generic types and definitions
"""
from typing import Literal
ALPHA_LOWER = 'abcdefghijklmnopqrstuvwxyz'
ALPHA_UPPER = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
DIGIT = '0123456789'
B32_CHAR = ALPHA_LOWER + '234567'
ALPHA_LOWER_OR_DIGIT = ALPHA_LOWER + DIGIT
PRINTABLE_ASCII = bytes(range(32, 127)).decode('utf-8')
ASCII = bytes(range(128))
Cmp = Literal[-1, 0, 1]

11
earthsnake/util.py Normal file
View File

@ -0,0 +1,11 @@
from datetime import datetime
from random import random
def microsecond_now() -> int:
return int(datetime.utcnow().timestamp() * 1000)
def random_id() -> str:
# TODO: better randomness here
return f'{random()}{random()}'

134
earthsnake/watchable.py Normal file
View File

@ -0,0 +1,134 @@
"""Watchable variables"""
from typing import Callable, Dict, Generic, Iterable, Optional, Set, TypeVar
Thunk = Callable[[], None]
T = TypeVar('T')
CbOldNew = Callable[[T, T], None]
CbValue = Callable[[T], None]
class Watchable(Generic[T]):
"""A non-seamless proxy to watch a variables value"""
def __init__(self, value: T):
self._cbs: Set[CbOldNew[T]] = set()
self._cbs_by_target: Dict[T, Set[CbOldNew[T]]] = {}
self.value = value
def get(self) -> T:
"""Get the current value of the variable"""
return self.value
def set(self, new_val: T) -> None:
"""Set the variable to a new value"""
old_val = self.value
self.value = new_val
if new_val != old_val:
for func in self._cbs:
func(old_val, new_val)
for target_func in self._cbs_by_target.get(new_val, []):
target_func(old_val, new_val)
def on_change(self, func: CbOldNew[T]) -> Thunk:
"""Add a callback to be called when the variable changes"""
self._cbs.add(func)
def del_cb() -> None:
self._cbs.remove(func)
return del_cb
def on_change_to(self, target: T) -> Callable[[CbOldNew[T]], Thunk]:
"""Add a callback to be called when the variable is set to a specific value"""
def decorator(func: CbOldNew[T]) -> Thunk:
self._cbs_by_target[target].add(func)
def del_cb() -> None:
self._cbs_by_target[target].remove(func)
return del_cb
return decorator
class WatchableSet(Set[T]):
"""A set that can be watched for changes"""
def __init__(self, iterable: Optional[Iterable[T]] = None) -> None:
if iterable is None:
super().__init__()
else:
super().__init__(iterable)
self._add_cbs: Set[CbValue[T]] = set()
self._remove_cbs: Set[CbValue[T]] = set()
self._change_cbs: Set[Thunk] = set()
def add(self, value: T) -> None:
had = value in self
super().add(value)
if not had:
for func in self._add_cbs:
func(value)
for change_func in self._change_cbs:
change_func()
def remove(self, value: T) -> None:
had = value in self
super().remove(value)
if had:
for func in self._remove_cbs:
func(value)
for change_func in self._change_cbs:
change_func()
def clear(self) -> None:
for value in super().copy():
super().remove(value)
for func in self._remove_cbs:
func(value)
for change_func in self._change_cbs:
change_func()
def on_add(self, func: CbValue[T]) -> Thunk:
"""Add a callback function to be called when an item gets added to the set"""
self._add_cbs.add(func)
def del_cb() -> None:
self._add_cbs.remove(func)
return del_cb
def on_remove(self, func: CbValue[T]) -> Thunk:
"""Add a callback function to be called when an item gets removed from the set"""
self._remove_cbs.add(func)
def del_cb() -> None:
self._remove_cbs.remove(func)
return del_cb
def on_change(self, func: Thunk) -> Thunk:
"""Add a callback function to be called when the set changes"""
self._change_cbs.add(func)
def del_cb() -> None:
self._change_cbs.remove(func)
return del_cb