import unittest from src.pwo import retry, async_retry, async_test class PrivateTest(unittest.TestCase): def test_retry_until_success(self): max_attempts = 20 attempt = 0 expected_result = object() @retry(max_attempts=max_attempts, initial_delay=0) def foo(): nonlocal attempt attempt += 1 if attempt < 10: raise Exception() else: return expected_result self.assertEqual(expected_result, foo()) self.assertEqual(10, attempt) def test_retry_until_max_attempt(self): max_attempts = 20 attempt = 0 @retry(max_attempts=max_attempts, initial_delay=0) def bar(): nonlocal attempt attempt += 1 raise Exception() with self.assertRaises(Exception): bar() self.assertEqual(max_attempts, attempt) @async_test async def test_async_retry_until_success(self): max_attempts = 20 attempt = 0 expected_result = object() @async_retry(max_attempts=max_attempts, initial_delay=0) async def foo(): nonlocal attempt attempt += 1 if attempt < 10: raise Exception() else: return expected_result self.assertEqual(expected_result, await foo()) self.assertEqual(10, attempt) @async_test async def test_async_retry_until_max_attempt(self): max_attempts = 20 attempt = 0 @async_retry(max_attempts=max_attempts, initial_delay=0) async def bar(): nonlocal attempt attempt += 1 raise Exception() with self.assertRaises(Exception): await bar() self.assertEqual(max_attempts, attempt) if __name__ == '__main__': unittest.main()