fixed bug with throttling handler when requests are delayed

This commit is contained in:
2025-06-13 13:49:13 +08:00
parent 3774ab8ef0
commit 206bcd6319
24 changed files with 350 additions and 130 deletions

View File

@@ -2,7 +2,7 @@ org.gradle.configuration-cache=false
org.gradle.parallel=true
org.gradle.caching=true
rbcs.version = 0.3.0-SNAPSHOT
rbcs.version = 0.3.1
lys.version = 2025.06.10

View File

@@ -20,6 +20,8 @@ public class Configuration {
@NonNull
EventExecutor eventExecutor;
@NonNull
RateLimiter rateLimiter;
@NonNull
Connection connection;
Map<String, User> users;
Map<String, Group> groups;
@@ -27,6 +29,13 @@ public class Configuration {
Authentication authentication;
Tls tls;
@Value
public static class RateLimiter {
boolean delayRequest;
int messageBufferSize;
int maxQueuedMessages;
}
@Value
public static class EventExecutor {
boolean useVirtualThreads;
@@ -133,6 +142,7 @@ public class Configuration {
int incomingConnectionsBacklogSize,
String serverPath,
EventExecutor eventExecutor,
RateLimiter rateLimiter,
Connection connection,
Map<String, User> users,
Map<String, Group> groups,
@@ -146,6 +156,7 @@ public class Configuration {
incomingConnectionsBacklogSize,
serverPath != null && !serverPath.isEmpty() && !serverPath.equals("/") ? serverPath : null,
eventExecutor,
rateLimiter,
connection,
users,
groups,

View File

@@ -14,17 +14,26 @@ public sealed interface CacheMessage {
private final String key;
}
@Getter
@RequiredArgsConstructor
abstract sealed class CacheGetResponse implements CacheMessage {
private final String key;
}
@Getter
@RequiredArgsConstructor
final class CacheValueFoundResponse extends CacheGetResponse {
private final String key;
private final CacheValueMetadata metadata;
public CacheValueFoundResponse(String key, CacheValueMetadata metadata) {
super(key);
this.metadata = metadata;
}
}
final class CacheValueNotFoundResponse extends CacheGetResponse {
public CacheValueNotFoundResponse(String key) {
super(key);
}
}
@Getter

View File

@@ -112,7 +112,7 @@ tasks.named(NativeImagePlugin.CONFIGURE_NATIVE_IMAGE_TASK_NAME, NativeImageConfi
nativeImage {
toolchain {
languageVersion = JavaLanguageVersion.of(23)
languageVersion = JavaLanguageVersion.of(24)
vendor = JvmVendorSpec.GRAAL_VM
}
mainClass = mainClassName

View File

@@ -99,6 +99,9 @@ object GraalNativeImageConfiguration {
100,
null,
Configuration.EventExecutor(true),
Configuration.RateLimiter(
false, 0x100000, 10
),
Configuration.Connection(
Duration.ofSeconds(10),
Duration.ofSeconds(15),

View File

@@ -101,6 +101,7 @@ class BenchmarkCommand : RbcsCommand() {
"Starting retrieval"
}
if (entries.isNotEmpty()) {
val errorCounter = AtomicLong(0)
val completionCounter = AtomicLong(0)
val semaphore = Semaphore(profile.maxConnections * 5)
val start = Instant.now()
@@ -109,14 +110,20 @@ class BenchmarkCommand : RbcsCommand() {
if (it.hasNext()) {
val entry = it.next()
semaphore.acquire()
val future = client.get(entry.first).thenApply {
if (it == null) {
val future = client.get(entry.first).handle { response, ex ->
if(ex != null) {
errorCounter.incrementAndGet()
log.error(ex.message, ex)
} else if (response == null) {
errorCounter.incrementAndGet()
log.error {
"Missing entry for key '${entry.first}'"
}
} else if (!entry.second.contentEquals(it)) {
} else if (!entry.second.contentEquals(response)) {
errorCounter.incrementAndGet()
log.error {
"Retrieved a value different from what was inserted for key '${entry.first}'"
"Retrieved a value different from what was inserted for key '${entry.first}': " +
"expected '${JWO.bytesToHex(entry.second)}', got '${JWO.bytesToHex(response)}' instead"
}
}
}
@@ -134,6 +141,12 @@ class BenchmarkCommand : RbcsCommand() {
}
}
val end = Instant.now()
val errors = errorCounter.get()
val successfulRetrievals = entries.size - errors
val successRate = successfulRetrievals.toDouble() / entries.size
log.info {
"Successfully retrieved ${entries.size - errors}/${entries.size} (${String.format("%.1f", successRate * 100)}%)"
}
log.info {
val elapsed = Duration.between(start, end).toMillis()
val opsPerSecond = String.format("%.2f", entries.size.toDouble() / elapsed * 1000)

View File

@@ -152,7 +152,7 @@ class RemoteBuildCacheClient(private val profile: Configuration.Profile) : AutoC
}
val pipeline: ChannelPipeline = ch.pipeline()
profile.connection?.also { conn ->
profile.connection.also { conn ->
val readIdleTimeout = conn.readIdleTimeout.toMillis()
val writeIdleTimeout = conn.writeIdleTimeout.toMillis()
val idleTimeout = conn.idleTimeout.toMillis()
@@ -295,7 +295,6 @@ class RemoteBuildCacheClient(private val profile: Configuration.Profile) : AutoC
): CompletableFuture<FullHttpResponse> {
val responseFuture = CompletableFuture<FullHttpResponse>()
// Custom handler for processing responses
pool.acquire().addListener(object : GenericFutureListener<NettyFuture<Channel>> {
override fun operationComplete(channelFuture: Future<Channel>) {
@@ -320,7 +319,7 @@ class RemoteBuildCacheClient(private val profile: Configuration.Profile) : AutoC
) {
pipeline.remove(this)
responseFuture.complete(response)
if(!profile.connection.requestPipelining) {
if (!profile.connection.requestPipelining) {
pool.release(channel)
}
}
@@ -337,21 +336,15 @@ class RemoteBuildCacheClient(private val profile: Configuration.Profile) : AutoC
override fun channelInactive(ctx: ChannelHandlerContext) {
responseFuture.completeExceptionally(IOException("The remote server closed the connection"))
if(!profile.connection.requestPipelining) {
pool.release(channel)
}
super.channelInactive(ctx)
pool.release(channel)
}
override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) {
if (evt is IdleStateEvent) {
val te = when (evt.state()) {
IdleState.READER_IDLE -> TimeoutException(
"Read timeout",
)
IdleState.READER_IDLE -> TimeoutException("Read timeout")
IdleState.WRITER_IDLE -> TimeoutException("Write timeout")
IdleState.ALL_IDLE -> TimeoutException("Idle timeout")
null -> throw IllegalStateException("This should never happen")
}
@@ -360,7 +353,7 @@ class RemoteBuildCacheClient(private val profile: Configuration.Profile) : AutoC
if (this === pipeline.last()) {
ctx.close()
}
if(!profile.connection.requestPipelining) {
if (!profile.connection.requestPipelining) {
pool.release(channel)
}
} else {
@@ -408,11 +401,11 @@ class RemoteBuildCacheClient(private val profile: Configuration.Profile) : AutoC
// Send the request
channel.writeAndFlush(request).addListener {
if(!it.isSuccess) {
if (!it.isSuccess) {
val ex = it.cause()
log.warn(ex.message, ex)
}
if(profile.connection.requestPipelining) {
if (profile.connection.requestPipelining) {
pool.release(channel)
}
}

View File

@@ -265,7 +265,7 @@ class MemcacheCacheHandler(
log.debug(ctx) {
"Cache miss for key ${msg.key} on memcache"
}
sendMessageAndFlush(ctx, CacheValueNotFoundResponse())
sendMessageAndFlush(ctx, CacheValueNotFoundResponse(msg.key))
}
}
}

View File

@@ -72,8 +72,8 @@ import net.woggioni.rbcs.server.auth.RoleAuthorizer
import net.woggioni.rbcs.server.configuration.Parser
import net.woggioni.rbcs.server.configuration.Serializer
import net.woggioni.rbcs.server.exception.ExceptionHandler
import net.woggioni.rbcs.server.handler.BlackHoleRequestHandler
import net.woggioni.rbcs.server.handler.MaxRequestSizeHandler
import net.woggioni.rbcs.server.handler.ReadTriggerDuplexHandler
import net.woggioni.rbcs.server.handler.ServerHandler
import net.woggioni.rbcs.server.throttling.BucketManager
import net.woggioni.rbcs.server.throttling.ThrottlingHandler
@@ -298,7 +298,7 @@ class RemoteBuildCacheServer(private val cfg: Configuration) {
"Closed connection ${ch.id().asShortText()} with ${ch.remoteAddress()}"
}
}
ch.config().setAutoRead(false)
ch.config().isAutoRead = false
val pipeline = ch.pipeline()
cfg.connection.also { conn ->
val readIdleTimeout = conn.readIdleTimeout.toMillis()
@@ -345,13 +345,14 @@ class RemoteBuildCacheServer(private val cfg: Configuration) {
maxChunkSize = cfg.connection.chunkSize
}
pipeline.addLast(HttpServerCodec(httpDecoderConfig))
pipeline.addLast(ReadTriggerDuplexHandler.NAME, ReadTriggerDuplexHandler)
pipeline.addLast(MaxRequestSizeHandler.NAME, MaxRequestSizeHandler(cfg.connection.maxRequestSize))
pipeline.addLast(HttpChunkContentCompressor(1024))
pipeline.addLast(ChunkedWriteHandler())
authenticator?.let {
pipeline.addLast(it)
}
pipeline.addLast(ThrottlingHandler(bucketManager, cfg.connection))
pipeline.addLast(ThrottlingHandler(bucketManager,cfg.rateLimiter, cfg.connection))
val serverHandler = let {
val prefix = Path.of("/").resolve(Path.of(cfg.serverPath ?: "/"))
@@ -361,7 +362,6 @@ class RemoteBuildCacheServer(private val cfg: Configuration) {
}
pipeline.addLast(ServerHandler.NAME, serverHandler)
pipeline.addLast(ExceptionHandler.NAME, ExceptionHandler)
pipeline.addLast(BlackHoleRequestHandler.NAME, BlackHoleRequestHandler())
}
override fun asyncClose() = cacheHandlerFactory.asyncClose()

View File

@@ -125,7 +125,7 @@ class FileSystemCacheHandler(
sendMessageAndFlush(ctx, LastHttpContent.EMPTY_LAST_CONTENT)
}
}
} ?: sendMessageAndFlush(ctx, CacheValueNotFoundResponse())
} ?: sendMessageAndFlush(ctx, CacheValueNotFoundResponse(key))
}
}
}

View File

@@ -11,6 +11,7 @@ import java.util.concurrent.CompletableFuture
import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantReadWriteLock
import kotlin.concurrent.withLock
import net.woggioni.rbcs.common.debug
private class CacheKey(private val value: ByteArray) {
override fun equals(other: Any?) = if (other is CacheKey) {

View File

@@ -118,7 +118,7 @@ class InMemoryCacheHandler(
} else {
sendMessage(ctx, LastCacheContent(value.content))
}
} ?: sendMessage(ctx, CacheValueNotFoundResponse())
} ?: sendMessage(ctx, CacheValueNotFoundResponse(req.request.key))
}
is InProgressPutRequest -> {
this.inProgressRequest = null

View File

@@ -33,6 +33,7 @@ object Parser {
0x4000000,
0x10000
)
var rateLimiter = Configuration.RateLimiter(false, 0x100000, 100)
var eventExecutor: Configuration.EventExecutor = Configuration.EventExecutor(true)
var cache: Cache? = null
var host = "127.0.0.1"
@@ -132,11 +133,24 @@ object Parser {
}
"event-executor" -> {
val useVirtualThread = root.renderAttribute("use-virtual-threads")
val useVirtualThread = child.renderAttribute("use-virtual-threads")
?.let(String::toBoolean) ?: true
eventExecutor = Configuration.EventExecutor(useVirtualThread)
}
"rate-limiter" -> {
val delayResponse = child.renderAttribute("delay-response")
?.let(String::toBoolean)
?: false
val messageBufferSize = child.renderAttribute("message-buffer-size")
?.let(Integer::decode)
?: 0x100000
val maxQueuedMessages = child.renderAttribute("max-queued-messages")
?.let(Integer::decode)
?: 100
rateLimiter = Configuration.RateLimiter(delayResponse, messageBufferSize, maxQueuedMessages)
}
"tls" -> {
var keyStore: KeyStore? = null
var trustStore: TrustStore? = null
@@ -184,6 +198,7 @@ object Parser {
incomingConnectionsBacklogSize,
serverPath,
eventExecutor,
rateLimiter,
connection,
users,
groups,

View File

@@ -46,6 +46,11 @@ object Serializer {
node("event-executor") {
attr("use-virtual-threads", conf.eventExecutor.isUseVirtualThreads.toString())
}
node("rate-limiter") {
attr("delay-response", conf.rateLimiter.isDelayRequest.toString())
attr("max-queued-messages", conf.rateLimiter.maxQueuedMessages.toString())
attr("message-buffer-size", conf.rateLimiter.messageBufferSize.toString())
}
val cache = conf.cache
val serializer : CacheProvider<Configuration.Cache> =
(CacheSerializers.index[cache.namespaceURI to cache.typeName] as? CacheProvider<Configuration.Cache>) ?: throw NotImplementedError()

View File

@@ -0,0 +1,36 @@
package net.woggioni.rbcs.server.handler
import io.netty.channel.ChannelDuplexHandler
import io.netty.channel.ChannelHandler.Sharable
import io.netty.channel.ChannelHandlerContext
import io.netty.channel.ChannelPromise
import io.netty.handler.codec.http.LastHttpContent
import net.woggioni.rbcs.common.createLogger
import net.woggioni.rbcs.common.debug
@Sharable
object ReadTriggerDuplexHandler : ChannelDuplexHandler() {
val NAME = ReadTriggerDuplexHandler::class.java.name
override fun handlerAdded(ctx: ChannelHandlerContext) {
ctx.read()
}
override fun channelRead(ctx: ChannelHandlerContext, msg: Any) {
super.channelRead(ctx, msg)
if(msg !is LastHttpContent) {
ctx.read()
}
}
override fun write(
ctx: ChannelHandlerContext,
msg: Any,
promise: ChannelPromise
) {
super.write(ctx, msg, promise)
if(msg is LastHttpContent) {
ctx.read()
}
}
}

View File

@@ -43,16 +43,6 @@ class ServerHandler(private val serverPrefix: Path, private val cacheHandlerSupp
private var httpVersion = HttpVersion.HTTP_1_1
private var keepAlive = true
private var pipelinedRequests = 0
private fun newRequest() {
pipelinedRequests += 1
}
private fun requestCompleted(ctx : ChannelHandlerContext) {
pipelinedRequests -= 1
if(pipelinedRequests == 0) ctx.read()
}
private fun resetRequestMetadata() {
httpVersion = HttpVersion.HTTP_1_1
@@ -74,10 +64,6 @@ class ServerHandler(private val serverPrefix: Path, private val cacheHandlerSupp
private var cacheRequestInProgress : Boolean = false
override fun handlerAdded(ctx: ChannelHandlerContext) {
ctx.read()
}
override fun channelRead(ctx: ChannelHandlerContext, msg: Any) {
when (msg) {
is HttpRequest -> handleRequest(ctx, msg)
@@ -98,18 +84,14 @@ class ServerHandler(private val serverPrefix: Path, private val cacheHandlerSupp
}
}
override fun channelReadComplete(ctx: ChannelHandlerContext) {
super.channelReadComplete(ctx)
if(cacheRequestInProgress) {
ctx.read()
}
}
override fun write(ctx: ChannelHandlerContext, msg: Any, promise: ChannelPromise?) {
if (msg is CacheMessage) {
try {
when (msg) {
is CachePutResponse -> {
log.debug(ctx) {
"Added value for key '${msg.key}' to build cache"
}
val response = DefaultFullHttpResponse(httpVersion, HttpResponseStatus.CREATED)
val keyBytes = msg.key.toByteArray(Charsets.UTF_8)
response.headers().apply {
@@ -121,21 +103,23 @@ class ServerHandler(private val serverPrefix: Path, private val cacheHandlerSupp
val buf = ctx.alloc().buffer(keyBytes.size).apply {
writeBytes(keyBytes)
}
ctx.writeAndFlush(DefaultLastHttpContent(buf)).also {
requestCompleted(ctx)
}
ctx.writeAndFlush(DefaultLastHttpContent(buf))
}
is CacheValueNotFoundResponse -> {
log.debug(ctx) {
"Value not found for key '${msg.key}'"
}
val response = DefaultFullHttpResponse(httpVersion, HttpResponseStatus.NOT_FOUND)
response.headers()[HttpHeaderNames.CONTENT_LENGTH] = 0
setKeepAliveHeader(response.headers())
ctx.writeAndFlush(response).also {
requestCompleted(ctx)
}
ctx.writeAndFlush(response)
}
is CacheValueFoundResponse -> {
log.debug(ctx) {
"Retrieved value for key '${msg.key}'"
}
val response = DefaultHttpResponse(httpVersion, HttpResponseStatus.OK)
response.headers().apply {
set(HttpHeaderNames.CONTENT_TYPE, msg.metadata.mimeType ?: HttpHeaderValues.APPLICATION_OCTET_STREAM)
@@ -149,9 +133,7 @@ class ServerHandler(private val serverPrefix: Path, private val cacheHandlerSupp
}
is LastCacheContent -> {
ctx.writeAndFlush(DefaultLastHttpContent(msg.content())).also {
requestCompleted(ctx)
}
ctx.writeAndFlush(DefaultLastHttpContent(msg.content()))
}
is CacheContent -> {
@@ -172,7 +154,6 @@ class ServerHandler(private val serverPrefix: Path, private val cacheHandlerSupp
}
} else if(msg is LastHttpContent) {
ctx.write(msg, promise)
requestCompleted(ctx)
} else super.write(ctx, msg, promise)
}
@@ -186,13 +167,13 @@ class ServerHandler(private val serverPrefix: Path, private val cacheHandlerSupp
cacheRequestInProgress = true
val relativePath = serverPrefix.relativize(path)
val key : String = relativePath.toString()
newRequest()
val cacheHandler = cacheHandlerSupplier()
ctx.pipeline().addBefore(ExceptionHandler.NAME, null, cacheHandler)
key.let(::CacheGetRequest)
.let(ctx::fireChannelRead)
?: ctx.channel().write(CacheValueNotFoundResponse())
?: ctx.channel().write(CacheValueNotFoundResponse(key))
} else {
cacheRequestInProgress = false
log.warn(ctx) {
"Got request for unhandled path '${msg.uri()}'"
}
@@ -206,12 +187,8 @@ class ServerHandler(private val serverPrefix: Path, private val cacheHandlerSupp
cacheRequestInProgress = true
val relativePath = serverPrefix.relativize(path)
val key = relativePath.toString()
log.debug(ctx) {
"Added value for key '$key' to build cache"
}
newRequest()
val cacheHandler = cacheHandlerSupplier()
ctx.pipeline().addBefore(ExceptionHandler.NAME, null, cacheHandler)
ctx.pipeline().addAfter(NAME, null, cacheHandler)
path.fileName?.toString()
?.let {
@@ -219,8 +196,9 @@ class ServerHandler(private val serverPrefix: Path, private val cacheHandlerSupp
CachePutRequest(key, CacheValueMetadata(msg.headers().get(HttpHeaderNames.CONTENT_DISPOSITION), mimeType))
}
?.let(ctx::fireChannelRead)
?: ctx.channel().write(CacheValueNotFoundResponse())
?: ctx.channel().write(CacheValueNotFoundResponse(key))
} else {
cacheRequestInProgress = false
log.warn(ctx) {
"Got request for unhandled path '${msg.uri()}'"
}
@@ -229,10 +207,11 @@ class ServerHandler(private val serverPrefix: Path, private val cacheHandlerSupp
ctx.writeAndFlush(response)
}
} else if (method == HttpMethod.TRACE) {
newRequest()
ctx.pipeline().addBefore(ExceptionHandler.NAME, null, TraceHandler)
cacheRequestInProgress = false
ctx.pipeline().addAfter(NAME, null, TraceHandler)
super.channelRead(ctx, msg)
} else {
cacheRequestInProgress = false
log.warn(ctx) {
"Got request with unhandled method '${msg.method().name()}'"
}
@@ -245,4 +224,4 @@ class ServerHandler(private val serverPrefix: Path, private val cacheHandlerSupp
override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
super.exceptionCaught(ctx, cause)
}
}
}

View File

@@ -1,32 +1,54 @@
package net.woggioni.rbcs.server.throttling
import io.netty.buffer.ByteBufHolder
import io.netty.channel.ChannelHandlerContext
import io.netty.channel.ChannelInboundHandlerAdapter
import io.netty.handler.codec.http.DefaultFullHttpResponse
import io.netty.handler.codec.http.FullHttpMessage
import io.netty.handler.codec.http.HttpContent
import io.netty.handler.codec.http.HttpHeaderNames
import io.netty.handler.codec.http.HttpRequest
import io.netty.handler.codec.http.HttpResponseStatus
import io.netty.handler.codec.http.HttpVersion
import io.netty.handler.codec.http.LastHttpContent
import java.net.InetSocketAddress
import java.time.Duration
import java.time.temporal.ChronoUnit
import java.util.ArrayDeque
import java.util.LinkedList
import java.util.concurrent.TimeUnit
import kotlin.collections.forEach
import kotlin.collections.isNotEmpty
import net.woggioni.jwo.Bucket
import net.woggioni.jwo.LongMath
import net.woggioni.rbcs.api.Configuration
import net.woggioni.rbcs.common.createLogger
import net.woggioni.rbcs.common.debug
import net.woggioni.rbcs.server.RemoteBuildCacheServer
import java.net.InetSocketAddress
import java.time.Duration
import java.time.temporal.ChronoUnit
import java.util.concurrent.TimeUnit
class ThrottlingHandler(private val bucketManager : BucketManager,
private val connectionConfiguration : Configuration.Connection) : ChannelInboundHandlerAdapter() {
class ThrottlingHandler(
private val bucketManager: BucketManager,
rateLimiterConfiguration: Configuration.RateLimiter,
connectionConfiguration: Configuration.Connection
) : ChannelInboundHandlerAdapter() {
private companion object {
private val log = createLogger<ThrottlingHandler>()
fun nextAttemptIsWithinThreshold(nextAttemptNanos : Long, waitThreshold : Duration) : Boolean {
val waitDuration = Duration.of(LongMath.ceilDiv(nextAttemptNanos, 100_000_000L) * 100L, ChronoUnit.MILLIS)
return waitDuration < waitThreshold
}
}
private var queuedContent : MutableList<HttpContent>? = null
private class RefusedRequest
private val maxMessageBufferSize = rateLimiterConfiguration.messageBufferSize
private val maxQueuedMessages = rateLimiterConfiguration.maxQueuedMessages
private val delayRequests = rateLimiterConfiguration.isDelayRequest
private var requestBufferSize : Int = 0
private var valveClosed = false
private var queuedContent = ArrayDeque<Any>()
/**
* If the suggested waiting time from the bucket is lower than this
@@ -39,38 +61,155 @@ class ThrottlingHandler(private val bucketManager : BucketManager,
connectionConfiguration.writeIdleTimeout
).dividedBy(2)
override fun channelRead(ctx: ChannelHandlerContext, msg: Any) {
if(msg is HttpRequest) {
val buckets = mutableListOf<Bucket>()
val user = ctx.channel().attr(RemoteBuildCacheServer.userAttribute).get()
if (user != null) {
bucketManager.getBucketByUser(user)?.let(buckets::addAll)
}
val groups = ctx.channel().attr(RemoteBuildCacheServer.groupAttribute).get() ?: emptySet()
if (groups.isNotEmpty()) {
groups.forEach { group ->
bucketManager.getBucketByGroup(group)?.let(buckets::add)
if(valveClosed) {
if(msg !is HttpRequest && msg is ByteBufHolder) {
val newBufferSize = requestBufferSize + msg.content().readableBytes()
if(newBufferSize > maxMessageBufferSize || queuedContent.size + 1 > maxQueuedMessages) {
log.debug {
if (newBufferSize > maxMessageBufferSize) {
"New message part exceeds maxMessageBufferSize, removing previous chunks"
} else {
"New message part exceeds maxQueuedMessages, removing previous chunks"
}
}
// If this message overflows the maxMessageBufferSize,
// then remove the previously enqueued chunks of the request from the deque,
// then discard the message
while(true) {
val tail = queuedContent.last()
if(tail is ByteBufHolder) {
requestBufferSize -= tail.content().readableBytes()
tail.release()
}
queuedContent.removeLast()
if(tail is HttpRequest) {
break
}
}
msg.release()
//Add a placeholder to remember to return a 429 response corresponding to this request
queuedContent.addLast(RefusedRequest())
} else {
//If the message does not overflow maxMessageBufferSize, just add it to the deque
queuedContent.addLast(msg)
requestBufferSize = newBufferSize
}
} else if(msg is HttpRequest && msg is FullHttpMessage){
val newBufferSize = requestBufferSize + msg.content().readableBytes()
// If this message overflows the maxMessageBufferSize,
// discard the message
if(newBufferSize > maxMessageBufferSize || queuedContent.size + 1 > maxQueuedMessages) {
log.debug {
if (newBufferSize > maxMessageBufferSize) {
"New message exceeds maxMessageBufferSize, discarding it"
} else {
"New message exceeds maxQueuedMessages, discarding it"
}
}
msg.release()
//Add a placeholder to remember to return a 429 response corresponding to this request
queuedContent.addLast(RefusedRequest())
} else {
//If the message does not exceed maxMessageBufferSize or maxQueuedMessages, just add it to the deque
queuedContent.addLast(msg)
requestBufferSize = newBufferSize
}
}
if (user == null && groups.isEmpty()) {
bucketManager.getBucketByAddress(ctx.channel().remoteAddress() as InetSocketAddress)?.let(buckets::add)
}
if (buckets.isEmpty()) {
super.channelRead(ctx, msg)
} else {
handleBuckets(buckets, ctx, msg, true)
queuedContent.addLast(msg)
}
ctx.channel().id()
} else if(msg is HttpContent) {
queuedContent?.add(msg) ?: super.channelRead(ctx, msg)
} else {
super.channelRead(ctx, msg)
entryPoint(ctx, msg)
}
}
private fun handleBuckets(buckets: List<Bucket>, ctx: ChannelHandlerContext, msg: Any, delayResponse: Boolean) {
private fun entryPoint(ctx : ChannelHandlerContext, msg : Any) {
if(msg is RefusedRequest) {
sendThrottledResponse(ctx, null)
if(queuedContent.isEmpty()) {
valveClosed = false
} else {
val head = queuedContent.poll()
if(head is ByteBufHolder) {
requestBufferSize -= head.content().readableBytes()
}
entryPoint(ctx, head)
}
} else if(msg is HttpRequest) {
val nextAttempt = getNextAttempt(ctx)
if (nextAttempt < 0) {
super.channelRead(ctx, msg)
if(msg !is LastHttpContent) {
while (true) {
val head = queuedContent.poll() ?: break
if(head is ByteBufHolder) {
requestBufferSize -= head.content().readableBytes()
}
super.channelRead(ctx, head)
if (head is LastHttpContent) break
}
}
log.debug {
"Queue size: ${queuedContent.stream().filter { it !is RefusedRequest }.count()}"
}
if(queuedContent.isEmpty()) {
valveClosed = false
} else {
val head = queuedContent.poll()
if(head is ByteBufHolder) {
requestBufferSize -= head.content().readableBytes()
}
entryPoint(ctx, head)
}
} else {
val waitDuration = Duration.of(LongMath.ceilDiv(nextAttempt, 100_000_000L) * 100L, ChronoUnit.MILLIS)
if (delayRequests && nextAttemptIsWithinThreshold(nextAttempt, waitThreshold)) {
valveClosed = true
ctx.executor().schedule({
entryPoint(ctx, msg)
}, waitDuration.toMillis(), TimeUnit.MILLISECONDS)
} else {
sendThrottledResponse(ctx, waitDuration)
if(queuedContent.isEmpty()) {
valveClosed = false
} else {
val head = queuedContent.poll()
if(head is ByteBufHolder) {
requestBufferSize -= head.content().readableBytes()
}
entryPoint(ctx, head)
}
}
}
} else {
super.channelRead(ctx, msg)
log.debug {
"Queue size: ${queuedContent.stream().filter { it !is RefusedRequest }.count()}"
}
}
}
/**
* Returns the number amount of milliseconds to wait before the requests can be processed
* or -1 if the request can be performed immediately
*/
private fun getNextAttempt(ctx : ChannelHandlerContext) : Long {
val buckets = mutableListOf<Bucket>()
val user = ctx.channel().attr(RemoteBuildCacheServer.userAttribute).get()
if (user != null) {
bucketManager.getBucketByUser(user)?.let(buckets::addAll)
}
val groups = ctx.channel().attr(RemoteBuildCacheServer.groupAttribute).get() ?: emptySet()
if (groups.isNotEmpty()) {
groups.forEach { group ->
bucketManager.getBucketByGroup(group)?.let(buckets::add)
}
}
if (user == null && groups.isEmpty()) {
bucketManager.getBucketByAddress(ctx.channel().remoteAddress() as InetSocketAddress)?.let(buckets::add)
}
var nextAttempt = -1L
for (bucket in buckets) {
val bucketNextAttempt = bucket.removeTokensWithEstimate(1)
@@ -78,41 +217,19 @@ class ThrottlingHandler(private val bucketManager : BucketManager,
nextAttempt = bucketNextAttempt
}
}
if (nextAttempt < 0) {
super.channelRead(ctx, msg)
queuedContent?.let {
for(content in it) {
super.channelRead(ctx, content)
}
queuedContent = null
}
} else {
val waitDuration = Duration.of(LongMath.ceilDiv(nextAttempt, 100_000_000L) * 100L, ChronoUnit.MILLIS)
if (delayResponse && waitDuration < waitThreshold) {
this.queuedContent = mutableListOf()
ctx.executor().schedule({
handleBuckets(buckets, ctx, msg, false)
}, waitDuration.toMillis(), TimeUnit.MILLISECONDS)
} else {
queuedContent?.let { qc ->
qc.forEach { it.release() }
}
this.queuedContent = null
sendThrottledResponse(ctx, waitDuration)
}
}
return nextAttempt
}
private fun sendThrottledResponse(ctx: ChannelHandlerContext, retryAfter: Duration) {
private fun sendThrottledResponse(ctx: ChannelHandlerContext, retryAfter: Duration?) {
val response = DefaultFullHttpResponse(
HttpVersion.HTTP_1_1,
HttpResponseStatus.TOO_MANY_REQUESTS
)
response.headers()[HttpHeaderNames.CONTENT_LENGTH] = 0
retryAfter.seconds.takeIf {
retryAfter?.seconds?.takeIf {
it > 0
}?.let {
response.headers()[HttpHeaderNames.RETRY_AFTER] = retryAfter.seconds
response.headers()[HttpHeaderNames.RETRY_AFTER] = it
}
ctx.writeAndFlush(response)

View File

@@ -16,6 +16,7 @@
<xs:element name="bind" type="rbcs:bindType" maxOccurs="1"/>
<xs:element name="connection" type="rbcs:connectionType" minOccurs="0" maxOccurs="1"/>
<xs:element name="event-executor" type="rbcs:eventExecutorType" minOccurs="0" maxOccurs="1"/>
<xs:element name="rate-limiter" type="rbcs:rateLimiterType" minOccurs="0" maxOccurs="1"/>
<xs:element name="cache" type="rbcs:cacheType" maxOccurs="1">
<xs:annotation>
<xs:documentation>
@@ -136,6 +137,37 @@
</xs:attribute>
</xs:complexType>
<xs:complexType name="rateLimiterType">
<xs:attribute name="delay-response" type="xs:boolean" use="optional" default="false">
<xs:annotation>
<xs:documentation>
If set to true, the server will delay responses to meet user quotas, otherwise it will simply
return an immediate 429 status code to all requests that exceed the configured quota
</xs:documentation>
</xs:annotation>
</xs:attribute>
<xs:attribute name="max-queued-messages" type="xs:nonNegativeInteger" use="optional" default="100">
<xs:annotation>
<xs:documentation>
Only meaningful when "delay-response" is set to "true",
when a request is delayed, it and all the following messages are queued
as long as "max-queued-messages" is not crossed, all requests that would exceed the
max-queued-message limit are instead discarded and responded with a 429 status code
</xs:documentation>
</xs:annotation>
</xs:attribute>
<xs:attribute name="message-buffer-size" type="rbcs:byteSizeType" use="optional" default="0x100000">
<xs:annotation>
<xs:documentation>
Only meaningful when "delay-response" is set to "true",
when a request is delayed, it and all the following requests are buffered
as long as "message-buffer-size" is not crossed, all requests that would exceed the buffer
size are instead discarded and responded with a 429 status code
</xs:documentation>
</xs:annotation>
</xs:attribute>
</xs:complexType>
<xs:complexType name="cacheType" abstract="true"/>
<xs:complexType name="inMemoryCacheType">

View File

@@ -37,6 +37,7 @@ abstract class AbstractBasicAuthServerTest : AbstractServerTest() {
50,
serverPath,
Configuration.EventExecutor(false),
Configuration.RateLimiter(true, 0x100000, 50),
Configuration.Connection(
Duration.of(60, ChronoUnit.SECONDS),
Duration.of(30, ChronoUnit.SECONDS),

View File

@@ -143,6 +143,7 @@ abstract class AbstractTlsServerTest : AbstractServerTest() {
100,
serverPath,
Configuration.EventExecutor(false),
Configuration.RateLimiter(true, 0x100000, 50),
Configuration.Connection(
Duration.of(60, ChronoUnit.SECONDS),
Duration.of(30, ChronoUnit.SECONDS),

View File

@@ -37,6 +37,7 @@ class NoAuthServerTest : AbstractServerTest() {
100,
serverPath,
Configuration.EventExecutor(false),
Configuration.RateLimiter(true, 0x100000, 50),
Configuration.Connection(
Duration.of(60, ChronoUnit.SECONDS),
Duration.of(30, ChronoUnit.SECONDS),

View File

@@ -10,6 +10,7 @@
max-request-size="101325"
chunk-size="0xa910"/>
<event-executor use-virtual-threads="false"/>
<rate-limiter delay-response="false" message-buffer-size="0x1234" max-queued-messages="13"/>
<cache xs:type="rbcs:fileSystemCacheType" path="/tmp/rbcs" max-age="P7D"/>
<authentication>
<none/>

View File

@@ -12,6 +12,7 @@
write-idle-timeout="PT60S"
chunk-size="123"/>
<event-executor use-virtual-threads="true"/>
<rate-limiter delay-response="false" message-buffer-size="12000" max-queued-messages="53"/>
<cache xs:type="rbcs-memcache:memcacheCacheType" max-age="P7D">
<server host="memcached" port="11211"/>
</cache>

View File

@@ -11,6 +11,7 @@
max-request-size="101325"
chunk-size="456"/>
<event-executor use-virtual-threads="false"/>
<rate-limiter delay-response="true" message-buffer-size="65432" max-queued-messages="21"/>
<cache xs:type="rbcs-memcache:memcacheCacheType" max-age="P7D" digest="SHA-256" compression-mode="deflate" compression-level="7">
<server host="127.0.0.1" port="11211" max-connections="10" connection-timeout="PT20S"/>
</cache>