From ee6e645cc15063f7b64f8983b86d8152a2d20711 Mon Sep 17 00:00:00 2001 From: Walter Oggioni Date: Tue, 12 Nov 2024 07:22:45 +0800 Subject: [PATCH] tmp --- core/src/bugis/core/_app.py | 70 ++++++----- core/src/bugis/core/_node.py | 18 +++ core/src/bugis/core/_path_handler.py | 21 ++++ core/src/bugis/core/_path_matcher.py | 87 +++++++++++++ core/src/bugis/core/_rsgi.py | 5 +- core/src/bugis/core/_tree.py | 164 ++++++++++++++++--------- core/src/bugis/core/_types.py | 70 ----------- core/src/bugis/core/_types/__init__.py | 96 ++++++++++++++- core/tests/test_asgi.py | 89 +++++++++++++- core/tests/test_tree.py | 19 ++- 10 files changed, 475 insertions(+), 164 deletions(-) create mode 100644 core/src/bugis/core/_node.py create mode 100644 core/src/bugis/core/_path_handler.py create mode 100644 core/src/bugis/core/_path_matcher.py delete mode 100644 core/src/bugis/core/_types.py diff --git a/core/src/bugis/core/_app.py b/core/src/bugis/core/_app.py index a41f836..f69c2b9 100644 --- a/core/src/bugis/core/_app.py +++ b/core/src/bugis/core/_app.py @@ -2,16 +2,17 @@ from abc import ABC, abstractmethod from asyncio import Queue, AbstractEventLoop from asyncio import get_running_loop from logging import getLogger -from typing import Callable, Awaitable, Any, Mapping, Sequence, Optional +from typing import Callable, Awaitable, Any, Mapping, Sequence, Optional, Unpack from pwo import Maybe, AsyncQueueIterator from ._http_context import HttpContext from ._http_method import HttpMethod +from ._types import StrOrStrings try: from ._rsgi import RsgiContext - from granian._granian import RSGIHTTPProtocol, RSGIHTTPScope # type: ignore + from granian._granian import RSGIHTTPProtocol, RSGIHTTPScope # type: ignore except ImportError: pass @@ -21,11 +22,12 @@ from ._types.asgi import LifespanScope, HTTPScope as ASGIHTTPScope, WebSocketSco log = getLogger(__name__) -type HttpHandler = Callable[[HttpContext], Awaitable[None]] +type HttpHandler = Callable[[HttpContext, Unpack], Awaitable[None]] + class AbstractBugisApp(ABC): async def __call__(self, - scope: ASGIHTTPScope|WebSocketScope|LifespanScope, + scope: ASGIHTTPScope | WebSocketScope | LifespanScope, receive: Callable[[], Awaitable[Any]], send: Callable[[Mapping[str, Any]], Awaitable[None]]) -> None: loop = get_running_loop() @@ -84,43 +86,55 @@ class BugisApp(AbstractBugisApp): self._tree = Tree() async def handle_request(self, ctx: HttpContext) -> None: - handler = self._tree.get_handler(ctx.path, ctx.method) - if handler is not None: - await handler.handle_request(ctx) + result = self._tree.get_handler(ctx.path, ctx.method) + if result is not None: + handler, captured = result + await handler.handle_request(ctx, captured) else: await ctx.send_empty(404) pass def route(self, - path: str, - methods: Optional[Sequence[HttpMethod]] = None) -> Callable[[HttpHandler], HttpHandler]: + paths: StrOrStrings, + methods: Optional[HttpMethod | Sequence[HttpMethod]] = None, + recursive: bool = False) -> Callable[[HttpHandler], HttpHandler]: + def wrapped(handler: HttpHandler) -> HttpHandler: - if methods is not None: - for method in methods: - self._tree.register(path, method, handler) - else: - self._tree.register(path, None, handler) + nonlocal methods + nonlocal paths + if methods is None: + methods = (None,) + elif isinstance(methods, HttpMethod): + methods = (methods,) + if isinstance(paths, str): + paths = (paths,) + for method in methods: + if isinstance(paths, str): + self._tree.register(paths, method, handler, recursive) + else: + for path in paths: + self._tree.register(path, method, handler, recursive) return handler return wrapped - def GET(self, path: str) -> Callable[[HttpHandler], HttpHandler]: - return self.route(path, (HttpMethod.GET,)) + def GET(self, path: str, recursive: bool = False) -> Callable[[HttpHandler], HttpHandler]: + return self.route(path, (HttpMethod.GET,), recursive) - def POST(self, path: str) -> Callable[[HttpHandler], HttpHandler]: - return self.route(path, (HttpMethod.POST,)) + def POST(self, path: str, recursive: bool = False) -> Callable[[HttpHandler], HttpHandler]: + return self.route(path, (HttpMethod.POST,), recursive) - def PUT(self, path: str) -> Callable[[HttpHandler], HttpHandler]: - return self.route(path, (HttpMethod.PUT,)) + def PUT(self, path: str, recursive: bool = False) -> Callable[[HttpHandler], HttpHandler]: + return self.route(path, (HttpMethod.PUT,), recursive) - def DELETE(self, path: str) -> Callable[[HttpHandler], HttpHandler]: - return self.route(path, (HttpMethod.DELETE,)) + def DELETE(self, path: str, recursive: bool = False) -> Callable[[HttpHandler], HttpHandler]: + return self.route(path, (HttpMethod.DELETE,), recursive) - def OPTIONS(self, path: str) -> Callable[[HttpHandler], HttpHandler]: - return self.route(path, (HttpMethod.OPTIONS,)) + def OPTIONS(self, path: str, recursive: bool = False) -> Callable[[HttpHandler], HttpHandler]: + return self.route(path, (HttpMethod.OPTIONS,), recursive) - def HEAD(self, path: str) -> Callable[[HttpHandler], HttpHandler]: - return self.route(path, (HttpMethod.HEAD,)) + def HEAD(self, path: str, recursive: bool = False) -> Callable[[HttpHandler], HttpHandler]: + return self.route(path, (HttpMethod.HEAD,), recursive) - def PATCH(self, path: str) -> Callable[[HttpHandler], HttpHandler]: - return self.route(path, (HttpMethod.PATCH,)) + def PATCH(self, path: str, recursive: bool = False) -> Callable[[HttpHandler], HttpHandler]: + return self.route(path, (HttpMethod.PATCH,), recursive) diff --git a/core/src/bugis/core/_node.py b/core/src/bugis/core/_node.py new file mode 100644 index 0000000..741df27 --- /dev/null +++ b/core/src/bugis/core/_node.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass +from typing import ( + Optional, + Dict, + List, +) +from ._types import NodeType +# from ._path_handler import PathHandler +# from ._path_matcher import PathMatcher + + +@dataclass +class Node: + key: NodeType + parent: Optional['Node'] + children: Dict[NodeType, 'Node'] + handlers: List['PathHandler'] + path_matchers: List['PathMatcher'] diff --git a/core/src/bugis/core/_path_handler.py b/core/src/bugis/core/_path_handler.py new file mode 100644 index 0000000..ea9240a --- /dev/null +++ b/core/src/bugis/core/_path_handler.py @@ -0,0 +1,21 @@ +from abc import ABC, abstractmethod +from typing import Sequence, Mapping, Any +from ._http_method import HttpMethod +from ._http_context import HttpContext +from ._types import PathMatcherResult, Matches + + +class PathHandler(ABC): + recursive: bool + + @abstractmethod + async def handle_request(self, ctx: HttpContext, captured: Matches) -> None: + pass + + @property + @abstractmethod + def recursive(self) -> bool: + raise NotImplementedError() + + +type PathHandlers = ('PathHandler' | Sequence['PathHandler']) diff --git a/core/src/bugis/core/_path_matcher.py b/core/src/bugis/core/_path_matcher.py new file mode 100644 index 0000000..724074d --- /dev/null +++ b/core/src/bugis/core/_path_matcher.py @@ -0,0 +1,87 @@ +from fnmatch import fnmatch +from abc import ABC, abstractmethod +from typing import Optional, Sequence, Dict, Mapping + +from ._path_handler import PathHandler +from ._types import NodeType, PathMatcherResult +from ._node import Node + + +class PathMatcher(ABC): + parent: Optional[Node] + children: Dict[NodeType, Node] + handlers: Sequence[PathHandler] + path_matchers: Sequence['PathMatcher'] + + def __init__(self, + parent: Optional[Node], + children: Dict[NodeType, Node], + handlers: Sequence[PathHandler], + path_matchers: Sequence['PathMatcher'] + ): + self.parent = parent + self.children = children + self.handlers = handlers + self.path_matchers = path_matchers + + @abstractmethod + def match(self, path: Sequence[str]) -> Optional[PathMatcherResult]: + pass + + +class StrMatcher(PathMatcher): + name: str + + def __init__(self, + name: str, + parent: Optional[Node], + children: Dict[NodeType, Node], + handlers: Sequence[PathHandler], + path_matchers: Sequence[PathMatcher], + ): + super().__init__(parent, children, handlers, path_matchers) + self.name = name + + def match(self, path: Sequence[str]) -> Optional[PathMatcherResult]: + if len(path): + return {self.name: path[0]} + else: + return None + + +class IntMatcher(PathMatcher): + name: str + + def __init__(self, + name: str, + parent: Optional[Node], + children: Dict[NodeType, Node], + handlers: Sequence[PathHandler], + path_matchers: Sequence[PathMatcher], + ): + super().__init__(parent, children, handlers, path_matchers) + self.name = name + + def match(self, path: Sequence[str]) -> Optional[PathMatcherResult]: + if len(path) > 0: + try: + return {self.name: int(path[0])} + except ValueError: + return None + + +class GlobMatcher(PathMatcher): + pattern: str + + def __init__(self, + pattern: str, + parent: Optional[Node], + children: Dict[NodeType, Node], + handlers: Sequence[PathHandler], + path_matchers: Sequence[PathMatcher], + ): + super().__init__(parent, children, handlers, path_matchers) + self.pattern = pattern + + def match(self, path: Sequence[str]) -> Optional[PathMatcherResult]: + return path if fnmatch('/'.join(path), self.pattern) else None diff --git a/core/src/bugis/core/_rsgi.py b/core/src/bugis/core/_rsgi.py index e19980b..5289cab 100644 --- a/core/src/bugis/core/_rsgi.py +++ b/core/src/bugis/core/_rsgi.py @@ -14,8 +14,7 @@ from typing import ( cast ) -from granian.rsgi import Scope -from granian._granian import RSGIHTTPProtocol +from granian._granian import RSGIHTTPProtocol, RSGIHTTPScope from pwo import Maybe from ._http_context import HttpContext @@ -34,7 +33,7 @@ class RsgiContext(HttpContext): request_body: AsyncIterator[bytes] head = Optional[Tuple[int, Sequence[Tuple[str, str]]]] - def __init__(self, scope: Scope, protocol: RSGIHTTPProtocol): + def __init__(self, scope: RSGIHTTPScope, protocol: RSGIHTTPProtocol): self.scheme = scope.scheme self.path = scope.path self.method = HttpMethod(scope.method) diff --git a/core/src/bugis/core/_tree.py b/core/src/bugis/core/_tree.py index 5506798..813b7d2 100644 --- a/core/src/bugis/core/_tree.py +++ b/core/src/bugis/core/_tree.py @@ -1,56 +1,67 @@ -from typing import Sequence, Dict, Awaitable, Callable, Optional, Generator, Self, List -from ._http_method import HttpMethod -from ._http_context import HttpContext -from dataclasses import dataclass -from abc import ABC, abstractmethod from itertools import chain +from typing import ( + Sequence, + Awaitable, + Callable, + Optional, + Generator, + Self, + List, + Tuple, + Mapping, + Any, + Dict +) +from typing_extensions import Unpack from urllib.parse import urlparse -from pwo import Maybe -type NodeType = (str | HttpMethod) +from pwo import Maybe, index_of_with_escape -type PathHandlers = (PathHandler | Sequence[PathHandler]) - - -class PathHandler(ABC): - - @abstractmethod - def match(self, subpath: Sequence[str], method: HttpMethod) -> bool: - raise NotImplementedError() - - @abstractmethod - async def handle_request(self, ctx: HttpContext) -> None: - pass - - -@dataclass -class Node: - key: NodeType - parent: Optional['Node'] - children: Dict[NodeType, 'Node'] - handlers: Sequence[PathHandler] +from ._http_context import HttpContext +from ._http_method import HttpMethod +from ._node import Node +from ._path_handler import PathHandler +from ._path_matcher import PathMatcher, IntMatcher, GlobMatcher, StrMatcher +from ._types import NodeType, PathMatcherResult, Matches class Tree: def __init__(self) -> None: - self.root = Node('/', None, {}, []) + self.root = Node('/', None, {}, [], []) - def search(self, path: Generator[str, None, None], method: HttpMethod) -> Optional[Node]: - lineage: Generator[NodeType, None, None] = (it for it in chain(path, (method,))) + def search(self, path: Generator[str, None, None], method: HttpMethod) \ + -> Optional[Tuple[Node | PathMatcher, Matches]]: + path: List = list(path) result = self.root - it = iter(lineage) + + matches = Matches() + it, i = iter((it for it in path)), -1 while True: node = result - leaf = next(it, None) + leaf, i = next(it, None), i + 1 if leaf is None: break child = node.children.get(leaf) - if child is None: - break + if child is None and isinstance(leaf, str): + for matcher in node.path_matchers: + match = matcher.match(path[i:]) + if match is not None: + if isinstance(match, Mapping): + matches.kwargs.update(match) + elif isinstance(match, Sequence): + matches.path = match + result = matcher + break + else: + break else: result = child - return None if result == self.root else result + child = result.children.get(method) + if child is not None: + result = child + matches.unmatched_paths = path[i:] + return None if result == self.root else (result, matches) def add(self, path: Generator[str, None, None], method: Optional[HttpMethod], *path_handlers: PathHandler) -> Node: lineage: Generator[NodeType, None, None] = (it for it in @@ -73,8 +84,11 @@ class Tree: result = child key = leaf while key is not None: - new_node = Node(key=key, parent=result, children={}, handlers=[]) - result.children[key] = new_node + new_node = self.parse(key, result) + if isinstance(new_node, Node): + result.children[key] = new_node + else: + result.path_matchers.append(new_node) result = new_node key = next(it, None) @@ -84,38 +98,78 @@ class Tree: def register(self, path: str, method: Optional[HttpMethod], - callback: Callable[[HttpContext], Awaitable[None]]) -> None: + callback: Callable[[HttpContext, Unpack], Awaitable[None]], + recursive) -> None: class Handler(PathHandler): - def match(self, subpath: Sequence[str], method: HttpMethod) -> bool: - return len(subpath) == 0 + async def handle_request(self, ctx: HttpContext, captured: Matches) -> None: + args = Maybe.of_nullable(captured.path).map(lambda it: [it]).or_else([]) + await callback(ctx, *args, **captured.kwargs) + + @property + def recursive(self) -> bool: + return recursive - async def handle_request(self, ctx: HttpContext) -> None: - await callback(ctx) handler = Handler() self.add((p for p in PathIterator(path)), method, handler) - def find_node(self, path: Generator[str, None, None], method: HttpMethod = HttpMethod.GET) -> Optional[Node]: + def find_node(self, path: Generator[str, None, None], method: HttpMethod = HttpMethod.GET) \ + -> Optional[Tuple[Node | PathMatcher, Matches]]: return (Maybe.of_nullable(self.search(path, method)) - .filter(lambda it: len(it.handlers) > 0) + .filter(lambda it: len(it[0].handlers) > 0) .or_none()) - def get_handler(self, url: str, method: HttpMethod = HttpMethod.GET) -> Optional[PathHandler]: + def get_handler(self, url: str, method: HttpMethod = HttpMethod.GET) \ + -> Optional[Tuple[PathHandler, Matches]]: path = urlparse(url).path - node = self.find_node((p for p in PathIterator(path)), method) - if node is None: + result: Optional[Tuple[Node | PathMatcher, Matches]] = self.find_node((p for p in PathIterator(path)), method) + if result is None: return None - requested = (p for p in PathIterator(path)) - found = reversed([n.key for n in NodeAncestryIterator(node) if n.key != '/']) - unmatched: List[str] = [] - for r, f in zip(requested, found): - if f is None: - unmatched.append(r) + node, captured = result + # requested = (p for p in PathIterator(path)) + # found = reversed([n for n in NodeAncestryIterator(node) if n != self.root]) + # unmatched: List[str] = [] + # for r, f in zip(requested, found): + # if f is None: + # unmatched.append(r) for handler in node.handlers: - if handler.match(unmatched, method): - return handler + if len(captured.unmatched_paths) == 0: + return handler, captured + elif handler.recursive: + return handler, captured + # if handler.match(unmatched, method): + # return (handler, unmatched) return None + def parse(self, leaf : str, parent : Node | PathMatcher) -> Node | PathMatcher: + start = 0 + result = index_of_with_escape(leaf, '${', '\\', 0) + if result >= 0: + start = result + 2 + end = leaf.index('}', start + 2) + definition = leaf[start:end] + try: + colon = definition.index(':') + except ValueError: + colon = None + if colon is None: + key = definition + kind = 'str' + else: + key = definition[:colon] + kind = definition[colon+1:] if colon is not None else 'str' + if kind == 'str': + return StrMatcher(name=key, parent=parent, children={}, handlers=[], path_matchers=[]) + elif kind == 'int': + return IntMatcher(name=key, parent=parent, children={}, handlers=[], path_matchers=[]) + else: + raise ValueError(f"Unknown kind: '{kind}'") + result = index_of_with_escape(leaf, '*', '\\', 0) + if result >= 0: + return GlobMatcher(pattern=leaf, parent=parent, children={}, handlers=[], path_matchers=[]) + else: + return Node(key=leaf, parent=parent, children={}, handlers=[], path_matchers=[]) + class PathIterator: path: str diff --git a/core/src/bugis/core/_types.py b/core/src/bugis/core/_types.py deleted file mode 100644 index 667197d..0000000 --- a/core/src/bugis/core/_types.py +++ /dev/null @@ -1,70 +0,0 @@ -from typing import ( - Sequence, - TypedDict, - Literal, - Iterable, - Tuple, - Optional, - NotRequired, - Dict, - Any, - Union -) - -type StrOrStrings = (str | Sequence[str]) - -class ASGIVersions(TypedDict): - spec_version: str - version: Union[Literal["2.0"], Literal["3.0"]] - - -class HTTPScope(TypedDict): - type: Literal["http"] - asgi: ASGIVersions - http_version: str - method: str - scheme: str - path: str - raw_path: bytes - query_string: bytes - root_path: str - headers: Iterable[Tuple[bytes, bytes]] - client: Optional[Tuple[str, int]] - server: Optional[Tuple[str, Optional[int]]] - state: NotRequired[Dict[str, Any]] - extensions: Optional[Dict[str, Dict[object, object]]] -class WebSocketScope(TypedDict): - type: Literal["websocket"] - asgi: ASGIVersions - http_version: str - scheme: str - path: str - raw_path: bytes - query_string: bytes - root_path: str - headers: Iterable[Tuple[bytes, bytes]] - client: Optional[Tuple[str, int]] - server: Optional[Tuple[str, Optional[int]]] - subprotocols: Iterable[str] - state: NotRequired[Dict[str, Any]] - extensions: Optional[Dict[str, Dict[object, object]]] - - -class LifespanScope(TypedDict): - type: Literal["lifespan"] - asgi: ASGIVersions - state: NotRequired[Dict[str, Any]] - -class RSGI: - class Scope(TypedDict): - proto: Literal['http'] = 'http' - rsgi_version: str - http_version: str - server: str - client: str - scheme: str - method: str - path: str - query_string: str - headers: Mapping[str, str] - authority: Optional[str] \ No newline at end of file diff --git a/core/src/bugis/core/_types/__init__.py b/core/src/bugis/core/_types/__init__.py index 8e6574f..eae7b4d 100644 --- a/core/src/bugis/core/_types/__init__.py +++ b/core/src/bugis/core/_types/__init__.py @@ -1,3 +1,95 @@ -from typing import Sequence +from typing import ( + TypedDict, + Literal, + Iterable, + Tuple, + Optional, + NotRequired, + Dict, + Any, + Union, + Mapping, + Sequence +) +from dataclasses import dataclass, field + +from bugis.core._http_method import HttpMethod + +type StrOrStrings = (str | Sequence[str]) + +type NodeType = (str | HttpMethod) + +type PathHandlers = ('PathHandler' | Sequence['PathHandler']) + +type PathMatcherResult = Mapping[str, str] | Sequence[str] + + +@dataclass +class Matches: + + kwargs: Dict[str, str] = field(default_factory=dict) + + path: Optional[Sequence[str]] = None + + unmatched_paths: [Sequence[str]] = field(default_factory=list) + + +class ASGIVersions(TypedDict): + spec_version: str + version: Union[Literal["2.0"], Literal["3.0"]] + + +class HTTPScope(TypedDict): + type: Literal["http"] + asgi: ASGIVersions + http_version: str + method: str + scheme: str + path: str + raw_path: bytes + query_string: bytes + root_path: str + headers: Iterable[Tuple[bytes, bytes]] + client: Optional[Tuple[str, int]] + server: Optional[Tuple[str, Optional[int]]] + state: NotRequired[Dict[str, Any]] + extensions: Optional[Dict[str, Dict[object, object]]] + + +class WebSocketScope(TypedDict): + type: Literal["websocket"] + asgi: ASGIVersions + http_version: str + scheme: str + path: str + raw_path: bytes + query_string: bytes + root_path: str + headers: Iterable[Tuple[bytes, bytes]] + client: Optional[Tuple[str, int]] + server: Optional[Tuple[str, Optional[int]]] + subprotocols: Iterable[str] + state: NotRequired[Dict[str, Any]] + extensions: Optional[Dict[str, Dict[object, object]]] + + +class LifespanScope(TypedDict): + type: Literal["lifespan"] + asgi: ASGIVersions + state: NotRequired[Dict[str, Any]] + + +class RSGI: + class Scope(TypedDict): + proto: Literal['http'] = 'http' + rsgi_version: str + http_version: str + server: str + client: str + scheme: str + method: str + path: str + query_string: str + headers: Mapping[str, str] + authority: Optional[str] -type StrOrStrings = (str | Sequence[str]) \ No newline at end of file diff --git a/core/tests/test_asgi.py b/core/tests/test_asgi.py index ccd4ef1..72e74b4 100644 --- a/core/tests/test_asgi.py +++ b/core/tests/test_asgi.py @@ -1,7 +1,9 @@ import unittest +import json import httpx from pwo import async_test -from bugis.core import BugisApp, HttpContext +from bugis.core import BugisApp, HttpContext, HttpMethod +from typing import Sequence class AsgiTest(unittest.TestCase): @@ -18,6 +20,39 @@ class AsgiTest(unittest.TestCase): print(chunk) await ctx.send_str(200, 'Hello World!') + @self.app.route(('/foo/bar',), HttpMethod.PUT, recursive=True) + async def handle_request(ctx: HttpContext) -> None: + async for chunk in ctx.request_body: + print(chunk) + await ctx.send_str(200, ctx.path) + + @self.app.route(('/foo/*',), HttpMethod.PUT, recursive=True) + async def handle_request(ctx: HttpContext, path: Sequence[str]) -> None: + async for chunk in ctx.request_body: + print(chunk) + await ctx.send_str(200, json.dumps(path)) + + @self.app.GET('/employee/${employee_id}') + async def handle_request(ctx: HttpContext, employee_id: str) -> None: + async for chunk in ctx.request_body: + print(chunk) + await ctx.send_str(200, employee_id) + + @self.app.GET('/square/${x:int}') + async def handle_request(ctx: HttpContext, x: int) -> None: + async for chunk in ctx.request_body: + print(chunk) + await ctx.send_str(200, str(x * x)) + + @self.app.GET('/department/${department_id:int}/employee/${employee_id:int}') + async def handle_request(ctx: HttpContext, department_id: int, employee_id: int) -> None: + async for chunk in ctx.request_body: + print(chunk) + await ctx.send_str(200, json.dumps({ + 'department_id': department_id, + 'employee_id': employee_id + })) + @async_test async def test_hello(self): transport = httpx.ASGITransport(app=self.app) @@ -38,3 +73,55 @@ class AsgiTest(unittest.TestCase): r = await client.get("/hello4") self.assertEqual(r.status_code, 404) self.assertTrue(len(r.text) == 0) + + @async_test + async def test_foo(self): + transport = httpx.ASGITransport(app=self.app) + + async with httpx.AsyncClient(transport=transport, base_url="http://127.0.0.1:80") as client: + r = await client.put("/foo/fizz/baz") + self.assertEqual(r.status_code, 200) + response = json.loads(r.text) + self.assertEqual(['fizz', 'baz'], response) + + @async_test + async def test_foo_bar(self): + transport = httpx.ASGITransport(app=self.app) + + async with httpx.AsyncClient(transport=transport, base_url="http://127.0.0.1:80") as client: + r = await client.put("/foo/bar/baz") + self.assertEqual(r.status_code, 200) + self.assertEqual('/foo/bar/baz', r.text) + + @async_test + async def test_employee(self): + transport = httpx.ASGITransport(app=self.app) + + async with httpx.AsyncClient(transport=transport, base_url="http://127.0.0.1:80") as client: + r = await client.get("/employee/101325") + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, '101325') + + @async_test + async def test_square(self): + transport = httpx.ASGITransport(app=self.app) + + async with httpx.AsyncClient(transport=transport, base_url="http://127.0.0.1:80") as client: + x = 30 + r = await client.get(f"/square/{x}") + self.assertEqual(r.status_code, 200) + self.assertEqual(r.text, str(x * x)) + + @async_test + async def test_department_employee(self): + transport = httpx.ASGITransport(app=self.app) + + async with httpx.AsyncClient(transport=transport, base_url="http://127.0.0.1:80") as client: + r = await client.get("department/189350/employee/101325") + self.assertEqual(r.status_code, 200) + response = json.loads(r.text) + self.assertEqual({ + 'department_id': 189350, + 'employee_id': 101325 + }, response) + diff --git a/core/tests/test_tree.py b/core/tests/test_tree.py index 6f0145c..d8a1501 100644 --- a/core/tests/test_tree.py +++ b/core/tests/test_tree.py @@ -32,13 +32,14 @@ class TreeTest(unittest.TestCase): class TestHandler(PathHandler): - def match(self, subpath: Sequence[str], method: HttpMethod) -> bool: - return True - def handle_request(self, ctx: HttpContext): pass - self.handlers = [TestHandler() for _ in range(10)] + @property + def recursive(self) -> bool: + return True + + self.handlers = [TestHandler() for _ in range(20)] routes: Tuple[Tuple[Tuple[str, ...], Optional[HttpMethod], PathHandler], ...] = ( (('home', 'something'), HttpMethod.GET, self.handlers[0]), @@ -49,6 +50,10 @@ class TreeTest(unittest.TestCase): (('home',), HttpMethod.GET, self.handlers[5]), (('home',), HttpMethod.POST, self.handlers[6]), (('home',), None, self.handlers[7]), + (('home', '*.md'), None, self.handlers[8]), + (('home', 'something', '*', 'blah', '*.md'), None, self.handlers[9]), + (('home', 'bar', '*'), None, self.handlers[10]), + ) for path, method, handler in routes: @@ -66,9 +71,13 @@ class TreeTest(unittest.TestCase): ('http://localhost:127.0.0.1:5432/home', HttpMethod.GET, 5), ('http://localhost:127.0.0.1:5432/home', HttpMethod.POST, 6), ('http://localhost:127.0.0.1:5432/home', HttpMethod.PUT, 7), + ('http://localhost:127.0.0.1:5432/home/README.md', HttpMethod.GET, 8), + ('http://localhost:127.0.0.1:5432/home/something/ciao/blah/README.md', HttpMethod.GET, 9), + ('http://localhost:127.0.0.1:5432/home/bar/ciao/blah/README.md', HttpMethod.GET, 10), ) for url, method, handler_num in cases: with self.subTest(f"{str(method)} {url}"): res = self.tree.get_handler(url, method) - self.assertIs(Maybe.of(handler_num).map(self.handlers.__getitem__).or_none(), res) + self.assertIs(Maybe.of(handler_num).map(self.handlers.__getitem__).or_none(), + Maybe.of_nullable(res).map(lambda it: it[0]).or_none())