Compare commits

2 Commits

Author SHA1 Message Date
8f0320f262 Fixed mypy 2024-11-13 12:12:49 +08:00
ee6e645cc1 tmp 2024-11-13 11:02:13 +08:00
11 changed files with 500 additions and 170 deletions

View File

@@ -1,7 +1,8 @@
from ._app import BugisApp from ._app import BugisApp
from ._http_method import HttpMethod from ._http_method import HttpMethod
from ._http_context import HttpContext from ._http_context import HttpContext
from ._tree import Tree, PathHandler, PathIterator from ._tree import Tree, PathIterator
from ._path_handler import PathHandler
__all__ = [ __all__ = [

View File

@@ -2,16 +2,17 @@ from abc import ABC, abstractmethod
from asyncio import Queue, AbstractEventLoop from asyncio import Queue, AbstractEventLoop
from asyncio import get_running_loop from asyncio import get_running_loop
from logging import getLogger from logging import getLogger
from typing import Callable, Awaitable, Any, Mapping, Sequence, Optional from typing import Callable, Awaitable, Any, Mapping, Sequence, Optional, Unpack, Tuple
from pwo import Maybe, AsyncQueueIterator from pwo import Maybe, AsyncQueueIterator
from ._http_context import HttpContext from ._http_context import HttpContext
from ._http_method import HttpMethod from ._http_method import HttpMethod
from ._types import StrOrStrings
try: try:
from ._rsgi import RsgiContext from ._rsgi import RsgiContext
from granian._granian import RSGIHTTPProtocol, RSGIHTTPScope # type: ignore from granian._granian import RSGIHTTPProtocol, RSGIHTTPScope # type: ignore
except ImportError: except ImportError:
pass pass
@@ -21,11 +22,12 @@ from ._types.asgi import LifespanScope, HTTPScope as ASGIHTTPScope, WebSocketSco
log = getLogger(__name__) log = getLogger(__name__)
type HttpHandler = Callable[[HttpContext], Awaitable[None]] type HttpHandler = Callable[[HttpContext, Unpack[Any]], Awaitable[None]]
class AbstractBugisApp(ABC): class AbstractBugisApp(ABC):
async def __call__(self, async def __call__(self,
scope: ASGIHTTPScope|WebSocketScope|LifespanScope, scope: ASGIHTTPScope | WebSocketScope | LifespanScope,
receive: Callable[[], Awaitable[Any]], receive: Callable[[], Awaitable[Any]],
send: Callable[[Mapping[str, Any]], Awaitable[None]]) -> None: send: Callable[[Mapping[str, Any]], Awaitable[None]]) -> None:
loop = get_running_loop() loop = get_running_loop()
@@ -84,43 +86,58 @@ class BugisApp(AbstractBugisApp):
self._tree = Tree() self._tree = Tree()
async def handle_request(self, ctx: HttpContext) -> None: async def handle_request(self, ctx: HttpContext) -> None:
handler = self._tree.get_handler(ctx.path, ctx.method) result = self._tree.get_handler(ctx.path, ctx.method)
if handler is not None: if result is not None:
await handler.handle_request(ctx) handler, captured = result
await handler.handle_request(ctx, captured)
else: else:
await ctx.send_empty(404) await ctx.send_empty(404)
pass pass
def route(self, def route(self,
path: str, paths: StrOrStrings,
methods: Optional[Sequence[HttpMethod]] = None) -> Callable[[HttpHandler], HttpHandler]: methods: Optional[HttpMethod | Sequence[HttpMethod]] = None,
recursive: bool = False) -> Callable[[HttpHandler], HttpHandler]:
def wrapped(handler: HttpHandler) -> HttpHandler: def wrapped(handler: HttpHandler) -> HttpHandler:
if methods is not None: nonlocal methods
for method in methods: nonlocal paths
self._tree.register(path, method, handler) _methods: Tuple[Optional[HttpMethod], ...]
if methods is None:
_methods = (None,)
elif isinstance(methods, HttpMethod):
_methods = (methods,)
else: else:
self._tree.register(path, None, handler) _methods = tuple(methods)
_paths: Tuple[str, ...]
if isinstance(paths, str):
_paths = (paths,)
else:
_paths = tuple(paths)
for method in _methods:
for path in _paths:
self._tree.register(path, method, handler, recursive)
return handler return handler
return wrapped return wrapped
def GET(self, path: str) -> Callable[[HttpHandler], HttpHandler]: def GET(self, path: str, recursive: bool = False) -> Callable[[HttpHandler], HttpHandler]:
return self.route(path, (HttpMethod.GET,)) return self.route(path, (HttpMethod.GET,), recursive)
def POST(self, path: str) -> Callable[[HttpHandler], HttpHandler]: def POST(self, path: str, recursive: bool = False) -> Callable[[HttpHandler], HttpHandler]:
return self.route(path, (HttpMethod.POST,)) return self.route(path, (HttpMethod.POST,), recursive)
def PUT(self, path: str) -> Callable[[HttpHandler], HttpHandler]: def PUT(self, path: str, recursive: bool = False) -> Callable[[HttpHandler], HttpHandler]:
return self.route(path, (HttpMethod.PUT,)) return self.route(path, (HttpMethod.PUT,), recursive)
def DELETE(self, path: str) -> Callable[[HttpHandler], HttpHandler]: def DELETE(self, path: str, recursive: bool = False) -> Callable[[HttpHandler], HttpHandler]:
return self.route(path, (HttpMethod.DELETE,)) return self.route(path, (HttpMethod.DELETE,), recursive)
def OPTIONS(self, path: str) -> Callable[[HttpHandler], HttpHandler]: def OPTIONS(self, path: str, recursive: bool = False) -> Callable[[HttpHandler], HttpHandler]:
return self.route(path, (HttpMethod.OPTIONS,)) return self.route(path, (HttpMethod.OPTIONS,), recursive)
def HEAD(self, path: str) -> Callable[[HttpHandler], HttpHandler]: def HEAD(self, path: str, recursive: bool = False) -> Callable[[HttpHandler], HttpHandler]:
return self.route(path, (HttpMethod.HEAD,)) return self.route(path, (HttpMethod.HEAD,), recursive)
def PATCH(self, path: str) -> Callable[[HttpHandler], HttpHandler]: def PATCH(self, path: str, recursive: bool = False) -> Callable[[HttpHandler], HttpHandler]:
return self.route(path, (HttpMethod.PATCH,)) return self.route(path, (HttpMethod.PATCH,), recursive)

View File

@@ -0,0 +1,12 @@
from dataclasses import dataclass
from typing import (
Optional,
Dict,
List,
)
from ._types import NodeType
from ._path_handler import PathHandler
from ._path_matcher import PathMatcher

View File

@@ -0,0 +1,33 @@
from abc import ABC, abstractmethod
from typing import (
Sequence,
Dict,
Optional
)
from dataclasses import dataclass, field
from ._http_context import HttpContext
@dataclass
class Matches:
kwargs: Dict[str, str] = field(default_factory=dict)
path: Optional[Sequence[str]] = None
unmatched_paths: Sequence[str] = field(default_factory=list)
class PathHandler(ABC):
@abstractmethod
async def handle_request(self, ctx: HttpContext, captured: Matches) -> None:
pass
@property
@abstractmethod
def recursive(self) -> bool:
raise NotImplementedError()
type PathHandlers = (PathHandler | Sequence[PathHandler])

View File

@@ -0,0 +1,97 @@
from fnmatch import fnmatch
from abc import ABC, abstractmethod
from typing import Optional, Sequence, Dict, List, Union
from dataclasses import dataclass
from ._path_handler import PathHandler
from ._types import NodeType, PathMatcherResult
@dataclass
class Node:
key: NodeType
parent: Optional[Union['Node', 'PathMatcher']]
children: Dict[NodeType, 'Node']
handlers: List[PathHandler]
path_matchers: List['PathMatcher']
class PathMatcher(ABC):
parent: Optional[Union['Node', 'PathMatcher']]
children: Dict[NodeType, Node]
handlers: List[PathHandler]
path_matchers: List['PathMatcher']
def __init__(self,
parent: Optional[Union['Node', 'PathMatcher']],
children: Dict[NodeType, Node],
handlers: List[PathHandler],
path_matchers: List['PathMatcher']
):
self.parent = parent
self.children = children
self.handlers = handlers
self.path_matchers = path_matchers
@abstractmethod
def match(self, path: Sequence[str]) -> Optional[PathMatcherResult]:
pass
class StrMatcher(PathMatcher):
name: str
def __init__(self,
name: str,
parent: Optional[Node | PathMatcher],
children: Dict[NodeType, Node],
handlers: List[PathHandler],
path_matchers: List[PathMatcher],
):
super().__init__(parent, children, handlers, path_matchers)
self.name = name
def match(self, path: Sequence[str]) -> Optional[PathMatcherResult]:
if len(path):
return {self.name: path[0]}
else:
return None
class IntMatcher(PathMatcher):
name: str
def __init__(self,
name: str,
parent: Optional[Node | PathMatcher],
children: Dict[NodeType, Node],
handlers: List[PathHandler],
path_matchers: List[PathMatcher],
):
super().__init__(parent, children, handlers, path_matchers)
self.name = name
def match(self, path: Sequence[str]) -> Optional[PathMatcherResult]:
if len(path) > 0:
try:
return {self.name: int(path[0])}
except ValueError:
return None
else:
return None
class GlobMatcher(PathMatcher):
pattern: str
def __init__(self,
pattern: str,
parent: Optional[Node | PathMatcher],
children: Dict[NodeType, Node],
handlers: List[PathHandler],
path_matchers: List[PathMatcher],
):
super().__init__(parent, children, handlers, path_matchers)
self.pattern = pattern
def match(self, path: Sequence[str]) -> Optional[PathMatcherResult]:
return path if fnmatch('/'.join(path), self.pattern) else None

View File

@@ -14,8 +14,7 @@ from typing import (
cast cast
) )
from granian.rsgi import Scope from granian._granian import RSGIHTTPProtocol, RSGIHTTPScope
from granian._granian import RSGIHTTPProtocol
from pwo import Maybe from pwo import Maybe
from ._http_context import HttpContext from ._http_context import HttpContext
@@ -34,7 +33,7 @@ class RsgiContext(HttpContext):
request_body: AsyncIterator[bytes] request_body: AsyncIterator[bytes]
head = Optional[Tuple[int, Sequence[Tuple[str, str]]]] head = Optional[Tuple[int, Sequence[Tuple[str, str]]]]
def __init__(self, scope: Scope, protocol: RSGIHTTPProtocol): def __init__(self, scope: RSGIHTTPScope, protocol: RSGIHTTPProtocol):
self.scheme = scope.scheme self.scheme = scope.scheme
self.path = scope.path self.path = scope.path
self.method = HttpMethod(scope.method) self.method = HttpMethod(scope.method)

View File

@@ -1,64 +1,73 @@
from typing import Sequence, Dict, Awaitable, Callable, Optional, Generator, Self, List
from ._http_method import HttpMethod
from ._http_context import HttpContext
from dataclasses import dataclass
from abc import ABC, abstractmethod
from itertools import chain from itertools import chain
from typing import (
Sequence,
Awaitable,
Callable,
Optional,
Generator,
Self,
List,
Tuple,
Mapping,
Any,
)
from typing_extensions import Unpack
from urllib.parse import urlparse from urllib.parse import urlparse
from pwo import Maybe
type NodeType = (str | HttpMethod) from pwo import Maybe, index_of_with_escape
type PathHandlers = (PathHandler | Sequence[PathHandler]) from ._http_context import HttpContext
from ._http_method import HttpMethod
from ._path_handler import PathHandler
class PathHandler(ABC): from ._path_matcher import PathMatcher, IntMatcher, GlobMatcher, StrMatcher, Node
from ._types import NodeType, Matches
@abstractmethod
def match(self, subpath: Sequence[str], method: HttpMethod) -> bool:
raise NotImplementedError()
@abstractmethod
async def handle_request(self, ctx: HttpContext) -> None:
pass
@dataclass
class Node:
key: NodeType
parent: Optional['Node']
children: Dict[NodeType, 'Node']
handlers: Sequence[PathHandler]
class Tree: class Tree:
def __init__(self) -> None: def __init__(self) -> None:
self.root = Node('/', None, {}, []) self.root = Node('/', None, {}, [], [])
def search(self, path: Generator[str, None, None], method: HttpMethod) -> Optional[Node]: def search(self, path: Generator[str, None, None], method: HttpMethod) \
lineage: Generator[NodeType, None, None] = (it for it in chain(path, (method,))) -> Optional[Tuple[Node | PathMatcher, Matches]]:
result = self.root paths: List[str] = list(path)
it = iter(lineage) result: Node | PathMatcher = self.root
matches = Matches()
it, i = iter((it for it in paths)), -1
while True: while True:
node = result node = result
leaf = next(it, None) leaf, i = next(it, None), i + 1
if leaf is None: if leaf is None:
break break
child = node.children.get(leaf) child = node.children.get(leaf)
if child is None: if child is None and isinstance(leaf, str):
break for matcher in node.path_matchers:
match = matcher.match(paths[i:])
if match is not None:
if isinstance(match, Mapping):
matches.kwargs.update(match)
elif isinstance(match, Sequence):
matches.path = match
result = matcher
break
else:
break
else: else:
result = child result = child
return None if result == self.root else result child = result.children.get(method)
if child is not None:
result = child
matches.unmatched_paths = paths[i:]
return None if result == self.root else (result, matches)
def add(self, path: Generator[str, None, None], method: Optional[HttpMethod], *path_handlers: PathHandler) -> Node: def add(self, path: Generator[str, None, None], method: Optional[HttpMethod], *path_handlers: PathHandler) -> Node | PathMatcher:
lineage: Generator[NodeType, None, None] = (it for it in lineage: Generator[NodeType, None, None] = (it for it in
chain(path, chain(path,
Maybe.of_nullable(method) Maybe.of_nullable(method)
.map(lambda it: [it]) .map(lambda it: [it])
.or_else([]))) .or_else([])))
result = self.root result: Node | PathMatcher = self.root
it = iter(lineage) it = iter(lineage)
while True: while True:
@@ -73,49 +82,92 @@ class Tree:
result = child result = child
key = leaf key = leaf
while key is not None: while key is not None:
new_node = Node(key=key, parent=result, children={}, handlers=[]) new_node = self.parse(key, result)
result.children[key] = new_node if isinstance(new_node, Node):
result.children[key] = new_node
else:
result.path_matchers.append(new_node)
result = new_node result = new_node
key = next(it, None) key = next(it, None)
result.handlers = tuple(chain(result.handlers, path_handlers)) result.handlers = list(chain(result.handlers, path_handlers))
return result return result
def register(self, def register(self,
path: str, path: str,
method: Optional[HttpMethod], method: Optional[HttpMethod],
callback: Callable[[HttpContext], Awaitable[None]]) -> None: callback: Callable[[HttpContext, Unpack[Any]], Awaitable[None]],
recursive: bool) -> None:
class Handler(PathHandler): class Handler(PathHandler):
def match(self, subpath: Sequence[str], method: HttpMethod) -> bool: async def handle_request(self, ctx: HttpContext, captured: Matches) -> None:
return len(subpath) == 0 args = Maybe.of_nullable(captured.path).map(lambda it: [it]).or_else([])
await callback(ctx, *args, **captured.kwargs)
@property
def recursive(self) -> bool:
return recursive
async def handle_request(self, ctx: HttpContext) -> None:
await callback(ctx)
handler = Handler() handler = Handler()
self.add((p for p in PathIterator(path)), method, handler) self.add((p for p in PathIterator(path)), method, handler)
def find_node(self, path: Generator[str, None, None], method: HttpMethod = HttpMethod.GET) -> Optional[Node]: def find_node(self, path: Generator[str, None, None], method: HttpMethod = HttpMethod.GET) \
-> Optional[Tuple[Node | PathMatcher, Matches]]:
return (Maybe.of_nullable(self.search(path, method)) return (Maybe.of_nullable(self.search(path, method))
.filter(lambda it: len(it.handlers) > 0) .filter(lambda it: len(it[0].handlers) > 0)
.or_none()) .or_none())
def get_handler(self, url: str, method: HttpMethod = HttpMethod.GET) -> Optional[PathHandler]: def get_handler(self, url: str, method: HttpMethod = HttpMethod.GET) \
-> Optional[Tuple[PathHandler, Matches]]:
path = urlparse(url).path path = urlparse(url).path
node = self.find_node((p for p in PathIterator(path)), method) result: Optional[Tuple[Node | PathMatcher, Matches]] = self.find_node((p for p in PathIterator(path)), method)
if node is None: if result is None:
return None return None
requested = (p for p in PathIterator(path)) node, captured = result
found = reversed([n.key for n in NodeAncestryIterator(node) if n.key != '/']) # requested = (p for p in PathIterator(path))
unmatched: List[str] = [] # found = reversed([n for n in NodeAncestryIterator(node) if n != self.root])
for r, f in zip(requested, found): # unmatched: List[str] = []
if f is None: # for r, f in zip(requested, found):
unmatched.append(r) # if f is None:
# unmatched.append(r)
for handler in node.handlers: for handler in node.handlers:
if handler.match(unmatched, method): if len(captured.unmatched_paths) == 0:
return handler return handler, captured
elif handler.recursive:
return handler, captured
# if handler.match(unmatched, method):
# return (handler, unmatched)
return None return None
def parse(self, leaf: str, parent: Optional[Node | PathMatcher]) -> Node | PathMatcher:
start = 0
result = index_of_with_escape(leaf, '${', '\\', 0)
if result >= 0:
start = result + 2
end = leaf.index('}', start + 2)
definition = leaf[start:end]
try:
colon = definition.index(':')
except ValueError:
colon = None
if colon is None:
key = definition
kind = 'str'
else:
key = definition[:colon]
kind = definition[colon+1:] if colon is not None else 'str'
if kind == 'str':
return StrMatcher(name=key, parent=parent, children={}, handlers=[], path_matchers=[])
elif kind == 'int':
return IntMatcher(name=key, parent=parent, children={}, handlers=[], path_matchers=[])
else:
raise ValueError(f"Unknown kind: '{kind}'")
result = index_of_with_escape(leaf, '*', '\\', 0)
if result >= 0:
return GlobMatcher(pattern=leaf, parent=parent, children={}, handlers=[], path_matchers=[])
else:
return Node(key=leaf, parent=parent, children={}, handlers=[], path_matchers=[])
class PathIterator: class PathIterator:
path: str path: str
@@ -154,7 +206,7 @@ class PathIterator:
class NodeAncestryIterator: class NodeAncestryIterator:
node: Node node: Node | PathMatcher
def __init__(self, node: Node): def __init__(self, node: Node):
self.node = node self.node = node
@@ -162,7 +214,7 @@ class NodeAncestryIterator:
def __iter__(self) -> Self: def __iter__(self) -> Self:
return self return self
def __next__(self) -> Node: def __next__(self) -> Node | PathMatcher:
parent = self.node.parent parent = self.node.parent
if parent is None: if parent is None:
raise StopIteration() raise StopIteration()

View File

@@ -1,70 +0,0 @@
from typing import (
Sequence,
TypedDict,
Literal,
Iterable,
Tuple,
Optional,
NotRequired,
Dict,
Any,
Union
)
type StrOrStrings = (str | Sequence[str])
class ASGIVersions(TypedDict):
spec_version: str
version: Union[Literal["2.0"], Literal["3.0"]]
class HTTPScope(TypedDict):
type: Literal["http"]
asgi: ASGIVersions
http_version: str
method: str
scheme: str
path: str
raw_path: bytes
query_string: bytes
root_path: str
headers: Iterable[Tuple[bytes, bytes]]
client: Optional[Tuple[str, int]]
server: Optional[Tuple[str, Optional[int]]]
state: NotRequired[Dict[str, Any]]
extensions: Optional[Dict[str, Dict[object, object]]]
class WebSocketScope(TypedDict):
type: Literal["websocket"]
asgi: ASGIVersions
http_version: str
scheme: str
path: str
raw_path: bytes
query_string: bytes
root_path: str
headers: Iterable[Tuple[bytes, bytes]]
client: Optional[Tuple[str, int]]
server: Optional[Tuple[str, Optional[int]]]
subprotocols: Iterable[str]
state: NotRequired[Dict[str, Any]]
extensions: Optional[Dict[str, Dict[object, object]]]
class LifespanScope(TypedDict):
type: Literal["lifespan"]
asgi: ASGIVersions
state: NotRequired[Dict[str, Any]]
class RSGI:
class Scope(TypedDict):
proto: Literal['http'] = 'http'
rsgi_version: str
http_version: str
server: str
client: str
scheme: str
method: str
path: str
query_string: str
headers: Mapping[str, str]
authority: Optional[str]

View File

@@ -1,3 +1,96 @@
from typing import Sequence from typing import (
TypedDict,
Literal,
Iterable,
Tuple,
Optional,
NotRequired,
Dict,
Any,
Union,
Mapping,
Sequence
)
from bugis.core._http_method import HttpMethod
from bugis.core._path_handler import PathHandler, Matches
type StrOrStrings = (str | Sequence[str]) type StrOrStrings = (str | Sequence[str])
type NodeType = (str | HttpMethod)
type PathMatcherResult = Mapping[str, Any] | Sequence[str]
class ASGIVersions(TypedDict):
spec_version: str
version: Union[Literal["2.0"], Literal["3.0"]]
class HTTPScope(TypedDict):
type: Literal["http"]
asgi: ASGIVersions
http_version: str
method: str
scheme: str
path: str
raw_path: bytes
query_string: bytes
root_path: str
headers: Iterable[Tuple[bytes, bytes]]
client: Optional[Tuple[str, int]]
server: Optional[Tuple[str, Optional[int]]]
state: NotRequired[Dict[str, Any]]
extensions: Optional[Dict[str, Dict[object, object]]]
class WebSocketScope(TypedDict):
type: Literal["websocket"]
asgi: ASGIVersions
http_version: str
scheme: str
path: str
raw_path: bytes
query_string: bytes
root_path: str
headers: Iterable[Tuple[bytes, bytes]]
client: Optional[Tuple[str, int]]
server: Optional[Tuple[str, Optional[int]]]
subprotocols: Iterable[str]
state: NotRequired[Dict[str, Any]]
extensions: Optional[Dict[str, Dict[object, object]]]
class LifespanScope(TypedDict):
type: Literal["lifespan"]
asgi: ASGIVersions
state: NotRequired[Dict[str, Any]]
class RSGI:
class Scope(TypedDict):
proto: Literal['http'] # = 'http'
rsgi_version: str
http_version: str
server: str
client: str
scheme: str
method: str
path: str
query_string: str
headers: Mapping[str, str]
authority: Optional[str]
__all__ = [
'HttpMethod',
'HTTPScope',
'LifespanScope',
'RSGI',
'ASGIVersions',
'WebSocketScope',
'PathHandler',
'NodeType',
'Matches'
]

View File

@@ -1,7 +1,9 @@
import unittest import unittest
import json
import httpx import httpx
from pwo import async_test from pwo import async_test
from bugis.core import BugisApp, HttpContext from bugis.core import BugisApp, HttpContext, HttpMethod
from typing import Sequence
class AsgiTest(unittest.TestCase): class AsgiTest(unittest.TestCase):
@@ -18,6 +20,39 @@ class AsgiTest(unittest.TestCase):
print(chunk) print(chunk)
await ctx.send_str(200, 'Hello World!') await ctx.send_str(200, 'Hello World!')
@self.app.route(('/foo/bar',), HttpMethod.PUT, recursive=True)
async def handle_request(ctx: HttpContext) -> None:
async for chunk in ctx.request_body:
print(chunk)
await ctx.send_str(200, ctx.path)
@self.app.route(('/foo/*',), HttpMethod.PUT, recursive=True)
async def handle_request(ctx: HttpContext, path: Sequence[str]) -> None:
async for chunk in ctx.request_body:
print(chunk)
await ctx.send_str(200, json.dumps(path))
@self.app.GET('/employee/${employee_id}')
async def handle_request(ctx: HttpContext, employee_id: str) -> None:
async for chunk in ctx.request_body:
print(chunk)
await ctx.send_str(200, employee_id)
@self.app.GET('/square/${x:int}')
async def handle_request(ctx: HttpContext, x: int) -> None:
async for chunk in ctx.request_body:
print(chunk)
await ctx.send_str(200, str(x * x))
@self.app.GET('/department/${department_id:int}/employee/${employee_id:int}')
async def handle_request(ctx: HttpContext, department_id: int, employee_id: int) -> None:
async for chunk in ctx.request_body:
print(chunk)
await ctx.send_str(200, json.dumps({
'department_id': department_id,
'employee_id': employee_id
}))
@async_test @async_test
async def test_hello(self): async def test_hello(self):
transport = httpx.ASGITransport(app=self.app) transport = httpx.ASGITransport(app=self.app)
@@ -38,3 +73,55 @@ class AsgiTest(unittest.TestCase):
r = await client.get("/hello4") r = await client.get("/hello4")
self.assertEqual(r.status_code, 404) self.assertEqual(r.status_code, 404)
self.assertTrue(len(r.text) == 0) self.assertTrue(len(r.text) == 0)
@async_test
async def test_foo(self):
transport = httpx.ASGITransport(app=self.app)
async with httpx.AsyncClient(transport=transport, base_url="http://127.0.0.1:80") as client:
r = await client.put("/foo/fizz/baz")
self.assertEqual(r.status_code, 200)
response = json.loads(r.text)
self.assertEqual(['fizz', 'baz'], response)
@async_test
async def test_foo_bar(self):
transport = httpx.ASGITransport(app=self.app)
async with httpx.AsyncClient(transport=transport, base_url="http://127.0.0.1:80") as client:
r = await client.put("/foo/bar/baz")
self.assertEqual(r.status_code, 200)
self.assertEqual('/foo/bar/baz', r.text)
@async_test
async def test_employee(self):
transport = httpx.ASGITransport(app=self.app)
async with httpx.AsyncClient(transport=transport, base_url="http://127.0.0.1:80") as client:
r = await client.get("/employee/101325")
self.assertEqual(r.status_code, 200)
self.assertEqual(r.text, '101325')
@async_test
async def test_square(self):
transport = httpx.ASGITransport(app=self.app)
async with httpx.AsyncClient(transport=transport, base_url="http://127.0.0.1:80") as client:
x = 30
r = await client.get(f"/square/{x}")
self.assertEqual(r.status_code, 200)
self.assertEqual(r.text, str(x * x))
@async_test
async def test_department_employee(self):
transport = httpx.ASGITransport(app=self.app)
async with httpx.AsyncClient(transport=transport, base_url="http://127.0.0.1:80") as client:
r = await client.get("department/189350/employee/101325")
self.assertEqual(r.status_code, 200)
response = json.loads(r.text)
self.assertEqual({
'department_id': 189350,
'employee_id': 101325
}, response)

View File

@@ -32,13 +32,14 @@ class TreeTest(unittest.TestCase):
class TestHandler(PathHandler): class TestHandler(PathHandler):
def match(self, subpath: Sequence[str], method: HttpMethod) -> bool:
return True
def handle_request(self, ctx: HttpContext): def handle_request(self, ctx: HttpContext):
pass pass
self.handlers = [TestHandler() for _ in range(10)] @property
def recursive(self) -> bool:
return True
self.handlers = [TestHandler() for _ in range(20)]
routes: Tuple[Tuple[Tuple[str, ...], Optional[HttpMethod], PathHandler], ...] = ( routes: Tuple[Tuple[Tuple[str, ...], Optional[HttpMethod], PathHandler], ...] = (
(('home', 'something'), HttpMethod.GET, self.handlers[0]), (('home', 'something'), HttpMethod.GET, self.handlers[0]),
@@ -49,6 +50,10 @@ class TreeTest(unittest.TestCase):
(('home',), HttpMethod.GET, self.handlers[5]), (('home',), HttpMethod.GET, self.handlers[5]),
(('home',), HttpMethod.POST, self.handlers[6]), (('home',), HttpMethod.POST, self.handlers[6]),
(('home',), None, self.handlers[7]), (('home',), None, self.handlers[7]),
(('home', '*.md'), None, self.handlers[8]),
(('home', 'something', '*', 'blah', '*.md'), None, self.handlers[9]),
(('home', 'bar', '*'), None, self.handlers[10]),
) )
for path, method, handler in routes: for path, method, handler in routes:
@@ -66,9 +71,13 @@ class TreeTest(unittest.TestCase):
('http://localhost:127.0.0.1:5432/home', HttpMethod.GET, 5), ('http://localhost:127.0.0.1:5432/home', HttpMethod.GET, 5),
('http://localhost:127.0.0.1:5432/home', HttpMethod.POST, 6), ('http://localhost:127.0.0.1:5432/home', HttpMethod.POST, 6),
('http://localhost:127.0.0.1:5432/home', HttpMethod.PUT, 7), ('http://localhost:127.0.0.1:5432/home', HttpMethod.PUT, 7),
('http://localhost:127.0.0.1:5432/home/README.md', HttpMethod.GET, 8),
('http://localhost:127.0.0.1:5432/home/something/ciao/blah/README.md', HttpMethod.GET, 9),
('http://localhost:127.0.0.1:5432/home/bar/ciao/blah/README.md', HttpMethod.GET, 10),
) )
for url, method, handler_num in cases: for url, method, handler_num in cases:
with self.subTest(f"{str(method)} {url}"): with self.subTest(f"{str(method)} {url}"):
res = self.tree.get_handler(url, method) res = self.tree.get_handler(url, method)
self.assertIs(Maybe.of(handler_num).map(self.handlers.__getitem__).or_none(), res) self.assertIs(Maybe.of(handler_num).map(self.handlers.__getitem__).or_none(),
Maybe.of_nullable(res).map(lambda it: it[0]).or_none())