import unittest from pwo import retry, async_retry, async_test, AsyncQueueIterator, aenumerate, index_of_with_escape from asyncio import Queue class PrivateTest(unittest.TestCase): def test_retry_until_success(self): max_attempts = 20 attempt = 0 expected_result = object() @retry(max_attempts=max_attempts, initial_delay=0) def foo(): nonlocal attempt attempt += 1 if attempt < 10: raise Exception() else: return expected_result self.assertEqual(expected_result, foo()) self.assertEqual(10, attempt) def test_retry_until_max_attempt(self): max_attempts = 20 attempt = 0 @retry(max_attempts=max_attempts, initial_delay=0) def bar(): nonlocal attempt attempt += 1 raise Exception() with self.assertRaises(Exception): bar() self.assertEqual(max_attempts, attempt) @async_test async def test_async_retry_until_success(self): max_attempts = 20 attempt = 0 expected_result = object() @async_retry(max_attempts=max_attempts, initial_delay=0) async def foo(): nonlocal attempt attempt += 1 if attempt < 10: raise Exception() else: return expected_result self.assertEqual(expected_result, await foo()) self.assertEqual(10, attempt) @async_test async def test_async_retry_until_max_attempt(self): max_attempts = 20 attempt = 0 @async_retry(max_attempts=max_attempts, initial_delay=0) async def bar(): nonlocal attempt attempt += 1 raise Exception() with self.assertRaises(Exception): await bar() self.assertEqual(max_attempts, attempt) @async_test async def test_async_queue_iterator(self): queue = Queue() queue_size = 10 objects = [object() for _ in range(queue_size)] async def poll() -> int: completed = 0 async for i, obj in aenumerate(AsyncQueueIterator(queue)): self.assertIs(objects[i], obj) completed += 1 return completed handle = poll() for o in objects: queue.put_nowait(o) queue.put_nowait(None) processed = await handle self.assertEqual(queue_size, processed) class TestIndexOfWithEscape(unittest.TestCase): def run_test_case(self, haystack, needle, escape, expected_solution): solution = [] i = 0 while True: i = index_of_with_escape(haystack, needle, escape, i, len(haystack)) if i < 0: break solution.append(i) i += 1 self.assertEqual(solution, expected_solution) def test_simple(self): self.run_test_case(" dsds $sdsa \\$dfivbdsf \\\\$sdgsga", '$', '\\', [6, 25]) def test_simple2(self): self.run_test_case("asdasd$$vdfv$", '$', '$', [12]) def test_no_needle(self): self.run_test_case("asdasd$$vdfv$", '#', '\\', []) def test_escaped_needle(self): self.run_test_case("asdasd$$vdfv$#sdfs", '#', '$', []) def test_not_escaped_needle(self): self.run_test_case("asdasd$$#vdfv$#sdfs", '#', '$', [8]) def test_special_case(self): self.run_test_case("\n${sys:user.home}${env:HOME}", ':', '\\', [6, 22])