added server support for request pipelining
Some checks failed
CI / build (push) Has been cancelled

This commit is contained in:
2025-03-07 11:39:17 +08:00
parent fc298de548
commit 59a12d6218
9 changed files with 169 additions and 55 deletions

View File

@@ -39,7 +39,6 @@ allprojects { subproject ->
modularity.inferModulePath = true
toolchain {
languageVersion = JavaLanguageVersion.of(21)
vendor = JvmVendorSpec.GRAAL_VM
}
}

View File

@@ -91,6 +91,10 @@ Provider<EnvelopeJarTask> envelopeJarTaskProvider = tasks.named(EnvelopePlugin.E
}
tasks.named(NativeImagePlugin.CONFIGURE_NATIVE_IMAGE_TASK_NAME, NativeImageConfigurationTask) {
toolchain {
languageVersion = JavaLanguageVersion.of(21)
vendor = JvmVendorSpec.GRAAL_VM
}
mainClass = "net.woggioni.rbcs.cli.graal.GraalNativeImageConfiguration"
classpath = project.files(
configurations.configureNativeImageRuntimeClasspath,

View File

@@ -56,7 +56,6 @@ import net.woggioni.rbcs.server.configuration.Serializer
import net.woggioni.rbcs.server.exception.ExceptionHandler
import net.woggioni.rbcs.server.handler.MaxRequestSizeHandler
import net.woggioni.rbcs.server.handler.ServerHandler
import net.woggioni.rbcs.server.handler.TraceHandler
import net.woggioni.rbcs.server.throttling.BucketManager
import net.woggioni.rbcs.server.throttling.ThrottlingHandler
import java.io.OutputStream
@@ -355,13 +354,12 @@ class RemoteBuildCacheServer(private val cfg: Configuration) {
val serverHandler = let {
val prefix = Path.of("/").resolve(Path.of(cfg.serverPath ?: "/"))
ServerHandler(prefix)
ServerHandler(prefix) {
cacheHandlerFactory.newHandler(cfg, ch.eventLoop(), channelFactory, datagramChannelFactory)
}
}
pipeline.addLast(eventExecutorGroup, ServerHandler.NAME, serverHandler)
pipeline.addLast(cacheHandlerFactory.newHandler(cfg, ch.eventLoop(), channelFactory, datagramChannelFactory))
pipeline.addLast(TraceHandler)
pipeline.addLast(ExceptionHandler)
pipeline.addLast(ExceptionHandler.NAME, ExceptionHandler)
}
override fun asyncClose() = cacheHandlerFactory.asyncClose()

View File

@@ -0,0 +1,4 @@
package net.woggioni.rbcs.server.event
class RequestCompletedEvent {
}

View File

@@ -27,6 +27,9 @@ import javax.net.ssl.SSLPeerUnverifiedException
@Sharable
object ExceptionHandler : ChannelDuplexHandler() {
val NAME : String = this::class.java.name
private val log = contextLogger()
private val NOT_AUTHORIZED: FullHttpResponse = DefaultFullHttpResponse(

View File

@@ -1,28 +1,79 @@
package net.woggioni.rbcs.server.handler
import io.netty.channel.ChannelHandler
import io.netty.channel.ChannelHandler.Sharable
import io.netty.channel.ChannelHandlerContext
import io.netty.channel.ChannelOutboundHandler
import io.netty.channel.ChannelPromise
import io.netty.channel.SimpleChannelInboundHandler
import io.netty.handler.codec.http.HttpContent
import io.netty.handler.codec.http.LastHttpContent
import net.woggioni.rbcs.api.message.CacheMessage.CacheValueNotFoundResponse
import net.woggioni.rbcs.api.message.CacheMessage.CacheContent
import net.woggioni.rbcs.api.message.CacheMessage.CachePutResponse
import net.woggioni.rbcs.api.message.CacheMessage.LastCacheContent
import java.net.SocketAddress
@Sharable
object CacheContentHandler : SimpleChannelInboundHandler<HttpContent>() {
val NAME = this::class.java.name
class CacheContentHandler(private val pairedHandler : ChannelHandler) : SimpleChannelInboundHandler<HttpContent>(), ChannelOutboundHandler {
private var requestFinished = false
override fun channelRead0(ctx: ChannelHandlerContext, msg: HttpContent) {
when(msg) {
is LastHttpContent -> {
ctx.fireChannelRead(LastCacheContent(msg.content().retain()))
ctx.pipeline().remove(this)
if(requestFinished) {
ctx.fireChannelRead(msg.retain())
} else {
when (msg) {
is LastHttpContent -> {
ctx.fireChannelRead(LastCacheContent(msg.content().retain()))
requestFinished = true
}
else -> ctx.fireChannelRead(CacheContent(msg.content().retain()))
}
else -> ctx.fireChannelRead(CacheContent(msg.content().retain()))
}
}
override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
super.exceptionCaught(ctx, cause)
}
override fun bind(ctx: ChannelHandlerContext, localAddress: SocketAddress, promise: ChannelPromise) {
ctx.bind(localAddress, promise)
}
override fun connect(
ctx: ChannelHandlerContext,
remoteAddress: SocketAddress,
localAddress: SocketAddress,
promise: ChannelPromise
) {
ctx.connect(remoteAddress, localAddress, promise)
}
override fun disconnect(ctx: ChannelHandlerContext, promise: ChannelPromise) {
ctx.disconnect(promise)
}
override fun close(ctx: ChannelHandlerContext, promise: ChannelPromise) {
ctx.close(promise)
}
override fun deregister(ctx: ChannelHandlerContext, promise: ChannelPromise) {
ctx.deregister(promise)
}
override fun read(ctx: ChannelHandlerContext) {
ctx.read()
}
override fun flush(ctx: ChannelHandlerContext) {
ctx.flush()
}
override fun write(ctx: ChannelHandlerContext, msg: Any, promise: ChannelPromise) {
ctx.write(msg, promise)
if(msg is LastCacheContent || msg is CachePutResponse || msg is CacheValueNotFoundResponse || msg is LastHttpContent) {
ctx.pipeline().remove(pairedHandler)
ctx.pipeline().remove(this)
}
}
}

View File

@@ -0,0 +1,56 @@
package net.woggioni.rbcs.server.handler
import io.netty.channel.ChannelHandlerContext
import io.netty.channel.ChannelInboundHandlerAdapter
import io.netty.channel.ChannelOutboundHandler
import io.netty.channel.ChannelPromise
import net.woggioni.rbcs.server.event.RequestCompletedEvent
import java.net.SocketAddress
class ResponseCapHandler : ChannelInboundHandlerAdapter(), ChannelOutboundHandler {
val bufferedMessages = mutableListOf<Any>()
override fun bind(ctx: ChannelHandlerContext, localAddress: SocketAddress, promise: ChannelPromise) {
ctx.bind(localAddress, promise)
}
override fun connect(
ctx: ChannelHandlerContext,
remoteAddress: SocketAddress,
localAddress: SocketAddress,
promise: ChannelPromise
) {
ctx.connect(remoteAddress, localAddress, promise)
}
override fun disconnect(ctx: ChannelHandlerContext, promise: ChannelPromise) {
ctx.disconnect(promise)
}
override fun close(ctx: ChannelHandlerContext, promise: ChannelPromise) {
ctx.close(promise)
}
override fun deregister(ctx: ChannelHandlerContext, promise: ChannelPromise) {
ctx.deregister(promise)
}
override fun read(ctx: ChannelHandlerContext) {
ctx.read()
}
override fun write(ctx: ChannelHandlerContext, msg: Any, promise: ChannelPromise) {
bufferedMessages.add(msg)
}
override fun flush(ctx: ChannelHandlerContext) {
}
override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) {
if(evt is RequestCompletedEvent) {
for(msg in bufferedMessages) ctx.write(msg)
ctx.flush()
ctx.pipeline().remove(this)
}
}
}

View File

@@ -1,6 +1,7 @@
package net.woggioni.rbcs.server.handler
import io.netty.channel.ChannelDuplexHandler
import io.netty.channel.ChannelHandler
import io.netty.channel.ChannelHandlerContext
import io.netty.channel.ChannelPromise
import io.netty.handler.codec.http.DefaultFullHttpResponse
@@ -15,6 +16,8 @@ import io.netty.handler.codec.http.HttpRequest
import io.netty.handler.codec.http.HttpResponseStatus
import io.netty.handler.codec.http.HttpUtil
import io.netty.handler.codec.http.HttpVersion
import io.netty.handler.codec.http.LastHttpContent
import net.woggioni.rbcs.api.CacheHandlerFactory
import net.woggioni.rbcs.api.CacheValueMetadata
import net.woggioni.rbcs.api.message.CacheMessage
import net.woggioni.rbcs.api.message.CacheMessage.CacheContent
@@ -27,19 +30,22 @@ import net.woggioni.rbcs.api.message.CacheMessage.LastCacheContent
import net.woggioni.rbcs.common.createLogger
import net.woggioni.rbcs.common.debug
import net.woggioni.rbcs.common.warn
import net.woggioni.rbcs.server.event.RequestCompletedEvent
import net.woggioni.rbcs.server.exception.ExceptionHandler
import java.nio.file.Path
import java.util.Locale
class ServerHandler(private val serverPrefix: Path) :
class ServerHandler(private val serverPrefix: Path, private val cacheHandlerSupplier : () -> ChannelHandler) :
ChannelDuplexHandler() {
companion object {
private val log = createLogger<ServerHandler>()
val NAME = this::class.java.name
val NAME = ServerHandler::class.java.name
}
private var httpVersion = HttpVersion.HTTP_1_1
private var keepAlive = true
private var pipelinedRequests = 0
private fun resetRequestMetadata() {
httpVersion = HttpVersion.HTTP_1_1
@@ -73,6 +79,8 @@ class ServerHandler(private val serverPrefix: Path) :
try {
when (msg) {
is CachePutResponse -> {
pipelinedRequests -= 1
ctx.fireUserEventTriggered(RequestCompletedEvent())
val response = DefaultFullHttpResponse(httpVersion, HttpResponseStatus.CREATED)
val keyBytes = msg.key.toByteArray(Charsets.UTF_8)
response.headers().apply {
@@ -88,6 +96,8 @@ class ServerHandler(private val serverPrefix: Path) :
}
is CacheValueNotFoundResponse -> {
pipelinedRequests -= 1
ctx.fireUserEventTriggered(RequestCompletedEvent())
val response = DefaultFullHttpResponse(httpVersion, HttpResponseStatus.NOT_FOUND)
response.headers()[HttpHeaderNames.CONTENT_LENGTH] = 0
setKeepAliveHeader(response.headers())
@@ -108,6 +118,8 @@ class ServerHandler(private val serverPrefix: Path) :
}
is LastCacheContent -> {
pipelinedRequests -= 1
ctx.fireUserEventTriggered(RequestCompletedEvent())
ctx.writeAndFlush(DefaultLastHttpContent(msg.content()))
}
@@ -127,6 +139,10 @@ class ServerHandler(private val serverPrefix: Path) :
} finally {
resetRequestMetadata()
}
} else if(msg is LastHttpContent) {
pipelinedRequests -= 1
ctx.fireUserEventTriggered(RequestCompletedEvent())
ctx.write(msg, promise)
} else super.write(ctx, msg, promise)
}
@@ -139,7 +155,13 @@ class ServerHandler(private val serverPrefix: Path) :
if (path.startsWith(serverPrefix)) {
val relativePath = serverPrefix.relativize(path)
val key = relativePath.toString()
ctx.pipeline().addAfter(NAME, CacheContentHandler.NAME, CacheContentHandler)
if(pipelinedRequests > 0) {
ctx.pipeline().addBefore(ExceptionHandler.NAME, null, ResponseCapHandler())
}
val cacheHandler = cacheHandlerSupplier()
ctx.pipeline().addBefore(ExceptionHandler.NAME, null, CacheContentHandler(cacheHandler))
ctx.pipeline().addBefore(ExceptionHandler.NAME, null, cacheHandler)
pipelinedRequests += 1
key.let(::CacheGetRequest)
.let(ctx::fireChannelRead)
?: ctx.channel().write(CacheValueNotFoundResponse())
@@ -159,7 +181,14 @@ class ServerHandler(private val serverPrefix: Path) :
log.debug(ctx) {
"Added value for key '$key' to build cache"
}
ctx.pipeline().addAfter(NAME, CacheContentHandler.NAME, CacheContentHandler)
if(pipelinedRequests > 0) {
ctx.pipeline().addBefore(ExceptionHandler.NAME, null, ResponseCapHandler())
}
val cacheHandler = cacheHandlerSupplier()
ctx.pipeline().addBefore(ExceptionHandler.NAME, null, CacheContentHandler(cacheHandler))
ctx.pipeline().addBefore(ExceptionHandler.NAME, null, cacheHandler)
pipelinedRequests += 1
path.fileName?.toString()
?.let {
val mimeType = HttpUtil.getMimeType(msg)?.toString()
@@ -176,6 +205,11 @@ class ServerHandler(private val serverPrefix: Path) :
ctx.writeAndFlush(response)
}
} else if (method == HttpMethod.TRACE) {
if(pipelinedRequests > 0) {
ctx.pipeline().addBefore(ExceptionHandler.NAME, null, ResponseCapHandler())
}
ctx.pipeline().addBefore(ExceptionHandler.NAME, null, TraceHandler)
pipelinedRequests += 1
super.channelRead(ctx, msg)
} else {
log.warn(ctx) {
@@ -187,42 +221,6 @@ class ServerHandler(private val serverPrefix: Path) :
}
}
data class ContentDisposition(val type: Type?, val fileName: String?) {
enum class Type {
attachment, `inline`;
companion object {
@JvmStatic
fun parse(maybeString: String?) = maybeString.let { s ->
try {
java.lang.Enum.valueOf(Type::class.java, s)
} catch (ex: IllegalArgumentException) {
null
}
}
}
}
companion object {
@JvmStatic
fun parse(contentDisposition: String) : ContentDisposition {
val parts = contentDisposition.split(";").dropLastWhile { it.isEmpty() }.toTypedArray()
val dispositionType = parts[0].trim { it <= ' ' }.let(Type::parse) // Get the type (e.g., attachment)
var filename: String? = null
for (i in 1..<parts.size) {
val part = parts[i].trim { it <= ' ' }
if (part.lowercase(Locale.getDefault()).startsWith("filename=")) {
filename = part.substring("filename=".length).trim { it <= ' ' }.replace("\"", "")
break
}
}
return ContentDisposition(dispositionType, filename)
}
}
}
override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
super.exceptionCaught(ctx, cause)
}

View File

@@ -42,6 +42,7 @@ object TraceHandler : ChannelInboundHandlerAdapter() {
}
is LastHttpContent -> {
ctx.writeAndFlush(msg)
ctx.pipeline().remove(this)
}
is HttpContent -> ctx.writeAndFlush(msg)
else -> super.channelRead(ctx, msg)