fixed bug with throttling handler when requests are delayed

This commit is contained in:
2025-06-13 13:49:13 +08:00
parent 3774ab8ef0
commit 206bcd6319
24 changed files with 350 additions and 130 deletions

View File

@@ -2,7 +2,7 @@ org.gradle.configuration-cache=false
org.gradle.parallel=true org.gradle.parallel=true
org.gradle.caching=true org.gradle.caching=true
rbcs.version = 0.3.0-SNAPSHOT rbcs.version = 0.3.1
lys.version = 2025.06.10 lys.version = 2025.06.10

View File

@@ -20,6 +20,8 @@ public class Configuration {
@NonNull @NonNull
EventExecutor eventExecutor; EventExecutor eventExecutor;
@NonNull @NonNull
RateLimiter rateLimiter;
@NonNull
Connection connection; Connection connection;
Map<String, User> users; Map<String, User> users;
Map<String, Group> groups; Map<String, Group> groups;
@@ -27,6 +29,13 @@ public class Configuration {
Authentication authentication; Authentication authentication;
Tls tls; Tls tls;
@Value
public static class RateLimiter {
boolean delayRequest;
int messageBufferSize;
int maxQueuedMessages;
}
@Value @Value
public static class EventExecutor { public static class EventExecutor {
boolean useVirtualThreads; boolean useVirtualThreads;
@@ -133,6 +142,7 @@ public class Configuration {
int incomingConnectionsBacklogSize, int incomingConnectionsBacklogSize,
String serverPath, String serverPath,
EventExecutor eventExecutor, EventExecutor eventExecutor,
RateLimiter rateLimiter,
Connection connection, Connection connection,
Map<String, User> users, Map<String, User> users,
Map<String, Group> groups, Map<String, Group> groups,
@@ -146,6 +156,7 @@ public class Configuration {
incomingConnectionsBacklogSize, incomingConnectionsBacklogSize,
serverPath != null && !serverPath.isEmpty() && !serverPath.equals("/") ? serverPath : null, serverPath != null && !serverPath.isEmpty() && !serverPath.equals("/") ? serverPath : null,
eventExecutor, eventExecutor,
rateLimiter,
connection, connection,
users, users,
groups, groups,

View File

@@ -14,17 +14,26 @@ public sealed interface CacheMessage {
private final String key; private final String key;
} }
@Getter
@RequiredArgsConstructor
abstract sealed class CacheGetResponse implements CacheMessage { abstract sealed class CacheGetResponse implements CacheMessage {
private final String key;
} }
@Getter @Getter
@RequiredArgsConstructor
final class CacheValueFoundResponse extends CacheGetResponse { final class CacheValueFoundResponse extends CacheGetResponse {
private final String key;
private final CacheValueMetadata metadata; private final CacheValueMetadata metadata;
public CacheValueFoundResponse(String key, CacheValueMetadata metadata) {
super(key);
this.metadata = metadata;
}
} }
final class CacheValueNotFoundResponse extends CacheGetResponse { final class CacheValueNotFoundResponse extends CacheGetResponse {
public CacheValueNotFoundResponse(String key) {
super(key);
}
} }
@Getter @Getter

View File

@@ -112,7 +112,7 @@ tasks.named(NativeImagePlugin.CONFIGURE_NATIVE_IMAGE_TASK_NAME, NativeImageConfi
nativeImage { nativeImage {
toolchain { toolchain {
languageVersion = JavaLanguageVersion.of(23) languageVersion = JavaLanguageVersion.of(24)
vendor = JvmVendorSpec.GRAAL_VM vendor = JvmVendorSpec.GRAAL_VM
} }
mainClass = mainClassName mainClass = mainClassName

View File

@@ -99,6 +99,9 @@ object GraalNativeImageConfiguration {
100, 100,
null, null,
Configuration.EventExecutor(true), Configuration.EventExecutor(true),
Configuration.RateLimiter(
false, 0x100000, 10
),
Configuration.Connection( Configuration.Connection(
Duration.ofSeconds(10), Duration.ofSeconds(10),
Duration.ofSeconds(15), Duration.ofSeconds(15),

View File

@@ -101,6 +101,7 @@ class BenchmarkCommand : RbcsCommand() {
"Starting retrieval" "Starting retrieval"
} }
if (entries.isNotEmpty()) { if (entries.isNotEmpty()) {
val errorCounter = AtomicLong(0)
val completionCounter = AtomicLong(0) val completionCounter = AtomicLong(0)
val semaphore = Semaphore(profile.maxConnections * 5) val semaphore = Semaphore(profile.maxConnections * 5)
val start = Instant.now() val start = Instant.now()
@@ -109,14 +110,20 @@ class BenchmarkCommand : RbcsCommand() {
if (it.hasNext()) { if (it.hasNext()) {
val entry = it.next() val entry = it.next()
semaphore.acquire() semaphore.acquire()
val future = client.get(entry.first).thenApply { val future = client.get(entry.first).handle { response, ex ->
if (it == null) { if(ex != null) {
errorCounter.incrementAndGet()
log.error(ex.message, ex)
} else if (response == null) {
errorCounter.incrementAndGet()
log.error { log.error {
"Missing entry for key '${entry.first}'" "Missing entry for key '${entry.first}'"
} }
} else if (!entry.second.contentEquals(it)) { } else if (!entry.second.contentEquals(response)) {
errorCounter.incrementAndGet()
log.error { log.error {
"Retrieved a value different from what was inserted for key '${entry.first}'" "Retrieved a value different from what was inserted for key '${entry.first}': " +
"expected '${JWO.bytesToHex(entry.second)}', got '${JWO.bytesToHex(response)}' instead"
} }
} }
} }
@@ -134,6 +141,12 @@ class BenchmarkCommand : RbcsCommand() {
} }
} }
val end = Instant.now() val end = Instant.now()
val errors = errorCounter.get()
val successfulRetrievals = entries.size - errors
val successRate = successfulRetrievals.toDouble() / entries.size
log.info {
"Successfully retrieved ${entries.size - errors}/${entries.size} (${String.format("%.1f", successRate * 100)}%)"
}
log.info { log.info {
val elapsed = Duration.between(start, end).toMillis() val elapsed = Duration.between(start, end).toMillis()
val opsPerSecond = String.format("%.2f", entries.size.toDouble() / elapsed * 1000) val opsPerSecond = String.format("%.2f", entries.size.toDouble() / elapsed * 1000)

View File

@@ -152,7 +152,7 @@ class RemoteBuildCacheClient(private val profile: Configuration.Profile) : AutoC
} }
val pipeline: ChannelPipeline = ch.pipeline() val pipeline: ChannelPipeline = ch.pipeline()
profile.connection?.also { conn -> profile.connection.also { conn ->
val readIdleTimeout = conn.readIdleTimeout.toMillis() val readIdleTimeout = conn.readIdleTimeout.toMillis()
val writeIdleTimeout = conn.writeIdleTimeout.toMillis() val writeIdleTimeout = conn.writeIdleTimeout.toMillis()
val idleTimeout = conn.idleTimeout.toMillis() val idleTimeout = conn.idleTimeout.toMillis()
@@ -295,7 +295,6 @@ class RemoteBuildCacheClient(private val profile: Configuration.Profile) : AutoC
): CompletableFuture<FullHttpResponse> { ): CompletableFuture<FullHttpResponse> {
val responseFuture = CompletableFuture<FullHttpResponse>() val responseFuture = CompletableFuture<FullHttpResponse>()
// Custom handler for processing responses // Custom handler for processing responses
pool.acquire().addListener(object : GenericFutureListener<NettyFuture<Channel>> { pool.acquire().addListener(object : GenericFutureListener<NettyFuture<Channel>> {
override fun operationComplete(channelFuture: Future<Channel>) { override fun operationComplete(channelFuture: Future<Channel>) {
@@ -320,7 +319,7 @@ class RemoteBuildCacheClient(private val profile: Configuration.Profile) : AutoC
) { ) {
pipeline.remove(this) pipeline.remove(this)
responseFuture.complete(response) responseFuture.complete(response)
if(!profile.connection.requestPipelining) { if (!profile.connection.requestPipelining) {
pool.release(channel) pool.release(channel)
} }
} }
@@ -337,21 +336,15 @@ class RemoteBuildCacheClient(private val profile: Configuration.Profile) : AutoC
override fun channelInactive(ctx: ChannelHandlerContext) { override fun channelInactive(ctx: ChannelHandlerContext) {
responseFuture.completeExceptionally(IOException("The remote server closed the connection")) responseFuture.completeExceptionally(IOException("The remote server closed the connection"))
if(!profile.connection.requestPipelining) {
pool.release(channel)
}
super.channelInactive(ctx) super.channelInactive(ctx)
pool.release(channel)
} }
override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) { override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) {
if (evt is IdleStateEvent) { if (evt is IdleStateEvent) {
val te = when (evt.state()) { val te = when (evt.state()) {
IdleState.READER_IDLE -> TimeoutException( IdleState.READER_IDLE -> TimeoutException("Read timeout")
"Read timeout",
)
IdleState.WRITER_IDLE -> TimeoutException("Write timeout") IdleState.WRITER_IDLE -> TimeoutException("Write timeout")
IdleState.ALL_IDLE -> TimeoutException("Idle timeout") IdleState.ALL_IDLE -> TimeoutException("Idle timeout")
null -> throw IllegalStateException("This should never happen") null -> throw IllegalStateException("This should never happen")
} }
@@ -360,7 +353,7 @@ class RemoteBuildCacheClient(private val profile: Configuration.Profile) : AutoC
if (this === pipeline.last()) { if (this === pipeline.last()) {
ctx.close() ctx.close()
} }
if(!profile.connection.requestPipelining) { if (!profile.connection.requestPipelining) {
pool.release(channel) pool.release(channel)
} }
} else { } else {
@@ -408,11 +401,11 @@ class RemoteBuildCacheClient(private val profile: Configuration.Profile) : AutoC
// Send the request // Send the request
channel.writeAndFlush(request).addListener { channel.writeAndFlush(request).addListener {
if(!it.isSuccess) { if (!it.isSuccess) {
val ex = it.cause() val ex = it.cause()
log.warn(ex.message, ex) log.warn(ex.message, ex)
} }
if(profile.connection.requestPipelining) { if (profile.connection.requestPipelining) {
pool.release(channel) pool.release(channel)
} }
} }

View File

@@ -265,7 +265,7 @@ class MemcacheCacheHandler(
log.debug(ctx) { log.debug(ctx) {
"Cache miss for key ${msg.key} on memcache" "Cache miss for key ${msg.key} on memcache"
} }
sendMessageAndFlush(ctx, CacheValueNotFoundResponse()) sendMessageAndFlush(ctx, CacheValueNotFoundResponse(msg.key))
} }
} }
} }

View File

@@ -72,8 +72,8 @@ import net.woggioni.rbcs.server.auth.RoleAuthorizer
import net.woggioni.rbcs.server.configuration.Parser import net.woggioni.rbcs.server.configuration.Parser
import net.woggioni.rbcs.server.configuration.Serializer import net.woggioni.rbcs.server.configuration.Serializer
import net.woggioni.rbcs.server.exception.ExceptionHandler import net.woggioni.rbcs.server.exception.ExceptionHandler
import net.woggioni.rbcs.server.handler.BlackHoleRequestHandler
import net.woggioni.rbcs.server.handler.MaxRequestSizeHandler import net.woggioni.rbcs.server.handler.MaxRequestSizeHandler
import net.woggioni.rbcs.server.handler.ReadTriggerDuplexHandler
import net.woggioni.rbcs.server.handler.ServerHandler import net.woggioni.rbcs.server.handler.ServerHandler
import net.woggioni.rbcs.server.throttling.BucketManager import net.woggioni.rbcs.server.throttling.BucketManager
import net.woggioni.rbcs.server.throttling.ThrottlingHandler import net.woggioni.rbcs.server.throttling.ThrottlingHandler
@@ -298,7 +298,7 @@ class RemoteBuildCacheServer(private val cfg: Configuration) {
"Closed connection ${ch.id().asShortText()} with ${ch.remoteAddress()}" "Closed connection ${ch.id().asShortText()} with ${ch.remoteAddress()}"
} }
} }
ch.config().setAutoRead(false) ch.config().isAutoRead = false
val pipeline = ch.pipeline() val pipeline = ch.pipeline()
cfg.connection.also { conn -> cfg.connection.also { conn ->
val readIdleTimeout = conn.readIdleTimeout.toMillis() val readIdleTimeout = conn.readIdleTimeout.toMillis()
@@ -345,13 +345,14 @@ class RemoteBuildCacheServer(private val cfg: Configuration) {
maxChunkSize = cfg.connection.chunkSize maxChunkSize = cfg.connection.chunkSize
} }
pipeline.addLast(HttpServerCodec(httpDecoderConfig)) pipeline.addLast(HttpServerCodec(httpDecoderConfig))
pipeline.addLast(ReadTriggerDuplexHandler.NAME, ReadTriggerDuplexHandler)
pipeline.addLast(MaxRequestSizeHandler.NAME, MaxRequestSizeHandler(cfg.connection.maxRequestSize)) pipeline.addLast(MaxRequestSizeHandler.NAME, MaxRequestSizeHandler(cfg.connection.maxRequestSize))
pipeline.addLast(HttpChunkContentCompressor(1024)) pipeline.addLast(HttpChunkContentCompressor(1024))
pipeline.addLast(ChunkedWriteHandler()) pipeline.addLast(ChunkedWriteHandler())
authenticator?.let { authenticator?.let {
pipeline.addLast(it) pipeline.addLast(it)
} }
pipeline.addLast(ThrottlingHandler(bucketManager, cfg.connection)) pipeline.addLast(ThrottlingHandler(bucketManager,cfg.rateLimiter, cfg.connection))
val serverHandler = let { val serverHandler = let {
val prefix = Path.of("/").resolve(Path.of(cfg.serverPath ?: "/")) val prefix = Path.of("/").resolve(Path.of(cfg.serverPath ?: "/"))
@@ -361,7 +362,6 @@ class RemoteBuildCacheServer(private val cfg: Configuration) {
} }
pipeline.addLast(ServerHandler.NAME, serverHandler) pipeline.addLast(ServerHandler.NAME, serverHandler)
pipeline.addLast(ExceptionHandler.NAME, ExceptionHandler) pipeline.addLast(ExceptionHandler.NAME, ExceptionHandler)
pipeline.addLast(BlackHoleRequestHandler.NAME, BlackHoleRequestHandler())
} }
override fun asyncClose() = cacheHandlerFactory.asyncClose() override fun asyncClose() = cacheHandlerFactory.asyncClose()

View File

@@ -125,7 +125,7 @@ class FileSystemCacheHandler(
sendMessageAndFlush(ctx, LastHttpContent.EMPTY_LAST_CONTENT) sendMessageAndFlush(ctx, LastHttpContent.EMPTY_LAST_CONTENT)
} }
} }
} ?: sendMessageAndFlush(ctx, CacheValueNotFoundResponse()) } ?: sendMessageAndFlush(ctx, CacheValueNotFoundResponse(key))
} }
} }
} }

View File

@@ -11,6 +11,7 @@ import java.util.concurrent.CompletableFuture
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantReadWriteLock import java.util.concurrent.locks.ReentrantReadWriteLock
import kotlin.concurrent.withLock import kotlin.concurrent.withLock
import net.woggioni.rbcs.common.debug
private class CacheKey(private val value: ByteArray) { private class CacheKey(private val value: ByteArray) {
override fun equals(other: Any?) = if (other is CacheKey) { override fun equals(other: Any?) = if (other is CacheKey) {

View File

@@ -118,7 +118,7 @@ class InMemoryCacheHandler(
} else { } else {
sendMessage(ctx, LastCacheContent(value.content)) sendMessage(ctx, LastCacheContent(value.content))
} }
} ?: sendMessage(ctx, CacheValueNotFoundResponse()) } ?: sendMessage(ctx, CacheValueNotFoundResponse(req.request.key))
} }
is InProgressPutRequest -> { is InProgressPutRequest -> {
this.inProgressRequest = null this.inProgressRequest = null

View File

@@ -33,6 +33,7 @@ object Parser {
0x4000000, 0x4000000,
0x10000 0x10000
) )
var rateLimiter = Configuration.RateLimiter(false, 0x100000, 100)
var eventExecutor: Configuration.EventExecutor = Configuration.EventExecutor(true) var eventExecutor: Configuration.EventExecutor = Configuration.EventExecutor(true)
var cache: Cache? = null var cache: Cache? = null
var host = "127.0.0.1" var host = "127.0.0.1"
@@ -132,11 +133,24 @@ object Parser {
} }
"event-executor" -> { "event-executor" -> {
val useVirtualThread = root.renderAttribute("use-virtual-threads") val useVirtualThread = child.renderAttribute("use-virtual-threads")
?.let(String::toBoolean) ?: true ?.let(String::toBoolean) ?: true
eventExecutor = Configuration.EventExecutor(useVirtualThread) eventExecutor = Configuration.EventExecutor(useVirtualThread)
} }
"rate-limiter" -> {
val delayResponse = child.renderAttribute("delay-response")
?.let(String::toBoolean)
?: false
val messageBufferSize = child.renderAttribute("message-buffer-size")
?.let(Integer::decode)
?: 0x100000
val maxQueuedMessages = child.renderAttribute("max-queued-messages")
?.let(Integer::decode)
?: 100
rateLimiter = Configuration.RateLimiter(delayResponse, messageBufferSize, maxQueuedMessages)
}
"tls" -> { "tls" -> {
var keyStore: KeyStore? = null var keyStore: KeyStore? = null
var trustStore: TrustStore? = null var trustStore: TrustStore? = null
@@ -184,6 +198,7 @@ object Parser {
incomingConnectionsBacklogSize, incomingConnectionsBacklogSize,
serverPath, serverPath,
eventExecutor, eventExecutor,
rateLimiter,
connection, connection,
users, users,
groups, groups,

View File

@@ -46,6 +46,11 @@ object Serializer {
node("event-executor") { node("event-executor") {
attr("use-virtual-threads", conf.eventExecutor.isUseVirtualThreads.toString()) attr("use-virtual-threads", conf.eventExecutor.isUseVirtualThreads.toString())
} }
node("rate-limiter") {
attr("delay-response", conf.rateLimiter.isDelayRequest.toString())
attr("max-queued-messages", conf.rateLimiter.maxQueuedMessages.toString())
attr("message-buffer-size", conf.rateLimiter.messageBufferSize.toString())
}
val cache = conf.cache val cache = conf.cache
val serializer : CacheProvider<Configuration.Cache> = val serializer : CacheProvider<Configuration.Cache> =
(CacheSerializers.index[cache.namespaceURI to cache.typeName] as? CacheProvider<Configuration.Cache>) ?: throw NotImplementedError() (CacheSerializers.index[cache.namespaceURI to cache.typeName] as? CacheProvider<Configuration.Cache>) ?: throw NotImplementedError()

View File

@@ -0,0 +1,36 @@
package net.woggioni.rbcs.server.handler
import io.netty.channel.ChannelDuplexHandler
import io.netty.channel.ChannelHandler.Sharable
import io.netty.channel.ChannelHandlerContext
import io.netty.channel.ChannelPromise
import io.netty.handler.codec.http.LastHttpContent
import net.woggioni.rbcs.common.createLogger
import net.woggioni.rbcs.common.debug
@Sharable
object ReadTriggerDuplexHandler : ChannelDuplexHandler() {
val NAME = ReadTriggerDuplexHandler::class.java.name
override fun handlerAdded(ctx: ChannelHandlerContext) {
ctx.read()
}
override fun channelRead(ctx: ChannelHandlerContext, msg: Any) {
super.channelRead(ctx, msg)
if(msg !is LastHttpContent) {
ctx.read()
}
}
override fun write(
ctx: ChannelHandlerContext,
msg: Any,
promise: ChannelPromise
) {
super.write(ctx, msg, promise)
if(msg is LastHttpContent) {
ctx.read()
}
}
}

View File

@@ -43,16 +43,6 @@ class ServerHandler(private val serverPrefix: Path, private val cacheHandlerSupp
private var httpVersion = HttpVersion.HTTP_1_1 private var httpVersion = HttpVersion.HTTP_1_1
private var keepAlive = true private var keepAlive = true
private var pipelinedRequests = 0
private fun newRequest() {
pipelinedRequests += 1
}
private fun requestCompleted(ctx : ChannelHandlerContext) {
pipelinedRequests -= 1
if(pipelinedRequests == 0) ctx.read()
}
private fun resetRequestMetadata() { private fun resetRequestMetadata() {
httpVersion = HttpVersion.HTTP_1_1 httpVersion = HttpVersion.HTTP_1_1
@@ -74,10 +64,6 @@ class ServerHandler(private val serverPrefix: Path, private val cacheHandlerSupp
private var cacheRequestInProgress : Boolean = false private var cacheRequestInProgress : Boolean = false
override fun handlerAdded(ctx: ChannelHandlerContext) {
ctx.read()
}
override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { override fun channelRead(ctx: ChannelHandlerContext, msg: Any) {
when (msg) { when (msg) {
is HttpRequest -> handleRequest(ctx, msg) is HttpRequest -> handleRequest(ctx, msg)
@@ -98,18 +84,14 @@ class ServerHandler(private val serverPrefix: Path, private val cacheHandlerSupp
} }
} }
override fun channelReadComplete(ctx: ChannelHandlerContext) {
super.channelReadComplete(ctx)
if(cacheRequestInProgress) {
ctx.read()
}
}
override fun write(ctx: ChannelHandlerContext, msg: Any, promise: ChannelPromise?) { override fun write(ctx: ChannelHandlerContext, msg: Any, promise: ChannelPromise?) {
if (msg is CacheMessage) { if (msg is CacheMessage) {
try { try {
when (msg) { when (msg) {
is CachePutResponse -> { is CachePutResponse -> {
log.debug(ctx) {
"Added value for key '${msg.key}' to build cache"
}
val response = DefaultFullHttpResponse(httpVersion, HttpResponseStatus.CREATED) val response = DefaultFullHttpResponse(httpVersion, HttpResponseStatus.CREATED)
val keyBytes = msg.key.toByteArray(Charsets.UTF_8) val keyBytes = msg.key.toByteArray(Charsets.UTF_8)
response.headers().apply { response.headers().apply {
@@ -121,21 +103,23 @@ class ServerHandler(private val serverPrefix: Path, private val cacheHandlerSupp
val buf = ctx.alloc().buffer(keyBytes.size).apply { val buf = ctx.alloc().buffer(keyBytes.size).apply {
writeBytes(keyBytes) writeBytes(keyBytes)
} }
ctx.writeAndFlush(DefaultLastHttpContent(buf)).also { ctx.writeAndFlush(DefaultLastHttpContent(buf))
requestCompleted(ctx)
}
} }
is CacheValueNotFoundResponse -> { is CacheValueNotFoundResponse -> {
log.debug(ctx) {
"Value not found for key '${msg.key}'"
}
val response = DefaultFullHttpResponse(httpVersion, HttpResponseStatus.NOT_FOUND) val response = DefaultFullHttpResponse(httpVersion, HttpResponseStatus.NOT_FOUND)
response.headers()[HttpHeaderNames.CONTENT_LENGTH] = 0 response.headers()[HttpHeaderNames.CONTENT_LENGTH] = 0
setKeepAliveHeader(response.headers()) setKeepAliveHeader(response.headers())
ctx.writeAndFlush(response).also { ctx.writeAndFlush(response)
requestCompleted(ctx)
}
} }
is CacheValueFoundResponse -> { is CacheValueFoundResponse -> {
log.debug(ctx) {
"Retrieved value for key '${msg.key}'"
}
val response = DefaultHttpResponse(httpVersion, HttpResponseStatus.OK) val response = DefaultHttpResponse(httpVersion, HttpResponseStatus.OK)
response.headers().apply { response.headers().apply {
set(HttpHeaderNames.CONTENT_TYPE, msg.metadata.mimeType ?: HttpHeaderValues.APPLICATION_OCTET_STREAM) set(HttpHeaderNames.CONTENT_TYPE, msg.metadata.mimeType ?: HttpHeaderValues.APPLICATION_OCTET_STREAM)
@@ -149,9 +133,7 @@ class ServerHandler(private val serverPrefix: Path, private val cacheHandlerSupp
} }
is LastCacheContent -> { is LastCacheContent -> {
ctx.writeAndFlush(DefaultLastHttpContent(msg.content())).also { ctx.writeAndFlush(DefaultLastHttpContent(msg.content()))
requestCompleted(ctx)
}
} }
is CacheContent -> { is CacheContent -> {
@@ -172,7 +154,6 @@ class ServerHandler(private val serverPrefix: Path, private val cacheHandlerSupp
} }
} else if(msg is LastHttpContent) { } else if(msg is LastHttpContent) {
ctx.write(msg, promise) ctx.write(msg, promise)
requestCompleted(ctx)
} else super.write(ctx, msg, promise) } else super.write(ctx, msg, promise)
} }
@@ -186,13 +167,13 @@ class ServerHandler(private val serverPrefix: Path, private val cacheHandlerSupp
cacheRequestInProgress = true cacheRequestInProgress = true
val relativePath = serverPrefix.relativize(path) val relativePath = serverPrefix.relativize(path)
val key : String = relativePath.toString() val key : String = relativePath.toString()
newRequest()
val cacheHandler = cacheHandlerSupplier() val cacheHandler = cacheHandlerSupplier()
ctx.pipeline().addBefore(ExceptionHandler.NAME, null, cacheHandler) ctx.pipeline().addBefore(ExceptionHandler.NAME, null, cacheHandler)
key.let(::CacheGetRequest) key.let(::CacheGetRequest)
.let(ctx::fireChannelRead) .let(ctx::fireChannelRead)
?: ctx.channel().write(CacheValueNotFoundResponse()) ?: ctx.channel().write(CacheValueNotFoundResponse(key))
} else { } else {
cacheRequestInProgress = false
log.warn(ctx) { log.warn(ctx) {
"Got request for unhandled path '${msg.uri()}'" "Got request for unhandled path '${msg.uri()}'"
} }
@@ -206,12 +187,8 @@ class ServerHandler(private val serverPrefix: Path, private val cacheHandlerSupp
cacheRequestInProgress = true cacheRequestInProgress = true
val relativePath = serverPrefix.relativize(path) val relativePath = serverPrefix.relativize(path)
val key = relativePath.toString() val key = relativePath.toString()
log.debug(ctx) {
"Added value for key '$key' to build cache"
}
newRequest()
val cacheHandler = cacheHandlerSupplier() val cacheHandler = cacheHandlerSupplier()
ctx.pipeline().addBefore(ExceptionHandler.NAME, null, cacheHandler) ctx.pipeline().addAfter(NAME, null, cacheHandler)
path.fileName?.toString() path.fileName?.toString()
?.let { ?.let {
@@ -219,8 +196,9 @@ class ServerHandler(private val serverPrefix: Path, private val cacheHandlerSupp
CachePutRequest(key, CacheValueMetadata(msg.headers().get(HttpHeaderNames.CONTENT_DISPOSITION), mimeType)) CachePutRequest(key, CacheValueMetadata(msg.headers().get(HttpHeaderNames.CONTENT_DISPOSITION), mimeType))
} }
?.let(ctx::fireChannelRead) ?.let(ctx::fireChannelRead)
?: ctx.channel().write(CacheValueNotFoundResponse()) ?: ctx.channel().write(CacheValueNotFoundResponse(key))
} else { } else {
cacheRequestInProgress = false
log.warn(ctx) { log.warn(ctx) {
"Got request for unhandled path '${msg.uri()}'" "Got request for unhandled path '${msg.uri()}'"
} }
@@ -229,10 +207,11 @@ class ServerHandler(private val serverPrefix: Path, private val cacheHandlerSupp
ctx.writeAndFlush(response) ctx.writeAndFlush(response)
} }
} else if (method == HttpMethod.TRACE) { } else if (method == HttpMethod.TRACE) {
newRequest() cacheRequestInProgress = false
ctx.pipeline().addBefore(ExceptionHandler.NAME, null, TraceHandler) ctx.pipeline().addAfter(NAME, null, TraceHandler)
super.channelRead(ctx, msg) super.channelRead(ctx, msg)
} else { } else {
cacheRequestInProgress = false
log.warn(ctx) { log.warn(ctx) {
"Got request with unhandled method '${msg.method().name()}'" "Got request with unhandled method '${msg.method().name()}'"
} }
@@ -245,4 +224,4 @@ class ServerHandler(private val serverPrefix: Path, private val cacheHandlerSupp
override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
super.exceptionCaught(ctx, cause) super.exceptionCaught(ctx, cause)
} }
} }

View File

@@ -1,32 +1,54 @@
package net.woggioni.rbcs.server.throttling package net.woggioni.rbcs.server.throttling
import io.netty.buffer.ByteBufHolder
import io.netty.channel.ChannelHandlerContext import io.netty.channel.ChannelHandlerContext
import io.netty.channel.ChannelInboundHandlerAdapter import io.netty.channel.ChannelInboundHandlerAdapter
import io.netty.handler.codec.http.DefaultFullHttpResponse import io.netty.handler.codec.http.DefaultFullHttpResponse
import io.netty.handler.codec.http.FullHttpMessage
import io.netty.handler.codec.http.HttpContent import io.netty.handler.codec.http.HttpContent
import io.netty.handler.codec.http.HttpHeaderNames import io.netty.handler.codec.http.HttpHeaderNames
import io.netty.handler.codec.http.HttpRequest import io.netty.handler.codec.http.HttpRequest
import io.netty.handler.codec.http.HttpResponseStatus import io.netty.handler.codec.http.HttpResponseStatus
import io.netty.handler.codec.http.HttpVersion import io.netty.handler.codec.http.HttpVersion
import io.netty.handler.codec.http.LastHttpContent
import java.net.InetSocketAddress
import java.time.Duration
import java.time.temporal.ChronoUnit
import java.util.ArrayDeque
import java.util.LinkedList
import java.util.concurrent.TimeUnit
import kotlin.collections.forEach
import kotlin.collections.isNotEmpty
import net.woggioni.jwo.Bucket import net.woggioni.jwo.Bucket
import net.woggioni.jwo.LongMath import net.woggioni.jwo.LongMath
import net.woggioni.rbcs.api.Configuration import net.woggioni.rbcs.api.Configuration
import net.woggioni.rbcs.common.createLogger import net.woggioni.rbcs.common.createLogger
import net.woggioni.rbcs.common.debug
import net.woggioni.rbcs.server.RemoteBuildCacheServer import net.woggioni.rbcs.server.RemoteBuildCacheServer
import java.net.InetSocketAddress
import java.time.Duration
import java.time.temporal.ChronoUnit
import java.util.concurrent.TimeUnit
class ThrottlingHandler(
class ThrottlingHandler(private val bucketManager : BucketManager, private val bucketManager: BucketManager,
private val connectionConfiguration : Configuration.Connection) : ChannelInboundHandlerAdapter() { rateLimiterConfiguration: Configuration.RateLimiter,
connectionConfiguration: Configuration.Connection
) : ChannelInboundHandlerAdapter() {
private companion object { private companion object {
private val log = createLogger<ThrottlingHandler>() private val log = createLogger<ThrottlingHandler>()
fun nextAttemptIsWithinThreshold(nextAttemptNanos : Long, waitThreshold : Duration) : Boolean {
val waitDuration = Duration.of(LongMath.ceilDiv(nextAttemptNanos, 100_000_000L) * 100L, ChronoUnit.MILLIS)
return waitDuration < waitThreshold
}
} }
private var queuedContent : MutableList<HttpContent>? = null private class RefusedRequest
private val maxMessageBufferSize = rateLimiterConfiguration.messageBufferSize
private val maxQueuedMessages = rateLimiterConfiguration.maxQueuedMessages
private val delayRequests = rateLimiterConfiguration.isDelayRequest
private var requestBufferSize : Int = 0
private var valveClosed = false
private var queuedContent = ArrayDeque<Any>()
/** /**
* If the suggested waiting time from the bucket is lower than this * If the suggested waiting time from the bucket is lower than this
@@ -39,38 +61,155 @@ class ThrottlingHandler(private val bucketManager : BucketManager,
connectionConfiguration.writeIdleTimeout connectionConfiguration.writeIdleTimeout
).dividedBy(2) ).dividedBy(2)
override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { override fun channelRead(ctx: ChannelHandlerContext, msg: Any) {
if(msg is HttpRequest) { if(valveClosed) {
val buckets = mutableListOf<Bucket>() if(msg !is HttpRequest && msg is ByteBufHolder) {
val user = ctx.channel().attr(RemoteBuildCacheServer.userAttribute).get() val newBufferSize = requestBufferSize + msg.content().readableBytes()
if (user != null) { if(newBufferSize > maxMessageBufferSize || queuedContent.size + 1 > maxQueuedMessages) {
bucketManager.getBucketByUser(user)?.let(buckets::addAll) log.debug {
} if (newBufferSize > maxMessageBufferSize) {
val groups = ctx.channel().attr(RemoteBuildCacheServer.groupAttribute).get() ?: emptySet() "New message part exceeds maxMessageBufferSize, removing previous chunks"
if (groups.isNotEmpty()) { } else {
groups.forEach { group -> "New message part exceeds maxQueuedMessages, removing previous chunks"
bucketManager.getBucketByGroup(group)?.let(buckets::add) }
}
// If this message overflows the maxMessageBufferSize,
// then remove the previously enqueued chunks of the request from the deque,
// then discard the message
while(true) {
val tail = queuedContent.last()
if(tail is ByteBufHolder) {
requestBufferSize -= tail.content().readableBytes()
tail.release()
}
queuedContent.removeLast()
if(tail is HttpRequest) {
break
}
}
msg.release()
//Add a placeholder to remember to return a 429 response corresponding to this request
queuedContent.addLast(RefusedRequest())
} else {
//If the message does not overflow maxMessageBufferSize, just add it to the deque
queuedContent.addLast(msg)
requestBufferSize = newBufferSize
}
} else if(msg is HttpRequest && msg is FullHttpMessage){
val newBufferSize = requestBufferSize + msg.content().readableBytes()
// If this message overflows the maxMessageBufferSize,
// discard the message
if(newBufferSize > maxMessageBufferSize || queuedContent.size + 1 > maxQueuedMessages) {
log.debug {
if (newBufferSize > maxMessageBufferSize) {
"New message exceeds maxMessageBufferSize, discarding it"
} else {
"New message exceeds maxQueuedMessages, discarding it"
}
}
msg.release()
//Add a placeholder to remember to return a 429 response corresponding to this request
queuedContent.addLast(RefusedRequest())
} else {
//If the message does not exceed maxMessageBufferSize or maxQueuedMessages, just add it to the deque
queuedContent.addLast(msg)
requestBufferSize = newBufferSize
} }
}
if (user == null && groups.isEmpty()) {
bucketManager.getBucketByAddress(ctx.channel().remoteAddress() as InetSocketAddress)?.let(buckets::add)
}
if (buckets.isEmpty()) {
super.channelRead(ctx, msg)
} else { } else {
handleBuckets(buckets, ctx, msg, true) queuedContent.addLast(msg)
} }
ctx.channel().id()
} else if(msg is HttpContent) {
queuedContent?.add(msg) ?: super.channelRead(ctx, msg)
} else { } else {
super.channelRead(ctx, msg) entryPoint(ctx, msg)
} }
} }
private fun handleBuckets(buckets: List<Bucket>, ctx: ChannelHandlerContext, msg: Any, delayResponse: Boolean) { private fun entryPoint(ctx : ChannelHandlerContext, msg : Any) {
if(msg is RefusedRequest) {
sendThrottledResponse(ctx, null)
if(queuedContent.isEmpty()) {
valveClosed = false
} else {
val head = queuedContent.poll()
if(head is ByteBufHolder) {
requestBufferSize -= head.content().readableBytes()
}
entryPoint(ctx, head)
}
} else if(msg is HttpRequest) {
val nextAttempt = getNextAttempt(ctx)
if (nextAttempt < 0) {
super.channelRead(ctx, msg)
if(msg !is LastHttpContent) {
while (true) {
val head = queuedContent.poll() ?: break
if(head is ByteBufHolder) {
requestBufferSize -= head.content().readableBytes()
}
super.channelRead(ctx, head)
if (head is LastHttpContent) break
}
}
log.debug {
"Queue size: ${queuedContent.stream().filter { it !is RefusedRequest }.count()}"
}
if(queuedContent.isEmpty()) {
valveClosed = false
} else {
val head = queuedContent.poll()
if(head is ByteBufHolder) {
requestBufferSize -= head.content().readableBytes()
}
entryPoint(ctx, head)
}
} else {
val waitDuration = Duration.of(LongMath.ceilDiv(nextAttempt, 100_000_000L) * 100L, ChronoUnit.MILLIS)
if (delayRequests && nextAttemptIsWithinThreshold(nextAttempt, waitThreshold)) {
valveClosed = true
ctx.executor().schedule({
entryPoint(ctx, msg)
}, waitDuration.toMillis(), TimeUnit.MILLISECONDS)
} else {
sendThrottledResponse(ctx, waitDuration)
if(queuedContent.isEmpty()) {
valveClosed = false
} else {
val head = queuedContent.poll()
if(head is ByteBufHolder) {
requestBufferSize -= head.content().readableBytes()
}
entryPoint(ctx, head)
}
}
}
} else {
super.channelRead(ctx, msg)
log.debug {
"Queue size: ${queuedContent.stream().filter { it !is RefusedRequest }.count()}"
}
}
}
/**
* Returns the number amount of milliseconds to wait before the requests can be processed
* or -1 if the request can be performed immediately
*/
private fun getNextAttempt(ctx : ChannelHandlerContext) : Long {
val buckets = mutableListOf<Bucket>()
val user = ctx.channel().attr(RemoteBuildCacheServer.userAttribute).get()
if (user != null) {
bucketManager.getBucketByUser(user)?.let(buckets::addAll)
}
val groups = ctx.channel().attr(RemoteBuildCacheServer.groupAttribute).get() ?: emptySet()
if (groups.isNotEmpty()) {
groups.forEach { group ->
bucketManager.getBucketByGroup(group)?.let(buckets::add)
}
}
if (user == null && groups.isEmpty()) {
bucketManager.getBucketByAddress(ctx.channel().remoteAddress() as InetSocketAddress)?.let(buckets::add)
}
var nextAttempt = -1L var nextAttempt = -1L
for (bucket in buckets) { for (bucket in buckets) {
val bucketNextAttempt = bucket.removeTokensWithEstimate(1) val bucketNextAttempt = bucket.removeTokensWithEstimate(1)
@@ -78,41 +217,19 @@ class ThrottlingHandler(private val bucketManager : BucketManager,
nextAttempt = bucketNextAttempt nextAttempt = bucketNextAttempt
} }
} }
if (nextAttempt < 0) { return nextAttempt
super.channelRead(ctx, msg)
queuedContent?.let {
for(content in it) {
super.channelRead(ctx, content)
}
queuedContent = null
}
} else {
val waitDuration = Duration.of(LongMath.ceilDiv(nextAttempt, 100_000_000L) * 100L, ChronoUnit.MILLIS)
if (delayResponse && waitDuration < waitThreshold) {
this.queuedContent = mutableListOf()
ctx.executor().schedule({
handleBuckets(buckets, ctx, msg, false)
}, waitDuration.toMillis(), TimeUnit.MILLISECONDS)
} else {
queuedContent?.let { qc ->
qc.forEach { it.release() }
}
this.queuedContent = null
sendThrottledResponse(ctx, waitDuration)
}
}
} }
private fun sendThrottledResponse(ctx: ChannelHandlerContext, retryAfter: Duration) { private fun sendThrottledResponse(ctx: ChannelHandlerContext, retryAfter: Duration?) {
val response = DefaultFullHttpResponse( val response = DefaultFullHttpResponse(
HttpVersion.HTTP_1_1, HttpVersion.HTTP_1_1,
HttpResponseStatus.TOO_MANY_REQUESTS HttpResponseStatus.TOO_MANY_REQUESTS
) )
response.headers()[HttpHeaderNames.CONTENT_LENGTH] = 0 response.headers()[HttpHeaderNames.CONTENT_LENGTH] = 0
retryAfter.seconds.takeIf { retryAfter?.seconds?.takeIf {
it > 0 it > 0
}?.let { }?.let {
response.headers()[HttpHeaderNames.RETRY_AFTER] = retryAfter.seconds response.headers()[HttpHeaderNames.RETRY_AFTER] = it
} }
ctx.writeAndFlush(response) ctx.writeAndFlush(response)

View File

@@ -16,6 +16,7 @@
<xs:element name="bind" type="rbcs:bindType" maxOccurs="1"/> <xs:element name="bind" type="rbcs:bindType" maxOccurs="1"/>
<xs:element name="connection" type="rbcs:connectionType" minOccurs="0" maxOccurs="1"/> <xs:element name="connection" type="rbcs:connectionType" minOccurs="0" maxOccurs="1"/>
<xs:element name="event-executor" type="rbcs:eventExecutorType" minOccurs="0" maxOccurs="1"/> <xs:element name="event-executor" type="rbcs:eventExecutorType" minOccurs="0" maxOccurs="1"/>
<xs:element name="rate-limiter" type="rbcs:rateLimiterType" minOccurs="0" maxOccurs="1"/>
<xs:element name="cache" type="rbcs:cacheType" maxOccurs="1"> <xs:element name="cache" type="rbcs:cacheType" maxOccurs="1">
<xs:annotation> <xs:annotation>
<xs:documentation> <xs:documentation>
@@ -136,6 +137,37 @@
</xs:attribute> </xs:attribute>
</xs:complexType> </xs:complexType>
<xs:complexType name="rateLimiterType">
<xs:attribute name="delay-response" type="xs:boolean" use="optional" default="false">
<xs:annotation>
<xs:documentation>
If set to true, the server will delay responses to meet user quotas, otherwise it will simply
return an immediate 429 status code to all requests that exceed the configured quota
</xs:documentation>
</xs:annotation>
</xs:attribute>
<xs:attribute name="max-queued-messages" type="xs:nonNegativeInteger" use="optional" default="100">
<xs:annotation>
<xs:documentation>
Only meaningful when "delay-response" is set to "true",
when a request is delayed, it and all the following messages are queued
as long as "max-queued-messages" is not crossed, all requests that would exceed the
max-queued-message limit are instead discarded and responded with a 429 status code
</xs:documentation>
</xs:annotation>
</xs:attribute>
<xs:attribute name="message-buffer-size" type="rbcs:byteSizeType" use="optional" default="0x100000">
<xs:annotation>
<xs:documentation>
Only meaningful when "delay-response" is set to "true",
when a request is delayed, it and all the following requests are buffered
as long as "message-buffer-size" is not crossed, all requests that would exceed the buffer
size are instead discarded and responded with a 429 status code
</xs:documentation>
</xs:annotation>
</xs:attribute>
</xs:complexType>
<xs:complexType name="cacheType" abstract="true"/> <xs:complexType name="cacheType" abstract="true"/>
<xs:complexType name="inMemoryCacheType"> <xs:complexType name="inMemoryCacheType">

View File

@@ -37,6 +37,7 @@ abstract class AbstractBasicAuthServerTest : AbstractServerTest() {
50, 50,
serverPath, serverPath,
Configuration.EventExecutor(false), Configuration.EventExecutor(false),
Configuration.RateLimiter(true, 0x100000, 50),
Configuration.Connection( Configuration.Connection(
Duration.of(60, ChronoUnit.SECONDS), Duration.of(60, ChronoUnit.SECONDS),
Duration.of(30, ChronoUnit.SECONDS), Duration.of(30, ChronoUnit.SECONDS),

View File

@@ -143,6 +143,7 @@ abstract class AbstractTlsServerTest : AbstractServerTest() {
100, 100,
serverPath, serverPath,
Configuration.EventExecutor(false), Configuration.EventExecutor(false),
Configuration.RateLimiter(true, 0x100000, 50),
Configuration.Connection( Configuration.Connection(
Duration.of(60, ChronoUnit.SECONDS), Duration.of(60, ChronoUnit.SECONDS),
Duration.of(30, ChronoUnit.SECONDS), Duration.of(30, ChronoUnit.SECONDS),

View File

@@ -37,6 +37,7 @@ class NoAuthServerTest : AbstractServerTest() {
100, 100,
serverPath, serverPath,
Configuration.EventExecutor(false), Configuration.EventExecutor(false),
Configuration.RateLimiter(true, 0x100000, 50),
Configuration.Connection( Configuration.Connection(
Duration.of(60, ChronoUnit.SECONDS), Duration.of(60, ChronoUnit.SECONDS),
Duration.of(30, ChronoUnit.SECONDS), Duration.of(30, ChronoUnit.SECONDS),

View File

@@ -10,6 +10,7 @@
max-request-size="101325" max-request-size="101325"
chunk-size="0xa910"/> chunk-size="0xa910"/>
<event-executor use-virtual-threads="false"/> <event-executor use-virtual-threads="false"/>
<rate-limiter delay-response="false" message-buffer-size="0x1234" max-queued-messages="13"/>
<cache xs:type="rbcs:fileSystemCacheType" path="/tmp/rbcs" max-age="P7D"/> <cache xs:type="rbcs:fileSystemCacheType" path="/tmp/rbcs" max-age="P7D"/>
<authentication> <authentication>
<none/> <none/>

View File

@@ -12,6 +12,7 @@
write-idle-timeout="PT60S" write-idle-timeout="PT60S"
chunk-size="123"/> chunk-size="123"/>
<event-executor use-virtual-threads="true"/> <event-executor use-virtual-threads="true"/>
<rate-limiter delay-response="false" message-buffer-size="12000" max-queued-messages="53"/>
<cache xs:type="rbcs-memcache:memcacheCacheType" max-age="P7D"> <cache xs:type="rbcs-memcache:memcacheCacheType" max-age="P7D">
<server host="memcached" port="11211"/> <server host="memcached" port="11211"/>
</cache> </cache>

View File

@@ -11,6 +11,7 @@
max-request-size="101325" max-request-size="101325"
chunk-size="456"/> chunk-size="456"/>
<event-executor use-virtual-threads="false"/> <event-executor use-virtual-threads="false"/>
<rate-limiter delay-response="true" message-buffer-size="65432" max-queued-messages="21"/>
<cache xs:type="rbcs-memcache:memcacheCacheType" max-age="P7D" digest="SHA-256" compression-mode="deflate" compression-level="7"> <cache xs:type="rbcs-memcache:memcacheCacheType" max-age="P7D" digest="SHA-256" compression-mode="deflate" compression-level="7">
<server host="127.0.0.1" port="11211" max-connections="10" connection-timeout="PT20S"/> <server host="127.0.0.1" port="11211" max-connections="10" connection-timeout="PT20S"/>
</cache> </cache>