added Try
All checks were successful
CI / build (push) Successful in 18s

This commit is contained in:
2024-11-03 22:39:08 +08:00
parent e0e763170f
commit 6ef7f8107e
10 changed files with 338 additions and 188 deletions

View File

@@ -18,7 +18,9 @@ jobs:
- name: Create virtualenv - name: Create virtualenv
run: | run: |
python -m venv .venv python -m venv .venv
.venv/bin/pip install -r requirements-dev.txt .venv/bin/pip install -r requirements.txt .
- name: Run mypy
run: .venv/bin/python -m mypy -p src
- name: Run unit tests - name: Run unit tests
run: .venv/bin/python -m unittest discover -s tests run: .venv/bin/python -m unittest discover -s tests
- name: Execute build - name: Execute build

View File

@@ -1,140 +0,0 @@
#
# This file is autogenerated by pip-compile with Python 3.12
# by the following command:
#
# pip-compile --extra=dev --output-file=requirements-dev.txt pyproject.toml
#
--index-url https://gitea.woggioni.net/api/packages/woggioni/pypi/simple
--extra-index-url https://pypi.org/simple
asttokens==2.4.1
# via stack-data
build==1.2.2.post1
# via
# pip-tools
# pwo (pyproject.toml)
certifi==2024.8.30
# via requests
cffi==1.17.1
# via cryptography
charset-normalizer==3.4.0
# via requests
click==8.1.7
# via pip-tools
cryptography==43.0.3
# via secretstorage
decorator==5.1.1
# via
# ipdb
# ipython
docutils==0.21.2
# via readme-renderer
executing==2.1.0
# via stack-data
idna==3.10
# via requests
importlib-metadata==8.5.0
# via twine
ipdb==0.13.13
# via pwo (pyproject.toml)
ipython==8.28.0
# via ipdb
jaraco-classes==3.4.0
# via keyring
jaraco-context==6.0.1
# via keyring
jaraco-functools==4.1.0
# via keyring
jedi==0.19.1
# via ipython
jeepney==0.8.0
# via
# keyring
# secretstorage
keyring==25.4.1
# via twine
markdown-it-py==3.0.0
# via rich
matplotlib-inline==0.1.7
# via ipython
mdurl==0.1.2
# via markdown-it-py
more-itertools==10.5.0
# via
# jaraco-classes
# jaraco-functools
mypy==1.13.0
# via pwo (pyproject.toml)
mypy-extensions==1.0.0
# via mypy
nh3==0.2.18
# via readme-renderer
packaging==24.1
# via build
parso==0.8.4
# via jedi
pexpect==4.9.0
# via ipython
pip-tools==7.4.1
# via pwo (pyproject.toml)
pkginfo==1.10.0
# via twine
prompt-toolkit==3.0.48
# via ipython
ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.3
# via stack-data
pycparser==2.22
# via cffi
pygments==2.18.0
# via
# ipython
# readme-renderer
# rich
pyproject-hooks==1.2.0
# via
# build
# pip-tools
readme-renderer==44.0
# via twine
requests==2.32.3
# via
# requests-toolbelt
# twine
requests-toolbelt==1.0.0
# via twine
rfc3986==2.0.0
# via twine
rich==13.9.3
# via twine
secretstorage==3.3.3
# via keyring
six==1.16.0
# via asttokens
stack-data==0.6.3
# via ipython
traitlets==5.14.3
# via
# ipython
# matplotlib-inline
twine==5.1.1
# via pwo (pyproject.toml)
typing-extensions==4.12.2
# via
# mypy
# pwo (pyproject.toml)
urllib3==2.2.3
# via
# requests
# twine
wcwidth==0.2.13
# via prompt-toolkit
wheel==0.44.0
# via pip-tools
zipp==3.20.2
# via importlib-metadata
# The following packages are considered to be unsafe in a requirements file:
# pip
# setuptools

View File

@@ -2,10 +2,139 @@
# This file is autogenerated by pip-compile with Python 3.12 # This file is autogenerated by pip-compile with Python 3.12
# by the following command: # by the following command:
# #
# pip-compile --output-file=requirements.txt pyproject.toml # pip-compile --extra=dev --output-file=requirements.txt
# #
--index-url https://gitea.woggioni.net/api/packages/woggioni/pypi/simple --index-url https://gitea.woggioni.net/api/packages/woggioni/pypi/simple
--extra-index-url https://pypi.org/simple --extra-index-url https://pypi.org/simple
typing-extensions==4.12.2 asttokens==2.4.1
# via stack-data
build==1.2.2.post1
# via
# pip-tools
# pwo (pyproject.toml)
certifi==2024.8.30
# via requests
cffi==1.17.1
# via cryptography
charset-normalizer==3.4.0
# via requests
click==8.1.7
# via pip-tools
cryptography==43.0.3
# via secretstorage
decorator==5.1.1
# via
# ipdb
# ipython
docutils==0.21.2
# via readme-renderer
executing==2.1.0
# via stack-data
idna==3.10
# via requests
importlib-metadata==8.5.0
# via twine
ipdb==0.13.13
# via pwo (pyproject.toml) # via pwo (pyproject.toml)
ipython==8.29.0
# via ipdb
jaraco-classes==3.4.0
# via keyring
jaraco-context==6.0.1
# via keyring
jaraco-functools==4.1.0
# via keyring
jedi==0.19.1
# via ipython
jeepney==0.8.0
# via
# keyring
# secretstorage
keyring==25.5.0
# via twine
markdown-it-py==3.0.0
# via rich
matplotlib-inline==0.1.7
# via ipython
mdurl==0.1.2
# via markdown-it-py
more-itertools==10.5.0
# via
# jaraco-classes
# jaraco-functools
mypy==1.13.0
# via pwo (pyproject.toml)
mypy-extensions==1.0.0
# via mypy
nh3==0.2.18
# via readme-renderer
packaging==24.1
# via build
parso==0.8.4
# via jedi
pexpect==4.9.0
# via ipython
pip-tools==7.4.1
# via pwo (pyproject.toml)
pkginfo==1.10.0
# via twine
prompt-toolkit==3.0.48
# via ipython
ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.3
# via stack-data
pycparser==2.22
# via cffi
pygments==2.18.0
# via
# ipython
# readme-renderer
# rich
pyproject-hooks==1.2.0
# via
# build
# pip-tools
readme-renderer==44.0
# via twine
requests==2.32.3
# via
# requests-toolbelt
# twine
requests-toolbelt==1.0.0
# via twine
rfc3986==2.0.0
# via twine
rich==13.9.4
# via twine
secretstorage==3.3.3
# via keyring
six==1.16.0
# via asttokens
stack-data==0.6.3
# via ipython
traitlets==5.14.3
# via
# ipython
# matplotlib-inline
twine==5.1.1
# via pwo (pyproject.toml)
typing-extensions==4.12.2
# via
# mypy
# pwo (pyproject.toml)
urllib3==2.2.3
# via
# requests
# twine
wcwidth==0.2.13
# via prompt-toolkit
wheel==0.44.0
# via pip-tools
zipp==3.20.2
# via importlib-metadata
# The following packages are considered to be unsafe in a requirements file:
# pip
# setuptools

View File

@@ -7,10 +7,12 @@ from .private import (
tmpdir, tmpdir,
decorator_with_kwargs, decorator_with_kwargs,
classproperty, classproperty,
AsyncQueueIterator AsyncQueueIterator,
aenumerate,
) )
from .maybe import Maybe from .maybe import Maybe
from .notification import TopicManager, Subscriber from .notification import TopicManager, Subscriber
from ._try import Try
__all__ = [ __all__ = [
'format_filesize', 'format_filesize',
@@ -24,5 +26,7 @@ __all__ = [
'classproperty', 'classproperty',
'TopicManager', 'TopicManager',
'Subscriber', 'Subscriber',
'AsyncQueueIterator' 'AsyncQueueIterator',
'aenumerate',
'Try'
] ]

52
src/pwo/_try.py Normal file
View File

@@ -0,0 +1,52 @@
from typing import (
Callable,
TypeVar,
Optional
)
ERR = TypeVar("ERR", bound=Exception)
class Try[T]:
value: T | Exception
def __init__(self, value: T | Exception):
self.value = value
def handle[U](self, cb: Callable[[Optional[T], Optional[Exception]], U]) -> 'Try[U]':
value = self.value
if isinstance(value, Exception):
return Try.of(lambda: cb(None, value))
else:
return Try.of(lambda: cb(value, None))
def get(self, alternative: Optional[T] = None) -> T:
if isinstance(self.value, Exception):
if alternative is None:
raise self.value
else:
return alternative
else:
return self.value
def then_try[U](self, cb: Callable[[T], U]) -> 'Try[U]':
value = self.value
if isinstance(value, Exception):
return Try.failure(value)
else:
return Try.of(lambda: cb(value))
@staticmethod
def success[U](value: U) -> 'Try[U]':
return Try(value)
@staticmethod
def failure[U](ex: Exception) -> 'Try[U]':
return Try(ex)
@staticmethod
def of[U](cb: Callable[[], U]) -> 'Try[U]':
try:
return Try(cb())
except Exception as ex:
return Try(ex)

View File

@@ -24,10 +24,11 @@ class Maybe(Generic[T]):
@property @property
def value(self) -> T: def value(self) -> T:
if self.is_empty: result = self._value
if result is None:
raise ValueError('Empty Maybe') raise ValueError('Empty Maybe')
else: else:
return self._value return result
@property @property
def is_present(self) -> bool: def is_present(self) -> bool:

View File

@@ -8,25 +8,31 @@ log = getLogger(__name__)
class Subscriber: class Subscriber:
_unsubscribe_callback: Callable[['Subscriber'], None] _unsubscribe_callback: Callable[['Subscriber'], None]
_event: Optional[Future] _event: Optional[Future[bool]]
_loop: AbstractEventLoop _loop: AbstractEventLoop
def __init__(self, unsubscribe: Callable[['Subscriber'], None], loop: AbstractEventLoop): def __init__(self, unsubscribe: Callable[['Subscriber'], None], loop: AbstractEventLoop):
self._unsubscribe_callback = unsubscribe self._unsubscribe_callback = unsubscribe
self._event: Optional[Future] = None self._event: Optional[Future[bool]] = None
self._loop = loop self._loop = loop
def unsubscribe(self) -> None: def unsubscribe(self) -> None:
self._event.cancel() evt = self._event
if evt is not None:
evt.cancel()
self._unsubscribe_callback(self) self._unsubscribe_callback(self)
log.debug('Deleted subscriber %s', id(self)) log.debug('Deleted subscriber %s', id(self))
async def wait(self, tout: float) -> bool: async def wait(self, tout: float) -> bool:
self._event = self._loop.create_future() self._event = self._loop.create_future()
def callback(): def callback() -> None:
if not self._event.done(): evt = self._event
self._event.set_result(False) if evt is None:
raise ValueError('Event is None')
evt.cancel()
if not evt.done():
evt.set_result(False)
handle = self._loop.call_later(tout, callback) handle = self._loop.call_later(tout, callback)
try: try:
@@ -39,8 +45,11 @@ class Subscriber:
def notify(self) -> None: def notify(self) -> None:
log.debug('Subscriber %s notified', id(self)) log.debug('Subscriber %s notified', id(self))
if not self._event.done(): evt = self._event
self._event.set_result(True) if evt is None:
raise ValueError('Event is None')
if not evt.done():
evt.set_result(True)
def reset(self) -> None: def reset(self) -> None:
self._event = self._loop.create_future() self._event = self._loop.create_future()
@@ -48,7 +57,7 @@ class Subscriber:
class TopicManager: class TopicManager:
_loop: AbstractEventLoop _loop: AbstractEventLoop
_queue: Queue _queue: Queue[Optional[str]]
_subscribers: dict[str, set[Subscriber]] _subscribers: dict[str, set[Subscriber]]
def __init__(self, loop: AbstractEventLoop): def __init__(self, loop: AbstractEventLoop):
@@ -60,7 +69,7 @@ class TopicManager:
subscriptions = self._subscribers subscriptions = self._subscribers
subscriptions_per_topic = subscriptions.setdefault(topic, set()) subscriptions_per_topic = subscriptions.setdefault(topic, set())
def unsubscribe_callback(subscription): def unsubscribe_callback(subscription: Subscriber) -> None:
subscriptions_per_topic.remove(subscription) subscriptions_per_topic.remove(subscription)
log.debug('Unsubscribed %s from topic %s', id(result), topic) log.debug('Unsubscribed %s from topic %s', id(result), topic)
@@ -69,7 +78,7 @@ class TopicManager:
subscriptions_per_topic.add(result) subscriptions_per_topic.add(result)
return result return result
def _notify_subscriptions(self, topic): def _notify_subscriptions(self, topic: str) -> None:
subscriptions = self._subscribers subscriptions = self._subscribers
subscriptions_per_topic = subscriptions.get(topic, None) subscriptions_per_topic = subscriptions.get(topic, None)
if subscriptions_per_topic: if subscriptions_per_topic:
@@ -77,14 +86,14 @@ class TopicManager:
for s in subscriptions_per_topic: for s in subscriptions_per_topic:
s.notify() s.notify()
async def process_events(self): async def process_events(self) -> None:
async for evt in AsyncQueueIterator(self._queue): async for evt in AsyncQueueIterator(self._queue):
log.debug(f"Processed event for topic '{evt}'") log.debug(f"Processed event for topic '{evt}'")
self._notify_subscriptions(evt) self._notify_subscriptions(evt)
log.debug(f"Event processor has completed") log.debug(f"Event processor has completed")
def post_event(self, topic): def post_event(self, topic: str) -> None:
def callback(): def callback() -> None:
self._queue.put_nowait(topic) self._queue.put_nowait(topic)
log.debug(f"Posted event for topic '{topic}', queue size: {self._queue.qsize()}") log.debug(f"Posted event for topic '{topic}', queue size: {self._queue.qsize()}")

View File

@@ -6,10 +6,21 @@ from inspect import signature
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from time import sleep from time import sleep
from typing import Callable, AsyncIterator from typing import (
Callable,
AsyncIterator,
Optional,
Self,
AsyncIterable,
Awaitable,
Any,
Coroutine,
Never,
Tuple
)
def decorator_with_kwargs(decorator: Callable) -> Callable: def decorator_with_kwargs(decorator: Callable[..., Any]) -> Callable[..., Any]:
"""Decorator factory to give decorated decorators the skill to receive """Decorator factory to give decorated decorators the skill to receive
optional keyword arguments. optional keyword arguments.
@@ -48,7 +59,7 @@ def decorator_with_kwargs(decorator: Callable) -> Callable:
""" """
@wraps(decorator) @wraps(decorator)
def decorator_wrapper(*args, **kwargs): def decorator_wrapper(*args, **kwargs) -> Any: # type: ignore
if (len(kwargs) == 0) and (len(args) == 1) and callable(args[0]): if (len(kwargs) == 0) and (len(args) == 1) and callable(args[0]):
return decorator(args[0]) return decorator(args[0])
if len(args) == 0: if len(args) == 0:
@@ -94,15 +105,15 @@ class ExceptionHandlerOutcome(Enum):
@decorator_with_kwargs @decorator_with_kwargs
def retry( def retry(
function, function: Callable[..., Any],
max_attempts: int = 3, max_attempts: int = 3,
multiplier: float = 2, multiplier: float = 2,
initial_delay: float = 1.0, initial_delay: float = 1.0,
exception_handler: Callable[[Exception], ExceptionHandlerOutcome] = exception_handler: Callable[[Exception], ExceptionHandlerOutcome] =
lambda _: ExceptionHandlerOutcome.CONTINUE lambda _: ExceptionHandlerOutcome.CONTINUE
): ) -> Callable[..., Any]:
@wraps(function) @wraps(function)
def result(*args, **kwargs): def result(*args: Any, **kwargs: Any) -> Any:
attempts = 0 attempts = 0
delay = initial_delay delay = initial_delay
while True: while True:
@@ -121,14 +132,14 @@ def retry(
@decorator_with_kwargs @decorator_with_kwargs
def async_retry( def async_retry(
function, function: Callable[..., Any],
max_attempts: int = 3, max_attempts: int = 3,
multiplier: float = 2, multiplier: float = 2,
initial_delay: float = 1.0, initial_delay: float = 1.0,
exception_handler=lambda _: ExceptionHandlerOutcome.CONTINUE exception_handler: Callable[[Exception], ExceptionHandlerOutcome] = lambda _: ExceptionHandlerOutcome.CONTINUE
): ) -> Callable[..., Any]:
@wraps(function) @wraps(function)
async def result(*args, **kwargs): async def result(*args: Any, **kwargs: Any) -> Any:
attempts = 0 attempts = 0
delay = initial_delay delay = initial_delay
while True: while True:
@@ -145,16 +156,16 @@ def async_retry(
return result return result
def async_test(coro): def async_test(coro: Callable[..., Coroutine[Never, Never, None]]) -> Callable[..., None]:
@wraps(coro) @wraps(coro)
def wrapper(*args, **kwargs): def wrapper(*args: Any, **kwargs: Any) -> None:
with Runner() as runner: with Runner() as runner:
runner.run(coro(*args, **kwargs)) runner.run(coro(*args, **kwargs))
return wrapper return wrapper
@decorator_with_kwargs @decorator_with_kwargs # type: ignore
def tmpdir(f, def tmpdir(f,
argument_name='temp_dir', argument_name='temp_dir',
suffix=None, suffix=None,
@@ -162,7 +173,7 @@ def tmpdir(f,
dir=None, dir=None,
ignore_cleanup_errors=False, ignore_cleanup_errors=False,
delete=True): delete=True):
@wraps(f) @wraps(f) # type: ignore
def result(*args, **kwargs): def result(*args, **kwargs):
with TemporaryDirectory( with TemporaryDirectory(
suffix=suffix, suffix=suffix,
@@ -177,48 +188,67 @@ def tmpdir(f,
return result return result
class ClassPropertyDescriptor: class ClassPropertyDescriptor[T]:
def __init__(self, fget, fset=None): def __init__(self, fget: Callable[[], T], fset: Optional[Callable[[T], None]]=None):
self.fget = fget self.fget = fget
self.fset = fset self.fset = fset
def __get__(self, obj, klass=None): def __get__(self, obj, klass=None): # type: ignore
if klass is None: if klass is None:
klass = type(obj) klass = type(obj)
return self.fget.__get__(obj, klass)() return self.fget.__get__(obj, klass)()
def __set__(self, obj, value): def __set__(self, obj, value): # type: ignore
if not self.fset: if not self.fset:
raise AttributeError("can't set attribute") raise AttributeError("can't set attribute")
type_ = type(obj) type_ = type(obj)
return self.fset.__get__(obj, type_)(value) return self.fset.__get__(obj, type_)(value)
def setter(self, func): def setter(self, func): # type: ignore
if not isinstance(func, (classmethod, staticmethod)): if not isinstance(func, (classmethod, staticmethod)):
func = classmethod(func) func = classmethod(func)
self.fset = func self.fset = func
return self return self
def classproperty(func): def classproperty(func): # type: ignore
if not isinstance(func, (classmethod, staticmethod)): if not isinstance(func, (classmethod, staticmethod)):
func = classmethod(func) func = classmethod(func)
return ClassPropertyDescriptor(func) return ClassPropertyDescriptor(func)
class AsyncQueueIterator[T]: class AsyncQueueIterator[T](AsyncIterator[T]):
_queue: Queue[T] _queue: Queue[Optional[T]]
def __init__(self, queue: Queue[T]): def __init__(self, queue: Queue[Optional[T]]):
self._queue = queue self._queue = queue
def __aiter__(self) -> AsyncIterator[T]: def __aiter__(self) -> AsyncIterator[T]:
return self return self
async def __anext__(self) -> [T]: async def __anext__(self) -> T:
item = await self._queue.get() item = await self._queue.get()
if item is None: if item is None:
raise StopAsyncIteration raise StopAsyncIteration
return item return item
class aenumerate[T](AsyncIterator[Tuple[int, T]]):
"""enumerate for async for"""
_aiterable: AsyncIterable[T]
_i: int
def __init__(self, aiterable: AsyncIterable[T], start: int = 0):
self._aiterable = aiterable
self._i = start - 1
def __aiter__(self) -> Self:
self._ait = self._aiterable.__aiter__()
return self
async def __anext__(self) -> Tuple[int, T]:
val = await self._ait.__anext__()
self._i += 1
return self._i, val

View File

@@ -1,6 +1,7 @@
import unittest import unittest
from src.pwo import retry, async_retry, async_test from pwo import retry, async_retry, async_test, AsyncQueueIterator, aenumerate
from asyncio import Queue
class PrivateTest(unittest.TestCase): class PrivateTest(unittest.TestCase):
@@ -70,5 +71,25 @@ class PrivateTest(unittest.TestCase):
await bar() await bar()
self.assertEqual(max_attempts, attempt) self.assertEqual(max_attempts, attempt)
if __name__ == '__main__': @async_test
unittest.main() async def test_async_queue_iterator(self):
queue = Queue()
queue_size = 10
objects = [object() for _ in range(queue_size)]
async def poll() -> int:
completed = 0
async for i, obj in aenumerate(AsyncQueueIterator(queue)):
self.assertIs(objects[i], obj)
completed += 1
return completed
handle = poll()
for o in objects:
queue.put_nowait(o)
queue.put_nowait(None)
processed = await handle
self.assertEqual(queue_size, processed)

42
tests/test_try.py Normal file
View File

@@ -0,0 +1,42 @@
import unittest
from pwo import Try
class TestException(Exception):
def __init__(self, msg: str):
super().__init__(msg)
class TryTest(unittest.TestCase):
def setUp(self):
pass
def test_try(self):
with self.subTest("Test failure"):
def throw_test_exception():
raise TestException("error")
t = Try.of(throw_test_exception)
with self.assertRaises(TestException):
t.get()
t = Try.failure(TestException("error"))
with self.assertRaises(TestException):
t.get()
with self.subTest("Test success"):
def complete_successfully():
return 42
t = Try.of(complete_successfully)
self.assertEqual(42, t.get())
t2 = t.handle(lambda value, err: value * 2)
self.assertEqual(84, t2.get())