128 lines
3.5 KiB
Python
128 lines
3.5 KiB
Python
import unittest
|
|
|
|
from pwo import retry, async_retry, async_test, AsyncQueueIterator, aenumerate, index_of_with_escape
|
|
from asyncio import Queue
|
|
|
|
|
|
class PrivateTest(unittest.TestCase):
|
|
def test_retry_until_success(self):
|
|
max_attempts = 20
|
|
attempt = 0
|
|
|
|
expected_result = object()
|
|
|
|
@retry(max_attempts=max_attempts, initial_delay=0)
|
|
def foo():
|
|
nonlocal attempt
|
|
attempt += 1
|
|
if attempt < 10:
|
|
raise Exception()
|
|
else:
|
|
return expected_result
|
|
|
|
self.assertEqual(expected_result, foo())
|
|
self.assertEqual(10, attempt)
|
|
|
|
def test_retry_until_max_attempt(self):
|
|
max_attempts = 20
|
|
attempt = 0
|
|
|
|
@retry(max_attempts=max_attempts, initial_delay=0)
|
|
def bar():
|
|
nonlocal attempt
|
|
attempt += 1
|
|
raise Exception()
|
|
|
|
with self.assertRaises(Exception):
|
|
bar()
|
|
self.assertEqual(max_attempts, attempt)
|
|
|
|
@async_test
|
|
async def test_async_retry_until_success(self):
|
|
max_attempts = 20
|
|
attempt = 0
|
|
|
|
expected_result = object()
|
|
|
|
@async_retry(max_attempts=max_attempts, initial_delay=0)
|
|
async def foo():
|
|
nonlocal attempt
|
|
attempt += 1
|
|
if attempt < 10:
|
|
raise Exception()
|
|
else:
|
|
return expected_result
|
|
|
|
self.assertEqual(expected_result, await foo())
|
|
self.assertEqual(10, attempt)
|
|
|
|
@async_test
|
|
async def test_async_retry_until_max_attempt(self):
|
|
max_attempts = 20
|
|
attempt = 0
|
|
|
|
@async_retry(max_attempts=max_attempts, initial_delay=0)
|
|
async def bar():
|
|
nonlocal attempt
|
|
attempt += 1
|
|
raise Exception()
|
|
|
|
with self.assertRaises(Exception):
|
|
await bar()
|
|
self.assertEqual(max_attempts, attempt)
|
|
|
|
@async_test
|
|
async def test_async_queue_iterator(self):
|
|
queue = Queue()
|
|
queue_size = 10
|
|
objects = [object() for _ in range(queue_size)]
|
|
|
|
async def poll() -> int:
|
|
completed = 0
|
|
async for i, obj in aenumerate(AsyncQueueIterator(queue)):
|
|
self.assertIs(objects[i], obj)
|
|
completed += 1
|
|
return completed
|
|
|
|
handle = poll()
|
|
|
|
for o in objects:
|
|
queue.put_nowait(o)
|
|
queue.put_nowait(None)
|
|
processed = await handle
|
|
self.assertEqual(queue_size, processed)
|
|
|
|
|
|
class TestIndexOfWithEscape(unittest.TestCase):
|
|
|
|
def run_test_case(self, haystack, needle, escape, expected_solution):
|
|
solution = []
|
|
i = 0
|
|
while True:
|
|
i = index_of_with_escape(haystack, needle, escape, i, len(haystack))
|
|
if i < 0:
|
|
break
|
|
solution.append(i)
|
|
i += 1
|
|
self.assertEqual(solution, expected_solution)
|
|
|
|
def test_simple(self):
|
|
self.run_test_case(" dsds $sdsa \\$dfivbdsf \\\\$sdgsga", '$', '\\', [6, 25])
|
|
|
|
def test_simple2(self):
|
|
self.run_test_case("asdasd$$vdfv$", '$', '$', [12])
|
|
|
|
def test_no_needle(self):
|
|
self.run_test_case("asdasd$$vdfv$", '#', '\\', [])
|
|
|
|
def test_escaped_needle(self):
|
|
self.run_test_case("asdasd$$vdfv$#sdfs", '#', '$', [])
|
|
|
|
def test_not_escaped_needle(self):
|
|
self.run_test_case("asdasd$$#vdfv$#sdfs", '#', '$', [8])
|
|
|
|
def test_special_case(self):
|
|
self.run_test_case("\n${sys:user.home}${env:HOME}", ':', '\\', [6, 22])
|
|
|
|
|