import unittest from pwo import retry, async_retry, async_test, AsyncQueueIterator, aenumerate from asyncio import Queue class PrivateTest(unittest.TestCase): def test_retry_until_success(self): max_attempts = 20 attempt = 0 expected_result = object() @retry(max_attempts=max_attempts, initial_delay=0) def foo(): nonlocal attempt attempt += 1 if attempt < 10: raise Exception() else: return expected_result self.assertEqual(expected_result, foo()) self.assertEqual(10, attempt) def test_retry_until_max_attempt(self): max_attempts = 20 attempt = 0 @retry(max_attempts=max_attempts, initial_delay=0) def bar(): nonlocal attempt attempt += 1 raise Exception() with self.assertRaises(Exception): bar() self.assertEqual(max_attempts, attempt) @async_test async def test_async_retry_until_success(self): max_attempts = 20 attempt = 0 expected_result = object() @async_retry(max_attempts=max_attempts, initial_delay=0) async def foo(): nonlocal attempt attempt += 1 if attempt < 10: raise Exception() else: return expected_result self.assertEqual(expected_result, await foo()) self.assertEqual(10, attempt) @async_test async def test_async_retry_until_max_attempt(self): max_attempts = 20 attempt = 0 @async_retry(max_attempts=max_attempts, initial_delay=0) async def bar(): nonlocal attempt attempt += 1 raise Exception() with self.assertRaises(Exception): await bar() self.assertEqual(max_attempts, attempt) @async_test async def test_async_queue_iterator(self): queue = Queue() queue_size = 10 objects = [object() for _ in range(queue_size)] async def poll() -> int: completed = 0 async for i, obj in aenumerate(AsyncQueueIterator(queue)): self.assertIs(objects[i], obj) completed += 1 return completed handle = poll() for o in objects: queue.put_nowait(o) queue.put_nowait(None) processed = await handle self.assertEqual(queue_size, processed)