Compare commits

...

1 Commits

Author SHA1 Message Date
c3c4bbe5e2 fixed throttling retry-after estimation
All checks were successful
CI / build (push) Successful in 3m38s
2025-01-25 00:39:31 +08:00
2 changed files with 45 additions and 17 deletions

View File

@@ -11,8 +11,10 @@ import net.woggioni.gbcs.api.Configuration
import net.woggioni.gbcs.common.contextLogger import net.woggioni.gbcs.common.contextLogger
import net.woggioni.gbcs.server.GradleBuildCacheServer import net.woggioni.gbcs.server.GradleBuildCacheServer
import net.woggioni.jwo.Bucket import net.woggioni.jwo.Bucket
import net.woggioni.jwo.LongMath
import java.net.InetSocketAddress import java.net.InetSocketAddress
import java.time.Duration import java.time.Duration
import java.time.temporal.ChronoUnit
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
@@ -54,25 +56,31 @@ class ThrottlingHandler(cfg: Configuration) :
if (buckets.isEmpty()) { if (buckets.isEmpty()) {
return super.channelRead(ctx, msg) return super.channelRead(ctx, msg)
} else { } else {
var nextAttempt = Long.MAX_VALUE handleBuckets(buckets, ctx, msg, true)
}
}
private fun handleBuckets(buckets : List<Bucket>, ctx : ChannelHandlerContext, msg : Any, delayResponse : Boolean) {
var nextAttempt = -1L
for (bucket in buckets) { for (bucket in buckets) {
val bucketNextAttempt = bucket.removeTokensWithEstimate(1) val bucketNextAttempt = bucket.removeTokensWithEstimate(1)
if (bucketNextAttempt < 0) { if (bucketNextAttempt > nextAttempt) {
return super.channelRead(ctx, msg)
} else if (bucketNextAttempt < nextAttempt) {
nextAttempt = bucketNextAttempt nextAttempt = bucketNextAttempt
} }
} }
val waitDuration = Duration.ofNanos(nextAttempt) if(nextAttempt < 0) {
if (waitDuration < waitThreshold) { super.channelRead(ctx, msg)
return
}
val waitDuration = Duration.of(LongMath.ceilDiv(nextAttempt, 100_000_000L) * 100L, ChronoUnit.MILLIS)
if (delayResponse && waitDuration < waitThreshold) {
ctx.executor().schedule({ ctx.executor().schedule({
ctx.fireChannelRead(msg) handleBuckets(buckets, ctx, msg, false)
}, waitDuration.toNanos(), TimeUnit.NANOSECONDS) }, waitDuration.toMillis(), TimeUnit.MILLISECONDS)
} else { } else {
sendThrottledResponse(ctx, waitDuration) sendThrottledResponse(ctx, waitDuration)
} }
} }
}
private fun sendThrottledResponse(ctx: ChannelHandlerContext, retryAfter: Duration) { private fun sendThrottledResponse(ctx: ChannelHandlerContext, retryAfter: Duration) {
val response = DefaultFullHttpResponse( val response = DefaultFullHttpResponse(
@@ -80,7 +88,12 @@ class ThrottlingHandler(cfg: Configuration) :
HttpResponseStatus.TOO_MANY_REQUESTS HttpResponseStatus.TOO_MANY_REQUESTS
) )
response.headers()[HttpHeaderNames.CONTENT_LENGTH] = 0 response.headers()[HttpHeaderNames.CONTENT_LENGTH] = 0
retryAfter.seconds.takeIf {
it > 0
}?.let {
response.headers()[HttpHeaderNames.RETRY_AFTER] = retryAfter.seconds response.headers()[HttpHeaderNames.RETRY_AFTER] = retryAfter.seconds
}
ctx.writeAndFlush(response) ctx.writeAndFlush(response)
} }
} }

View File

@@ -133,4 +133,19 @@ class TlsServerTest : AbstractTlsServerTest() {
val response: HttpResponse<String> = client.send(requestBuilder.build(), HttpResponse.BodyHandlers.ofString()) val response: HttpResponse<String> = client.send(requestBuilder.build(), HttpResponse.BodyHandlers.ofString())
Assertions.assertEquals(HttpResponseStatus.FORBIDDEN.code(), response.statusCode()) Assertions.assertEquals(HttpResponseStatus.FORBIDDEN.code(), response.statusCode())
} }
@Test
@Order(8)
fun traceAsAnonymousUser() {
val client: HttpClient = getHttpClient(null)
val requestBuilder = newRequestBuilder("").method(
"TRACE",
HttpRequest.BodyPublishers.ofByteArray("sfgsdgfaiousfiuhsd".toByteArray())
)
val response: HttpResponse<ByteArray> =
client.send(requestBuilder.build(), HttpResponse.BodyHandlers.ofByteArray())
Assertions.assertEquals(HttpResponseStatus.OK.code(), response.statusCode())
println(String(response.body()))
}
} }