diff --git a/gradle.properties b/gradle.properties index 04d7e57..c3bb330 100644 --- a/gradle.properties +++ b/gradle.properties @@ -2,7 +2,7 @@ org.gradle.configuration-cache=false org.gradle.parallel=true org.gradle.caching=true -rbcs.version = 0.1.6 +rbcs.version = 0.2.0 lys.version = 2025.02.08 diff --git a/rbcs-api/build.gradle b/rbcs-api/build.gradle index ac6a484..a99ec9a 100644 --- a/rbcs-api/build.gradle +++ b/rbcs-api/build.gradle @@ -6,6 +6,7 @@ plugins { dependencies { api catalog.netty.buffer + api catalog.netty.handler } publishing { diff --git a/rbcs-api/src/main/java/module-info.java b/rbcs-api/src/main/java/module-info.java index 42abdf2..78cdc90 100644 --- a/rbcs-api/src/main/java/module-info.java +++ b/rbcs-api/src/main/java/module-info.java @@ -2,7 +2,9 @@ module net.woggioni.rbcs.api { requires static lombok; requires java.xml; requires io.netty.buffer; + requires io.netty.handler; + requires io.netty.transport; exports net.woggioni.rbcs.api; exports net.woggioni.rbcs.api.exception; - exports net.woggioni.rbcs.api.event; + exports net.woggioni.rbcs.api.message; } \ No newline at end of file diff --git a/rbcs-api/src/main/java/net/woggioni/rbcs/api/Cache.java b/rbcs-api/src/main/java/net/woggioni/rbcs/api/Cache.java deleted file mode 100644 index 06af410..0000000 --- a/rbcs-api/src/main/java/net/woggioni/rbcs/api/Cache.java +++ /dev/null @@ -1,17 +0,0 @@ -package net.woggioni.rbcs.api; - -import io.netty.buffer.ByteBufAllocator; - -import java.util.concurrent.CompletableFuture; - - -public interface Cache extends AutoCloseable { - - default void get(String key, ResponseHandle responseHandle, ByteBufAllocator alloc) { - throw new UnsupportedOperationException(); - } - - default CompletableFuture put(String key, ResponseHandle responseHandle, ByteBufAllocator alloc) { - throw new UnsupportedOperationException(); - } -} diff --git a/rbcs-api/src/main/java/net/woggioni/rbcs/api/CacheHandlerFactory.java b/rbcs-api/src/main/java/net/woggioni/rbcs/api/CacheHandlerFactory.java new file mode 100644 index 0000000..aeda6a5 --- /dev/null +++ b/rbcs-api/src/main/java/net/woggioni/rbcs/api/CacheHandlerFactory.java @@ -0,0 +1,7 @@ +package net.woggioni.rbcs.api; + +import io.netty.channel.ChannelHandler; + +public interface CacheHandlerFactory extends AutoCloseable { + ChannelHandler newHandler(); +} diff --git a/rbcs-api/src/main/java/net/woggioni/rbcs/api/CacheValueMetadata.java b/rbcs-api/src/main/java/net/woggioni/rbcs/api/CacheValueMetadata.java new file mode 100644 index 0000000..dce5407 --- /dev/null +++ b/rbcs-api/src/main/java/net/woggioni/rbcs/api/CacheValueMetadata.java @@ -0,0 +1,14 @@ +package net.woggioni.rbcs.api; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; + +import java.io.Serializable; + +@Getter +@RequiredArgsConstructor +public class CacheValueMetadata implements Serializable { + private final String contentDisposition; + private final String mimeType; +} + diff --git a/rbcs-api/src/main/java/net/woggioni/rbcs/api/Configuration.java b/rbcs-api/src/main/java/net/woggioni/rbcs/api/Configuration.java index 080c65c..1f492d7 100644 --- a/rbcs-api/src/main/java/net/woggioni/rbcs/api/Configuration.java +++ b/rbcs-api/src/main/java/net/woggioni/rbcs/api/Configuration.java @@ -1,6 +1,7 @@ package net.woggioni.rbcs.api; +import io.netty.channel.ChannelInboundHandler; import lombok.EqualsAndHashCode; import lombok.NonNull; import lombok.Value; @@ -10,6 +11,7 @@ import java.security.cert.X509Certificate; import java.time.Duration; import java.util.Map; import java.util.Set; +import java.util.function.Supplier; import java.util.stream.Collectors; @Value @@ -135,7 +137,7 @@ public class Configuration { } public interface Cache { - net.woggioni.rbcs.api.Cache materialize(); + CacheHandlerFactory materialize(); String getNamespaceURI(); String getTypeName(); } diff --git a/rbcs-api/src/main/java/net/woggioni/rbcs/api/RequestHandle.java b/rbcs-api/src/main/java/net/woggioni/rbcs/api/RequestHandle.java deleted file mode 100644 index 7d3c43d..0000000 --- a/rbcs-api/src/main/java/net/woggioni/rbcs/api/RequestHandle.java +++ /dev/null @@ -1,8 +0,0 @@ -package net.woggioni.rbcs.api; - -import net.woggioni.rbcs.api.event.RequestStreamingEvent; - -@FunctionalInterface -public interface RequestHandle { - void handleEvent(RequestStreamingEvent evt); -} diff --git a/rbcs-api/src/main/java/net/woggioni/rbcs/api/ResponseHandle.java b/rbcs-api/src/main/java/net/woggioni/rbcs/api/ResponseHandle.java deleted file mode 100644 index fa6d285..0000000 --- a/rbcs-api/src/main/java/net/woggioni/rbcs/api/ResponseHandle.java +++ /dev/null @@ -1,8 +0,0 @@ -package net.woggioni.rbcs.api; - -import net.woggioni.rbcs.api.event.ResponseStreamingEvent; - -@FunctionalInterface -public interface ResponseHandle { - void handleEvent(ResponseStreamingEvent evt); -} diff --git a/rbcs-api/src/main/java/net/woggioni/rbcs/api/event/RequestStreamingEvent.java b/rbcs-api/src/main/java/net/woggioni/rbcs/api/event/RequestStreamingEvent.java deleted file mode 100644 index 46fc534..0000000 --- a/rbcs-api/src/main/java/net/woggioni/rbcs/api/event/RequestStreamingEvent.java +++ /dev/null @@ -1,26 +0,0 @@ -package net.woggioni.rbcs.api.event; - -import io.netty.buffer.ByteBuf; -import lombok.Getter; -import lombok.RequiredArgsConstructor; - -public sealed interface RequestStreamingEvent { - - @Getter - @RequiredArgsConstructor - non-sealed class ChunkReceived implements RequestStreamingEvent { - private final ByteBuf chunk; - } - - final class LastChunkReceived extends ChunkReceived { - public LastChunkReceived(ByteBuf chunk) { - super(chunk); - } - } - - @Getter - @RequiredArgsConstructor - final class ExceptionCaught implements RequestStreamingEvent { - private final Throwable exception; - } -} diff --git a/rbcs-api/src/main/java/net/woggioni/rbcs/api/event/ResponseStreamingEvent.java b/rbcs-api/src/main/java/net/woggioni/rbcs/api/event/ResponseStreamingEvent.java deleted file mode 100644 index cced1e3..0000000 --- a/rbcs-api/src/main/java/net/woggioni/rbcs/api/event/ResponseStreamingEvent.java +++ /dev/null @@ -1,42 +0,0 @@ -package net.woggioni.rbcs.api.event; - -import io.netty.buffer.ByteBuf; -import lombok.Getter; -import lombok.RequiredArgsConstructor; - -import java.nio.channels.FileChannel; - -public sealed interface ResponseStreamingEvent { - - final class ResponseReceived implements ResponseStreamingEvent { - } - - @Getter - @RequiredArgsConstructor - non-sealed class ChunkReceived implements ResponseStreamingEvent { - private final ByteBuf chunk; - } - - @Getter - @RequiredArgsConstructor - non-sealed class FileReceived implements ResponseStreamingEvent { - private final FileChannel file; - } - - final class LastChunkReceived extends ChunkReceived { - public LastChunkReceived(ByteBuf chunk) { - super(chunk); - } - } - - @Getter - @RequiredArgsConstructor - final class ExceptionCaught implements ResponseStreamingEvent { - private final Throwable exception; - } - - final class NotFound implements ResponseStreamingEvent { } - - NotFound NOT_FOUND = new NotFound(); - ResponseReceived RESPONSE_RECEIVED = new ResponseReceived(); -} diff --git a/rbcs-api/src/main/java/net/woggioni/rbcs/api/message/CacheMessage.java b/rbcs-api/src/main/java/net/woggioni/rbcs/api/message/CacheMessage.java new file mode 100644 index 0000000..08e20d1 --- /dev/null +++ b/rbcs-api/src/main/java/net/woggioni/rbcs/api/message/CacheMessage.java @@ -0,0 +1,161 @@ +package net.woggioni.rbcs.api.message; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import net.woggioni.rbcs.api.CacheValueMetadata; + +public sealed interface CacheMessage { + + @Getter + @RequiredArgsConstructor + final class CacheGetRequest implements CacheMessage { + private final String key; + } + + abstract sealed class CacheGetResponse implements CacheMessage { + } + + @Getter + @RequiredArgsConstructor + final class CacheValueFoundResponse extends CacheGetResponse { + private final String key; + private final CacheValueMetadata metadata; + } + + final class CacheValueNotFoundResponse extends CacheGetResponse { + } + + @Getter + @RequiredArgsConstructor + final class CachePutRequest implements CacheMessage { + private final String key; + private final CacheValueMetadata metadata; + } + + @Getter + @RequiredArgsConstructor + final class CachePutResponse implements CacheMessage { + private final String key; + } + + @RequiredArgsConstructor + non-sealed class CacheContent implements CacheMessage, ByteBufHolder { + protected final ByteBuf chunk; + + @Override + public ByteBuf content() { + return chunk; + } + + @Override + public CacheContent copy() { + return replace(chunk.copy()); + } + + @Override + public CacheContent duplicate() { + return new CacheContent(chunk.duplicate()); + } + + @Override + public CacheContent retainedDuplicate() { + return new CacheContent(chunk.retainedDuplicate()); + } + + @Override + public CacheContent replace(ByteBuf content) { + return new CacheContent(content); + } + + @Override + public CacheContent retain() { + chunk.retain(); + return this; + } + + @Override + public CacheContent retain(int increment) { + chunk.retain(increment); + return this; + } + + @Override + public CacheContent touch() { + chunk.touch(); + return this; + } + + @Override + public CacheContent touch(Object hint) { + chunk.touch(hint); + return this; + } + + @Override + public int refCnt() { + return chunk.refCnt(); + } + + @Override + public boolean release() { + return chunk.release(); + } + + @Override + public boolean release(int decrement) { + return chunk.release(decrement); + } + } + + final class LastCacheContent extends CacheContent { + public LastCacheContent(ByteBuf chunk) { + super(chunk); + } + + @Override + public LastCacheContent copy() { + return replace(chunk.copy()); + } + + @Override + public LastCacheContent duplicate() { + return new LastCacheContent(chunk.duplicate()); + } + + @Override + public LastCacheContent retainedDuplicate() { + return new LastCacheContent(chunk.retainedDuplicate()); + } + + @Override + public LastCacheContent replace(ByteBuf content) { + return new LastCacheContent(chunk); + } + + @Override + public LastCacheContent retain() { + super.retain(); + return this; + } + + @Override + public LastCacheContent retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public LastCacheContent touch() { + super.touch(); + return this; + } + + @Override + public LastCacheContent touch(Object hint) { + super.touch(hint); + return this; + } + } +} diff --git a/rbcs-cli/src/main/kotlin/net/woggioni/rbcs/cli/impl/commands/BenchmarkCommand.kt b/rbcs-cli/src/main/kotlin/net/woggioni/rbcs/cli/impl/commands/BenchmarkCommand.kt index 47e237b..47f3ad0 100644 --- a/rbcs-cli/src/main/kotlin/net/woggioni/rbcs/cli/impl/commands/BenchmarkCommand.kt +++ b/rbcs-cli/src/main/kotlin/net/woggioni/rbcs/cli/impl/commands/BenchmarkCommand.kt @@ -1,17 +1,20 @@ package net.woggioni.rbcs.cli.impl.commands -import net.woggioni.rbcs.cli.impl.RbcsCommand -import net.woggioni.rbcs.client.RemoteBuildCacheClient -import net.woggioni.rbcs.common.contextLogger -import net.woggioni.rbcs.common.error -import net.woggioni.rbcs.common.info import net.woggioni.jwo.JWO import net.woggioni.jwo.LongMath +import net.woggioni.rbcs.api.CacheValueMetadata +import net.woggioni.rbcs.cli.impl.RbcsCommand +import net.woggioni.rbcs.cli.impl.converters.ByteSizeConverter +import net.woggioni.rbcs.client.RemoteBuildCacheClient +import net.woggioni.rbcs.common.createLogger import net.woggioni.rbcs.common.debug +import net.woggioni.rbcs.common.error +import net.woggioni.rbcs.common.info import picocli.CommandLine import java.security.SecureRandom import java.time.Duration import java.time.Instant +import java.time.temporal.ChronoUnit import java.util.concurrent.LinkedBlockingQueue import java.util.concurrent.Semaphore import java.util.concurrent.atomic.AtomicLong @@ -23,7 +26,9 @@ import kotlin.random.Random showDefaultValues = true ) class BenchmarkCommand : RbcsCommand() { - private val log = contextLogger() + companion object{ + private val log = createLogger() + } @CommandLine.Spec private lateinit var spec: CommandLine.Model.CommandSpec @@ -38,10 +43,17 @@ class BenchmarkCommand : RbcsCommand() { @CommandLine.Option( names = ["-s", "--size"], description = ["Size of a cache value in bytes"], - paramLabel = "SIZE" + paramLabel = "SIZE", + converter = [ByteSizeConverter::class] ) private var size = 0x1000 + @CommandLine.Option( + names = ["-r", "--random"], + description = ["Insert completely random byte values"] + ) + private var randomValues = false + override fun run() { val clientCommand = spec.parent().userObject() as ClientCommand val profile = clientCommand.profileName.let { profileName -> @@ -55,8 +67,12 @@ class BenchmarkCommand : RbcsCommand() { val random = Random(SecureRandom.getInstance("NativePRNGNonBlocking").nextLong()) while (true) { val key = JWO.bytesToHex(random.nextBytes(16)) - val content = random.nextInt().toByte() - val value = ByteArray(size, { _ -> content }) + val value = if(randomValues) { + random.nextBytes(size) + } else { + val byteValue = random.nextInt().toByte() + ByteArray(size) {_ -> byteValue} + } yield(key to value) } } @@ -68,13 +84,13 @@ class BenchmarkCommand : RbcsCommand() { val completionCounter = AtomicLong(0) val completionQueue = LinkedBlockingQueue>(numberOfEntries) val start = Instant.now() - val semaphore = Semaphore(profile.maxConnections * 3) + val semaphore = Semaphore(profile.maxConnections * 5) val iterator = entryGenerator.take(numberOfEntries).iterator() while (completionCounter.get() < numberOfEntries) { if (iterator.hasNext()) { val entry = iterator.next() semaphore.acquire() - val future = client.put(entry.first, entry.second).thenApply { entry } + val future = client.put(entry.first, entry.second, CacheValueMetadata(null, null)).thenApply { entry } future.whenComplete { result, ex -> if (ex != null) { log.error(ex.message, ex) @@ -90,7 +106,7 @@ class BenchmarkCommand : RbcsCommand() { } } } else { - Thread.sleep(0) + Thread.sleep(Duration.of(500, ChronoUnit.MILLIS)) } } @@ -111,12 +127,13 @@ class BenchmarkCommand : RbcsCommand() { } if (entries.isNotEmpty()) { val completionCounter = AtomicLong(0) - val semaphore = Semaphore(profile.maxConnections * 3) + val semaphore = Semaphore(profile.maxConnections * 5) val start = Instant.now() val it = entries.iterator() while (completionCounter.get() < entries.size) { if (it.hasNext()) { val entry = it.next() + semaphore.acquire() val future = client.get(entry.first).thenApply { if (it == null) { log.error { @@ -138,7 +155,7 @@ class BenchmarkCommand : RbcsCommand() { semaphore.release() } } else { - Thread.sleep(0) + Thread.sleep(Duration.of(500, ChronoUnit.MILLIS)) } } val end = Instant.now() diff --git a/rbcs-cli/src/main/kotlin/net/woggioni/rbcs/cli/impl/commands/GetCommand.kt b/rbcs-cli/src/main/kotlin/net/woggioni/rbcs/cli/impl/commands/GetCommand.kt index 9c23d8e..2aeae17 100644 --- a/rbcs-cli/src/main/kotlin/net/woggioni/rbcs/cli/impl/commands/GetCommand.kt +++ b/rbcs-cli/src/main/kotlin/net/woggioni/rbcs/cli/impl/commands/GetCommand.kt @@ -2,7 +2,7 @@ package net.woggioni.rbcs.cli.impl.commands import net.woggioni.rbcs.cli.impl.RbcsCommand import net.woggioni.rbcs.client.RemoteBuildCacheClient -import net.woggioni.rbcs.common.contextLogger +import net.woggioni.rbcs.common.createLogger import picocli.CommandLine import java.nio.file.Files import java.nio.file.Path @@ -13,7 +13,9 @@ import java.nio.file.Path showDefaultValues = true ) class GetCommand : RbcsCommand() { - private val log = contextLogger() + companion object{ + private val log = createLogger() + } @CommandLine.Spec private lateinit var spec: CommandLine.Model.CommandSpec diff --git a/rbcs-cli/src/main/kotlin/net/woggioni/rbcs/cli/impl/commands/HealthCheckCommand.kt b/rbcs-cli/src/main/kotlin/net/woggioni/rbcs/cli/impl/commands/HealthCheckCommand.kt index d14ea6a..6b0bbab 100644 --- a/rbcs-cli/src/main/kotlin/net/woggioni/rbcs/cli/impl/commands/HealthCheckCommand.kt +++ b/rbcs-cli/src/main/kotlin/net/woggioni/rbcs/cli/impl/commands/HealthCheckCommand.kt @@ -3,6 +3,7 @@ package net.woggioni.rbcs.cli.impl.commands import net.woggioni.rbcs.cli.impl.RbcsCommand import net.woggioni.rbcs.client.RemoteBuildCacheClient import net.woggioni.rbcs.common.contextLogger +import net.woggioni.rbcs.common.createLogger import picocli.CommandLine import java.security.SecureRandom import kotlin.random.Random @@ -13,7 +14,9 @@ import kotlin.random.Random showDefaultValues = true ) class HealthCheckCommand : RbcsCommand() { - private val log = contextLogger() + companion object{ + private val log = createLogger() + } @CommandLine.Spec private lateinit var spec: CommandLine.Model.CommandSpec @@ -32,11 +35,12 @@ class HealthCheckCommand : RbcsCommand() { if(value == null) { throw IllegalStateException("Empty response from server") } + val offset = value.size - nonce.size for(i in 0 until nonce.size) { - for(j in value.size - nonce.size until nonce.size) { - if(nonce[i] != value[j]) { - throw IllegalStateException("Server nonce does not match") - } + val a = nonce[i] + val b = value[offset + i] + if(a != b) { + throw IllegalStateException("Server nonce does not match") } } }.get() diff --git a/rbcs-cli/src/main/kotlin/net/woggioni/rbcs/cli/impl/commands/PutCommand.kt b/rbcs-cli/src/main/kotlin/net/woggioni/rbcs/cli/impl/commands/PutCommand.kt index 6f39748..e201674 100644 --- a/rbcs-cli/src/main/kotlin/net/woggioni/rbcs/cli/impl/commands/PutCommand.kt +++ b/rbcs-cli/src/main/kotlin/net/woggioni/rbcs/cli/impl/commands/PutCommand.kt @@ -1,11 +1,21 @@ package net.woggioni.rbcs.cli.impl.commands +import net.woggioni.jwo.Hash +import net.woggioni.jwo.JWO +import net.woggioni.jwo.NullOutputStream +import net.woggioni.rbcs.api.CacheValueMetadata import net.woggioni.rbcs.cli.impl.RbcsCommand -import net.woggioni.rbcs.cli.impl.converters.InputStreamConverter import net.woggioni.rbcs.client.RemoteBuildCacheClient import net.woggioni.rbcs.common.contextLogger +import net.woggioni.rbcs.common.createLogger import picocli.CommandLine import java.io.InputStream +import java.nio.file.Files +import java.nio.file.Path +import java.security.DigestInputStream +import java.security.MessageDigest +import java.util.UUID +import kotlin.io.encoding.decodingWith @CommandLine.Command( name = "put", @@ -13,25 +23,41 @@ import java.io.InputStream showDefaultValues = true ) class PutCommand : RbcsCommand() { - private val log = contextLogger() + companion object{ + private val log = createLogger() + } + @CommandLine.Spec private lateinit var spec: CommandLine.Model.CommandSpec @CommandLine.Option( names = ["-k", "--key"], - description = ["The key for the new value"], + description = ["The key for the new value, randomly generated if omitted"], paramLabel = "KEY" ) - private var key : String = "" + private var key : String? = null + + @CommandLine.Option( + names = ["-i", "--inline"], + description = ["File is to be displayed in the browser"], + paramLabel = "INLINE", + ) + private var inline : Boolean = false + + @CommandLine.Option( + names = ["-t", "--type"], + description = ["File mime type"], + paramLabel = "MIME_TYPE", + ) + private var mimeType : String? = null @CommandLine.Option( names = ["-v", "--value"], description = ["Path to a file containing the value to be added (defaults to stdin)"], paramLabel = "VALUE_FILE", - converter = [InputStreamConverter::class] ) - private var value : InputStream = System.`in` + private var value : Path? = null override fun run() { val clientCommand = spec.parent().userObject() as ClientCommand @@ -40,9 +66,40 @@ class PutCommand : RbcsCommand() { ?: throw IllegalArgumentException("Profile $profileName does not exist in configuration") } RemoteBuildCacheClient(profile).use { client -> - value.use { - client.put(key, it.readAllBytes()) + val inputStream : InputStream + val mimeType : String? + val contentDisposition : String? + val valuePath = value + val actualKey : String? + if(valuePath != null) { + inputStream = Files.newInputStream(valuePath) + mimeType = this.mimeType ?: Files.probeContentType(valuePath) + contentDisposition = if(inline) { + "inline" + } else { + "attachment; filename=\"${valuePath.fileName}\"" + } + actualKey = key ?: let { + val md = Hash.Algorithm.SHA512.newInputStream(Files.newInputStream(valuePath)).use { + JWO.copy(it, NullOutputStream()) + it.messageDigest + } + UUID.nameUUIDFromBytes(md.digest()).toString() + } + } else { + inputStream = System.`in` + mimeType = this.mimeType + contentDisposition = if(inline) { + "inline" + } else { + null + } + actualKey = key ?: UUID.randomUUID().toString() + } + inputStream.use { + client.put(actualKey, it.readAllBytes(), CacheValueMetadata(contentDisposition, mimeType)) }.get() + println(profile.serverURI.resolve(actualKey)) } } } \ No newline at end of file diff --git a/rbcs-cli/src/main/kotlin/net/woggioni/rbcs/cli/impl/commands/ServerCommand.kt b/rbcs-cli/src/main/kotlin/net/woggioni/rbcs/cli/impl/commands/ServerCommand.kt index f6418da..d569a2b 100644 --- a/rbcs-cli/src/main/kotlin/net/woggioni/rbcs/cli/impl/commands/ServerCommand.kt +++ b/rbcs-cli/src/main/kotlin/net/woggioni/rbcs/cli/impl/commands/ServerCommand.kt @@ -1,19 +1,20 @@ package net.woggioni.rbcs.cli.impl.commands +import net.woggioni.jwo.Application +import net.woggioni.jwo.JWO import net.woggioni.rbcs.cli.impl.RbcsCommand import net.woggioni.rbcs.cli.impl.converters.DurationConverter -import net.woggioni.rbcs.common.contextLogger +import net.woggioni.rbcs.common.createLogger import net.woggioni.rbcs.common.debug import net.woggioni.rbcs.common.info import net.woggioni.rbcs.server.RemoteBuildCacheServer import net.woggioni.rbcs.server.RemoteBuildCacheServer.Companion.DEFAULT_CONFIGURATION_URL -import net.woggioni.jwo.Application -import net.woggioni.jwo.JWO import picocli.CommandLine import java.io.ByteArrayOutputStream import java.nio.file.Files import java.nio.file.Path import java.time.Duration +import java.util.concurrent.TimeUnit @CommandLine.Command( name = "server", @@ -21,8 +22,9 @@ import java.time.Duration showDefaultValues = true ) class ServerCommand(app : Application) : RbcsCommand() { - - private val log = contextLogger() + companion object { + private val log = createLogger() + } private fun createDefaultConfigurationFile(configurationFile: Path) { log.info { @@ -66,11 +68,20 @@ class ServerCommand(app : Application) : RbcsCommand() { } } val server = RemoteBuildCacheServer(configuration) - server.run().use { server -> - timeout?.let { - Thread.sleep(it) - server.shutdown() + val handle = server.run() + val shutdownHook = Thread.ofPlatform().unstarted { + handle.sendShutdownSignal() + try { + handle.get(60, TimeUnit.SECONDS) + } catch (ex : Throwable) { + log.warn(ex.message, ex) } } + Runtime.getRuntime().addShutdownHook(shutdownHook) + if(timeout != null) { + Thread.sleep(timeout) + handle.sendShutdownSignal() + } + handle.get() } } \ No newline at end of file diff --git a/rbcs-cli/src/main/kotlin/net/woggioni/rbcs/cli/impl/converters/ByteSizeConverter.kt b/rbcs-cli/src/main/kotlin/net/woggioni/rbcs/cli/impl/converters/ByteSizeConverter.kt new file mode 100644 index 0000000..e9d323e --- /dev/null +++ b/rbcs-cli/src/main/kotlin/net/woggioni/rbcs/cli/impl/converters/ByteSizeConverter.kt @@ -0,0 +1,10 @@ +package net.woggioni.rbcs.cli.impl.converters + +import picocli.CommandLine + + +class ByteSizeConverter : CommandLine.ITypeConverter { + override fun convert(value: String): Int { + return Integer.decode(value) + } +} \ No newline at end of file diff --git a/rbcs-client/src/main/kotlin/net/woggioni/rbcs/client/Client.kt b/rbcs-client/src/main/kotlin/net/woggioni/rbcs/client/Client.kt index e3e119f..b125050 100644 --- a/rbcs-client/src/main/kotlin/net/woggioni/rbcs/client/Client.kt +++ b/rbcs-client/src/main/kotlin/net/woggioni/rbcs/client/Client.kt @@ -4,7 +4,9 @@ import io.netty.bootstrap.Bootstrap import io.netty.buffer.ByteBuf import io.netty.buffer.Unpooled import io.netty.channel.Channel +import io.netty.channel.ChannelHandler import io.netty.channel.ChannelHandlerContext +import io.netty.channel.ChannelInboundHandlerAdapter import io.netty.channel.ChannelOption import io.netty.channel.ChannelPipeline import io.netty.channel.SimpleChannelInboundHandler @@ -28,13 +30,18 @@ import io.netty.handler.codec.http.HttpVersion import io.netty.handler.ssl.SslContext import io.netty.handler.ssl.SslContextBuilder import io.netty.handler.stream.ChunkedWriteHandler +import io.netty.handler.timeout.IdleState +import io.netty.handler.timeout.IdleStateEvent +import io.netty.handler.timeout.IdleStateHandler import io.netty.util.concurrent.Future import io.netty.util.concurrent.GenericFutureListener +import net.woggioni.rbcs.api.CacheValueMetadata import net.woggioni.rbcs.client.impl.Parser import net.woggioni.rbcs.common.Xml -import net.woggioni.rbcs.common.contextLogger +import net.woggioni.rbcs.common.createLogger import net.woggioni.rbcs.common.debug import net.woggioni.rbcs.common.trace +import java.io.IOException import java.net.InetSocketAddress import java.net.URI import java.nio.file.Files @@ -44,14 +51,19 @@ import java.security.cert.X509Certificate import java.time.Duration import java.util.Base64 import java.util.concurrent.CompletableFuture +import java.util.concurrent.TimeUnit +import java.util.concurrent.TimeoutException import java.util.concurrent.atomic.AtomicInteger import kotlin.random.Random import io.netty.util.concurrent.Future as NettyFuture class RemoteBuildCacheClient(private val profile: Configuration.Profile) : AutoCloseable { + companion object{ + private val log = createLogger() + } + private val group: NioEventLoopGroup private var sslContext: SslContext - private val log = contextLogger() private val pool: ChannelPool data class Configuration( @@ -72,11 +84,21 @@ class RemoteBuildCacheClient(private val profile: Configuration.Profile) : AutoC val exp: Double ) + class Connection( + val readTimeout: Duration, + val writeTimeout: Duration, + val idleTimeout: Duration, + val readIdleTimeout: Duration, + val writeIdleTimeout: Duration + ) + data class Profile( val serverURI: URI, + val connection: Connection?, val authentication: Authentication?, val connectionTimeout: Duration?, val maxConnections: Int, + val compressionEnabled: Boolean, val retryPolicy: RetryPolicy?, ) @@ -141,18 +163,50 @@ class RemoteBuildCacheClient(private val profile: Configuration.Profile) : AutoC } override fun channelCreated(ch: Channel) { - val connectionId = connectionCount.getAndIncrement() + val connectionId = connectionCount.incrementAndGet() log.debug { - "Created connection $connectionId, total number of active connections: $connectionId" + "Created connection ${ch.id().asShortText()}, total number of active connections: $connectionId" } ch.closeFuture().addListener { val activeConnections = connectionCount.decrementAndGet() log.debug { - "Closed connection $connectionId, total number of active connections: $activeConnections" + "Closed connection ${ + ch.id().asShortText() + }, total number of active connections: $activeConnections" } } val pipeline: ChannelPipeline = ch.pipeline() + profile.connection?.also { conn -> + val readTimeout = conn.readTimeout.toMillis() + val writeTimeout = conn.writeTimeout.toMillis() + if (readTimeout > 0 || writeTimeout > 0) { + pipeline.addLast( + IdleStateHandler( + false, + readTimeout, + writeTimeout, + 0, + TimeUnit.MILLISECONDS + ) + ) + } + val readIdleTimeout = conn.readIdleTimeout.toMillis() + val writeIdleTimeout = conn.writeIdleTimeout.toMillis() + val idleTimeout = conn.idleTimeout.toMillis() + if (readIdleTimeout > 0 || writeIdleTimeout > 0 || idleTimeout > 0) { + pipeline.addLast( + IdleStateHandler( + true, + readIdleTimeout, + writeIdleTimeout, + idleTimeout, + TimeUnit.MILLISECONDS + ) + ) + } + } + // Add SSL handler if needed if ("https".equals(scheme, ignoreCase = true)) { pipeline.addLast("ssl", sslContext.newHandler(ch.alloc(), host, port)) @@ -160,7 +214,9 @@ class RemoteBuildCacheClient(private val profile: Configuration.Profile) : AutoC // HTTP handlers pipeline.addLast("codec", HttpClientCodec()) - pipeline.addLast("decompressor", HttpContentDecompressor()) + if(profile.compressionEnabled) { + pipeline.addLast("decompressor", HttpContentDecompressor()) + } pipeline.addLast("aggregator", HttpObjectAggregator(134217728)) pipeline.addLast("chunked", ChunkedWriteHandler()) } @@ -254,9 +310,13 @@ class RemoteBuildCacheClient(private val profile: Configuration.Profile) : AutoC } } - fun put(key: String, content: ByteArray): CompletableFuture { + fun put(key: String, content: ByteArray, metadata: CacheValueMetadata): CompletableFuture { return executeWithRetry { - sendRequest(profile.serverURI.resolve(key), HttpMethod.PUT, content) + val extraHeaders = sequenceOf( + metadata.mimeType?.let { HttpHeaderNames.CONTENT_TYPE to it }, + metadata.contentDisposition?.let { HttpHeaderNames.CONTENT_DISPOSITION to it } + ).filterNotNull() + sendRequest(profile.serverURI.resolve(key), HttpMethod.PUT, content, extraHeaders.asIterable()) }.thenApply { val status = it.status() if (it.status() != HttpResponseStatus.CREATED && it.status() != HttpResponseStatus.OK) { @@ -265,35 +325,83 @@ class RemoteBuildCacheClient(private val profile: Configuration.Profile) : AutoC } } - private fun sendRequest(uri: URI, method: HttpMethod, body: ByteArray?): CompletableFuture { + private fun sendRequest( + uri: URI, + method: HttpMethod, + body: ByteArray?, + extraHeaders: Iterable>? = null + ): CompletableFuture { val responseFuture = CompletableFuture() // Custom handler for processing responses + pool.acquire().addListener(object : GenericFutureListener> { + private val handlers = mutableListOf() + + fun cleanup(channel: Channel, pipeline: ChannelPipeline) { + handlers.forEach(pipeline::remove) + pool.release(channel) + } + override fun operationComplete(channelFuture: Future) { if (channelFuture.isSuccess) { val channel = channelFuture.now val pipeline = channel.pipeline() - channel.pipeline().addLast("handler", object : SimpleChannelInboundHandler() { + val timeoutHandler = object : ChannelInboundHandlerAdapter() { + override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) { + if (evt is IdleStateEvent) { + val te = when (evt.state()) { + IdleState.READER_IDLE -> TimeoutException( + "Read timeout", + ) + + IdleState.WRITER_IDLE -> TimeoutException("Write timeout") + + IdleState.ALL_IDLE -> TimeoutException("Idle timeout") + null -> throw IllegalStateException("This should never happen") + } + responseFuture.completeExceptionally(te) + ctx.close() + } + } + } + val closeListener = GenericFutureListener> { + responseFuture.completeExceptionally(IOException("The remote server closed the connection")) + pool.release(channel) + } + + val responseHandler = object : SimpleChannelInboundHandler() { override fun channelRead0( ctx: ChannelHandlerContext, response: FullHttpResponse ) { - pipeline.removeLast() - pool.release(channel) + channel.closeFuture().removeListener(closeListener) + cleanup(channel, pipeline) responseFuture.complete(response) } override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { + ctx.newPromise() val ex = when (cause) { is DecoderException -> cause.cause else -> cause } responseFuture.completeExceptionally(ex) ctx.close() - pipeline.removeLast() - pool.release(channel) } - }) + + override fun channelInactive(ctx: ChannelHandlerContext) { + pool.release(channel) + responseFuture.completeExceptionally(IOException("The remote server closed the connection")) + super.channelInactive(ctx) + } + } + for (handler in arrayOf(timeoutHandler, responseHandler)) { + handlers.add(handler) + } + pipeline.addLast(timeoutHandler, responseHandler) + channel.closeFuture().addListener(closeListener) + + // Prepare the HTTP request val request: FullHttpRequest = let { val content: ByteBuf? = body?.takeIf(ByteArray::isNotEmpty)?.let(Unpooled::wrappedBuffer) @@ -305,15 +413,19 @@ class RemoteBuildCacheClient(private val profile: Configuration.Profile) : AutoC ).apply { headers().apply { if (content != null) { - set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_OCTET_STREAM) set(HttpHeaderNames.CONTENT_LENGTH, content.readableBytes()) } set(HttpHeaderNames.HOST, profile.serverURI.host) set(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE) - set( - HttpHeaderNames.ACCEPT_ENCODING, - HttpHeaderValues.GZIP.toString() + "," + HttpHeaderValues.DEFLATE.toString() - ) + if(profile.compressionEnabled) { + set( + HttpHeaderNames.ACCEPT_ENCODING, + HttpHeaderValues.GZIP.toString() + "," + HttpHeaderValues.DEFLATE.toString() + ) + } + extraHeaders?.forEach { (k, v) -> + add(k, v) + } // Add basic auth if configured (profile.authentication as? Configuration.Authentication.BasicAuthenticationCredentials)?.let { credentials -> val auth = "${credentials.username}:${credentials.password}" diff --git a/rbcs-client/src/main/kotlin/net/woggioni/rbcs/client/impl/Parser.kt b/rbcs-client/src/main/kotlin/net/woggioni/rbcs/client/impl/Parser.kt index dcc6bee..7c955d8 100644 --- a/rbcs-client/src/main/kotlin/net/woggioni/rbcs/client/impl/Parser.kt +++ b/rbcs-client/src/main/kotlin/net/woggioni/rbcs/client/impl/Parser.kt @@ -1,5 +1,6 @@ package net.woggioni.rbcs.client.impl +import net.woggioni.rbcs.api.Configuration import net.woggioni.rbcs.api.exception.ConfigurationException import net.woggioni.rbcs.client.RemoteBuildCacheClient import net.woggioni.rbcs.common.Xml.Companion.asIterable @@ -12,6 +13,7 @@ import java.security.KeyStore import java.security.PrivateKey import java.security.cert.X509Certificate import java.time.Duration +import java.time.temporal.ChronoUnit object Parser { @@ -29,6 +31,7 @@ object Parser { ?: throw ConfigurationException("base-url attribute is required") var authentication: RemoteBuildCacheClient.Configuration.Authentication? = null var retryPolicy: RemoteBuildCacheClient.Configuration.RetryPolicy? = null + var connection : RemoteBuildCacheClient.Configuration.Connection? = null for (gchild in child.asIterable()) { when (gchild.localName) { "tls-client-auth" -> { @@ -86,6 +89,26 @@ object Parser { exp.toDouble() ) } + + "connection" -> { + val writeTimeout = gchild.renderAttribute("write-timeout") + ?.let(Duration::parse) ?: Duration.of(0, ChronoUnit.SECONDS) + val readTimeout = gchild.renderAttribute("read-timeout") + ?.let(Duration::parse) ?: Duration.of(0, ChronoUnit.SECONDS) + val idleTimeout = gchild.renderAttribute("idle-timeout") + ?.let(Duration::parse) ?: Duration.of(30, ChronoUnit.SECONDS) + val readIdleTimeout = gchild.renderAttribute("read-idle-timeout") + ?.let(Duration::parse) ?: Duration.of(60, ChronoUnit.SECONDS) + val writeIdleTimeout = gchild.renderAttribute("write-idle-timeout") + ?.let(Duration::parse) ?: Duration.of(60, ChronoUnit.SECONDS) + connection = RemoteBuildCacheClient.Configuration.Connection( + readTimeout, + writeTimeout, + idleTimeout, + readIdleTimeout, + writeIdleTimeout, + ) + } } } val maxConnections = child.renderAttribute("max-connections") @@ -93,11 +116,17 @@ object Parser { ?: 50 val connectionTimeout = child.renderAttribute("connection-timeout") ?.let(Duration::parse) + val compressionEnabled = child.renderAttribute("enable-compression") + ?.let(String::toBoolean) + ?: true + profiles[name] = RemoteBuildCacheClient.Configuration.Profile( uri, + connection, authentication, connectionTimeout, maxConnections, + compressionEnabled, retryPolicy ) } diff --git a/rbcs-client/src/main/resources/net/woggioni/rbcs/client/schema/rbcs-client.xsd b/rbcs-client/src/main/resources/net/woggioni/rbcs/client/schema/rbcs-client.xsd index 2fea379..e6feb53 100644 --- a/rbcs-client/src/main/resources/net/woggioni/rbcs/client/schema/rbcs-client.xsd +++ b/rbcs-client/src/main/resources/net/woggioni/rbcs/client/schema/rbcs-client.xsd @@ -19,12 +19,22 @@ + + + + + + + + + + diff --git a/rbcs-common/src/main/java/module-info.java b/rbcs-common/src/main/java/module-info.java index c08df9d..53fbe61 100644 --- a/rbcs-common/src/main/java/module-info.java +++ b/rbcs-common/src/main/java/module-info.java @@ -5,6 +5,7 @@ module net.woggioni.rbcs.common { requires kotlin.stdlib; requires net.woggioni.jwo; requires io.netty.buffer; + requires io.netty.transport; provides java.net.spi.URLStreamHandlerProvider with net.woggioni.rbcs.common.RbcsUrlStreamHandlerFactory; exports net.woggioni.rbcs.common; diff --git a/rbcs-common/src/main/kotlin/net/woggioni/rbcs/common/Logging.kt b/rbcs-common/src/main/kotlin/net/woggioni/rbcs/common/Logging.kt index b17cb07..c928e5e 100644 --- a/rbcs-common/src/main/kotlin/net/woggioni/rbcs/common/Logging.kt +++ b/rbcs-common/src/main/kotlin/net/woggioni/rbcs/common/Logging.kt @@ -1,90 +1,173 @@ package net.woggioni.rbcs.common +import io.netty.channel.ChannelHandlerContext +import io.netty.channel.Channel import org.slf4j.Logger import org.slf4j.LoggerFactory +import org.slf4j.MDC import org.slf4j.event.Level +import org.slf4j.spi.LoggingEventBuilder import java.nio.file.Files import java.nio.file.Path import java.util.logging.LogManager inline fun T.contextLogger() = LoggerFactory.getLogger(T::class.java) +inline fun createLogger() = LoggerFactory.getLogger(T::class.java) -inline fun Logger.traceParam(messageBuilder : () -> Pair>) { - if(isTraceEnabled) { +inline fun Logger.traceParam(messageBuilder: () -> Pair>) { + if (isTraceEnabled) { val (format, params) = messageBuilder() trace(format, params) } } -inline fun Logger.debugParam(messageBuilder : () -> Pair>) { - if(isDebugEnabled) { +inline fun Logger.debugParam(messageBuilder: () -> Pair>) { + if (isDebugEnabled) { val (format, params) = messageBuilder() info(format, params) } } -inline fun Logger.infoParam(messageBuilder : () -> Pair>) { - if(isInfoEnabled) { +inline fun Logger.infoParam(messageBuilder: () -> Pair>) { + if (isInfoEnabled) { val (format, params) = messageBuilder() info(format, params) } } -inline fun Logger.warnParam(messageBuilder : () -> Pair>) { - if(isWarnEnabled) { +inline fun Logger.warnParam(messageBuilder: () -> Pair>) { + if (isWarnEnabled) { val (format, params) = messageBuilder() warn(format, params) } } -inline fun Logger.errorParam(messageBuilder : () -> Pair>) { - if(isErrorEnabled) { +inline fun Logger.errorParam(messageBuilder: () -> Pair>) { + if (isErrorEnabled) { val (format, params) = messageBuilder() error(format, params) } } -inline fun log(log : Logger, - filter : Logger.() -> Boolean, - loggerMethod : Logger.(String) -> Unit, messageBuilder : () -> String) { - if(log.filter()) { +inline fun log( + log: Logger, + filter: Logger.() -> Boolean, + loggerMethod: Logger.(String) -> Unit, messageBuilder: () -> String +) { + if (log.filter()) { log.loggerMethod(messageBuilder()) } } -inline fun Logger.log(level : Level, messageBuilder : () -> String) { - if(isEnabledForLevel(level)) { +fun withMDC(params: Array>, cb: () -> Unit) { + object : AutoCloseable { + override fun close() { + for ((key, _) in params) MDC.remove(key) + } + }.use { + for ((key, value) in params) MDC.put(key, value) + cb() + } +} + +inline fun Logger.log(level: Level, channel: Channel, crossinline messageBuilder: (LoggingEventBuilder) -> Unit ) { + if (isEnabledForLevel(level)) { + val params = arrayOf>( + "channel-id-short" to channel.id().asShortText(), + "channel-id-long" to channel.id().asLongText(), + "remote-address" to channel.remoteAddress().toString(), + "local-address" to channel.localAddress().toString(), + ) + withMDC(params) { + val builder = makeLoggingEventBuilder(level) +// for ((key, value) in params) { +// builder.addKeyValue(key, value) +// } + messageBuilder(builder) + builder.log() + } + } +} +inline fun Logger.log(level: Level, channel: Channel, crossinline messageBuilder: () -> String) { + log(level, channel) { builder -> + builder.setMessage(messageBuilder()) + } +} + +inline fun Logger.trace(ch: Channel, crossinline messageBuilder: () -> String) { + log(Level.TRACE, ch, messageBuilder) +} + +inline fun Logger.debug(ch: Channel, crossinline messageBuilder: () -> String) { + log(Level.DEBUG, ch, messageBuilder) +} + +inline fun Logger.info(ch: Channel, crossinline messageBuilder: () -> String) { + log(Level.INFO, ch, messageBuilder) +} + +inline fun Logger.warn(ch: Channel, crossinline messageBuilder: () -> String) { + log(Level.WARN, ch, messageBuilder) +} + +inline fun Logger.error(ch: Channel, crossinline messageBuilder: () -> String) { + log(Level.ERROR, ch, messageBuilder) +} + +inline fun Logger.trace(ctx: ChannelHandlerContext, crossinline messageBuilder: () -> String) { + log(Level.TRACE, ctx.channel(), messageBuilder) +} + +inline fun Logger.debug(ctx: ChannelHandlerContext, crossinline messageBuilder: () -> String) { + log(Level.DEBUG, ctx.channel(), messageBuilder) +} + +inline fun Logger.info(ctx: ChannelHandlerContext, crossinline messageBuilder: () -> String) { + log(Level.INFO, ctx.channel(), messageBuilder) +} + +inline fun Logger.warn(ctx: ChannelHandlerContext, crossinline messageBuilder: () -> String) { + log(Level.WARN, ctx.channel(), messageBuilder) +} + +inline fun Logger.error(ctx: ChannelHandlerContext, crossinline messageBuilder: () -> String) { + log(Level.ERROR, ctx.channel(), messageBuilder) +} + + +inline fun Logger.log(level: Level, messageBuilder: () -> String) { + if (isEnabledForLevel(level)) { makeLoggingEventBuilder(level).log(messageBuilder()) } } -inline fun Logger.trace(messageBuilder : () -> String) { - if(isTraceEnabled) { +inline fun Logger.trace(messageBuilder: () -> String) { + if (isTraceEnabled) { trace(messageBuilder()) } } -inline fun Logger.debug(messageBuilder : () -> String) { - if(isDebugEnabled) { +inline fun Logger.debug(messageBuilder: () -> String) { + if (isDebugEnabled) { debug(messageBuilder()) } } -inline fun Logger.info(messageBuilder : () -> String) { - if(isInfoEnabled) { +inline fun Logger.info(messageBuilder: () -> String) { + if (isInfoEnabled) { info(messageBuilder()) } } -inline fun Logger.warn(messageBuilder : () -> String) { - if(isWarnEnabled) { +inline fun Logger.warn(messageBuilder: () -> String) { + if (isWarnEnabled) { warn(messageBuilder()) } } -inline fun Logger.error(messageBuilder : () -> String) { - if(isErrorEnabled) { +inline fun Logger.error(messageBuilder: () -> String) { + if (isErrorEnabled) { error(messageBuilder()) } } @@ -94,9 +177,9 @@ class LoggingConfig { init { val logManager = LogManager.getLogManager() - System.getProperty("log.config.source")?.let withSource@ { source -> + System.getProperty("log.config.source")?.let withSource@{ source -> val urls = LoggingConfig::class.java.classLoader.getResources(source) - while(urls.hasMoreElements()) { + while (urls.hasMoreElements()) { val url = urls.nextElement() url.openStream().use { inputStream -> logManager.readConfiguration(inputStream) diff --git a/rbcs-common/src/main/kotlin/net/woggioni/rbcs/common/RBCS.kt b/rbcs-common/src/main/kotlin/net/woggioni/rbcs/common/RBCS.kt index 599f091..86ff9f2 100644 --- a/rbcs-common/src/main/kotlin/net/woggioni/rbcs/common/RBCS.kt +++ b/rbcs-common/src/main/kotlin/net/woggioni/rbcs/common/RBCS.kt @@ -44,4 +44,18 @@ object RBCS { ): String { return JWO.bytesToHex(digest(data, md)) } + + fun processCacheKey(key: String, digestAlgorithm: String?) = digestAlgorithm + ?.let(MessageDigest::getInstance) + ?.let { md -> + digest(key.toByteArray(), md) + } ?: key.toByteArray(Charsets.UTF_8) + + fun Long.toIntOrNull(): Int? { + return if (this >= Int.MIN_VALUE && this <= Int.MAX_VALUE) { + toInt() + } else { + null + } + } } \ No newline at end of file diff --git a/rbcs-common/src/main/kotlin/net/woggioni/rbcs/common/Xml.kt b/rbcs-common/src/main/kotlin/net/woggioni/rbcs/common/Xml.kt index 0a05122..f317226 100644 --- a/rbcs-common/src/main/kotlin/net/woggioni/rbcs/common/Xml.kt +++ b/rbcs-common/src/main/kotlin/net/woggioni/rbcs/common/Xml.kt @@ -79,7 +79,7 @@ class Xml(val doc: Document, val element: Element) { class ErrorHandler(private val fileURL: URL) : ErrHandler { companion object { - private val log = LoggerFactory.getLogger(ErrorHandler::class.java) + private val log = createLogger() } override fun warning(ex: SAXParseException)= err(ex, Level.WARN) diff --git a/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/MemcacheCache.kt b/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/MemcacheCache.kt deleted file mode 100644 index c5d5a84..0000000 --- a/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/MemcacheCache.kt +++ /dev/null @@ -1,235 +0,0 @@ -package net.woggioni.rbcs.server.memcache - -import io.netty.buffer.ByteBufAllocator -import io.netty.buffer.Unpooled -import io.netty.handler.codec.memcache.binary.BinaryMemcacheOpcodes -import io.netty.handler.codec.memcache.binary.BinaryMemcacheResponseStatus -import io.netty.handler.codec.memcache.binary.DefaultBinaryMemcacheRequest -import net.woggioni.rbcs.api.Cache -import net.woggioni.rbcs.api.RequestHandle -import net.woggioni.rbcs.api.ResponseHandle -import net.woggioni.rbcs.api.event.RequestStreamingEvent -import net.woggioni.rbcs.api.event.ResponseStreamingEvent -import net.woggioni.rbcs.api.exception.ContentTooLargeException -import net.woggioni.rbcs.common.ByteBufOutputStream -import net.woggioni.rbcs.common.RBCS.digest -import net.woggioni.rbcs.common.contextLogger -import net.woggioni.rbcs.common.debug -import net.woggioni.rbcs.common.extractChunk -import net.woggioni.rbcs.server.memcache.client.MemcacheClient -import net.woggioni.rbcs.server.memcache.client.MemcacheResponseHandle -import net.woggioni.rbcs.server.memcache.client.StreamingRequestEvent -import net.woggioni.rbcs.server.memcache.client.StreamingResponseEvent -import java.security.MessageDigest -import java.time.Duration -import java.time.Instant -import java.util.concurrent.CompletableFuture -import java.util.zip.Deflater -import java.util.zip.DeflaterOutputStream -import java.util.zip.Inflater -import java.util.zip.InflaterOutputStream - -class MemcacheCache(private val cfg: MemcacheCacheConfiguration) : Cache { - - companion object { - @JvmStatic - private val log = contextLogger() - } - - private val memcacheClient = MemcacheClient(cfg) - - override fun get(key: String, responseHandle: ResponseHandle, alloc: ByteBufAllocator) { - val compressionMode = cfg.compressionMode - val buf = alloc.compositeBuffer() - val stream = ByteBufOutputStream(buf).let { outputStream -> - if (compressionMode != null) { - when (compressionMode) { - MemcacheCacheConfiguration.CompressionMode.DEFLATE -> { - InflaterOutputStream( - outputStream, - Inflater() - ) - } - } - } else { - outputStream - } - } - val memcacheResponseHandle = object : MemcacheResponseHandle { - override fun handleEvent(evt: StreamingResponseEvent) { - when (evt) { - is StreamingResponseEvent.ResponseReceived -> { - if (evt.response.status() == BinaryMemcacheResponseStatus.SUCCESS) { - responseHandle.handleEvent(ResponseStreamingEvent.RESPONSE_RECEIVED) - } else if (evt.response.status() == BinaryMemcacheResponseStatus.KEY_ENOENT) { - responseHandle.handleEvent(ResponseStreamingEvent.NOT_FOUND) - } else { - responseHandle.handleEvent(ResponseStreamingEvent.ExceptionCaught(MemcacheException(evt.response.status()))) - } - } - - is StreamingResponseEvent.LastContentReceived -> { - evt.content.content().let { content -> - content.readBytes(stream, content.readableBytes()) - } - buf.retain() - stream.close() - val chunk = extractChunk(buf, alloc) - buf.release() - responseHandle.handleEvent( - ResponseStreamingEvent.LastChunkReceived( - chunk - ) - ) - } - - is StreamingResponseEvent.ContentReceived -> { - evt.content.content().let { content -> - content.readBytes(stream, content.readableBytes()) - } - if (buf.readableBytes() >= cfg.chunkSize) { - val chunk = extractChunk(buf, alloc) - responseHandle.handleEvent(ResponseStreamingEvent.ChunkReceived(chunk)) - } - } - - is StreamingResponseEvent.ExceptionCaught -> { - responseHandle.handleEvent(ResponseStreamingEvent.ExceptionCaught(evt.exception)) - } - } - } - } - memcacheClient.sendRequest(Unpooled.wrappedBuffer(key.toByteArray()), memcacheResponseHandle) - .thenApply { memcacheRequestHandle -> - val request = (cfg.digestAlgorithm - ?.let(MessageDigest::getInstance) - ?.let { md -> - digest(key.toByteArray(), md) - } ?: key.toByteArray(Charsets.UTF_8) - ).let { digest -> - DefaultBinaryMemcacheRequest(Unpooled.wrappedBuffer(digest)).apply { - setOpcode(BinaryMemcacheOpcodes.GET) - } - } - memcacheRequestHandle.handleEvent(StreamingRequestEvent.SendRequest(request)) - }.exceptionally { ex -> - responseHandle.handleEvent(ResponseStreamingEvent.ExceptionCaught(ex)) - } - } - - private fun encodeExpiry(expiry: Duration): Int { - val expirySeconds = expiry.toSeconds() - return expirySeconds.toInt().takeIf { it.toLong() == expirySeconds } - ?: Instant.ofEpochSecond(expirySeconds).epochSecond.toInt() - } - - override fun put( - key: String, - responseHandle: ResponseHandle, - alloc: ByteBufAllocator - ): CompletableFuture { - val memcacheResponseHandle = object : MemcacheResponseHandle { - override fun handleEvent(evt: StreamingResponseEvent) { - when (evt) { - is StreamingResponseEvent.ResponseReceived -> { - when (evt.response.status()) { - BinaryMemcacheResponseStatus.SUCCESS -> { - responseHandle.handleEvent(ResponseStreamingEvent.RESPONSE_RECEIVED) - } - - BinaryMemcacheResponseStatus.KEY_ENOENT -> { - responseHandle.handleEvent(ResponseStreamingEvent.NOT_FOUND) - } - - BinaryMemcacheResponseStatus.E2BIG -> { - responseHandle.handleEvent( - ResponseStreamingEvent.ExceptionCaught( - ContentTooLargeException("Request payload is too big", null) - ) - ) - } - - else -> { - responseHandle.handleEvent(ResponseStreamingEvent.ExceptionCaught(MemcacheException(evt.response.status()))) - } - } - } - - is StreamingResponseEvent.LastContentReceived -> { - responseHandle.handleEvent( - ResponseStreamingEvent.LastChunkReceived( - evt.content.content().retain() - ) - ) - } - - is StreamingResponseEvent.ContentReceived -> { - responseHandle.handleEvent(ResponseStreamingEvent.ChunkReceived(evt.content.content().retain())) - } - - is StreamingResponseEvent.ExceptionCaught -> { - responseHandle.handleEvent(ResponseStreamingEvent.ExceptionCaught(evt.exception)) - } - } - } - } - val result: CompletableFuture = - memcacheClient.sendRequest(Unpooled.wrappedBuffer(key.toByteArray()), memcacheResponseHandle) - .thenApply { memcacheRequestHandle -> - val request = (cfg.digestAlgorithm - ?.let(MessageDigest::getInstance) - ?.let { md -> - digest(key.toByteArray(), md) - } ?: key.toByteArray(Charsets.UTF_8)).let { digest -> - val extras = Unpooled.buffer(8, 8) - extras.writeInt(0) - extras.writeInt(encodeExpiry(cfg.maxAge)) - DefaultBinaryMemcacheRequest(Unpooled.wrappedBuffer(digest), extras).apply { - setOpcode(BinaryMemcacheOpcodes.SET) - } - } -// memcacheRequestHandle.handleEvent(StreamingRequestEvent.SendRequest(request)) - val compressionMode = cfg.compressionMode - val buf = alloc.heapBuffer() - val stream = ByteBufOutputStream(buf).let { outputStream -> - if (compressionMode != null) { - when (compressionMode) { - MemcacheCacheConfiguration.CompressionMode.DEFLATE -> { - DeflaterOutputStream( - outputStream, - Deflater(Deflater.DEFAULT_COMPRESSION, false) - ) - } - } - } else { - outputStream - } - } - RequestHandle { evt -> - when (evt) { - is RequestStreamingEvent.LastChunkReceived -> { - evt.chunk.readBytes(stream, evt.chunk.readableBytes()) - buf.retain() - stream.close() - request.setTotalBodyLength(buf.readableBytes() + request.keyLength() + request.extrasLength()) - memcacheRequestHandle.handleEvent(StreamingRequestEvent.SendRequest(request)) - memcacheRequestHandle.handleEvent(StreamingRequestEvent.SendLastChunk(buf)) - } - - is RequestStreamingEvent.ChunkReceived -> { - evt.chunk.readBytes(stream, evt.chunk.readableBytes()) - } - - is RequestStreamingEvent.ExceptionCaught -> { - stream.close() - } - } - } - } - return result - } - - override fun close() { - memcacheClient.close() - } -} diff --git a/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/MemcacheCacheConfiguration.kt b/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/MemcacheCacheConfiguration.kt index 0ff4f42..d725584 100644 --- a/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/MemcacheCacheConfiguration.kt +++ b/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/MemcacheCacheConfiguration.kt @@ -1,15 +1,17 @@ package net.woggioni.rbcs.server.memcache +import net.woggioni.rbcs.api.CacheHandlerFactory import net.woggioni.rbcs.api.Configuration import net.woggioni.rbcs.common.HostAndPort +import net.woggioni.rbcs.server.memcache.client.MemcacheClient import java.time.Duration data class MemcacheCacheConfiguration( val servers: List, val maxAge: Duration = Duration.ofDays(1), - val maxSize: Int = 0x100000, val digestAlgorithm: String? = null, val compressionMode: CompressionMode? = null, + val compressionLevel: Int, val chunkSize : Int ) : Configuration.Cache { @@ -27,7 +29,14 @@ data class MemcacheCacheConfiguration( ) - override fun materialize() = MemcacheCache(this) + override fun materialize() = object : CacheHandlerFactory { + private val client = MemcacheClient(this@MemcacheCacheConfiguration.servers, chunkSize) + override fun close() { + client.close() + } + + override fun newHandler() = MemcacheCacheHandler(client, digestAlgorithm, compressionMode != null, compressionLevel, chunkSize, maxAge) + } override fun getNamespaceURI() = "urn:net.woggioni.rbcs.server.memcache" diff --git a/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/MemcacheCacheHandler.kt b/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/MemcacheCacheHandler.kt new file mode 100644 index 0000000..1451010 --- /dev/null +++ b/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/MemcacheCacheHandler.kt @@ -0,0 +1,409 @@ +package net.woggioni.rbcs.server.memcache + +import io.netty.buffer.ByteBuf +import io.netty.buffer.ByteBufAllocator +import io.netty.buffer.CompositeByteBuf +import io.netty.channel.ChannelHandlerContext +import io.netty.channel.SimpleChannelInboundHandler +import io.netty.handler.codec.memcache.DefaultLastMemcacheContent +import io.netty.handler.codec.memcache.DefaultMemcacheContent +import io.netty.handler.codec.memcache.LastMemcacheContent +import io.netty.handler.codec.memcache.MemcacheContent +import io.netty.handler.codec.memcache.binary.BinaryMemcacheOpcodes +import io.netty.handler.codec.memcache.binary.BinaryMemcacheResponse +import io.netty.handler.codec.memcache.binary.BinaryMemcacheResponseStatus +import io.netty.handler.codec.memcache.binary.DefaultBinaryMemcacheRequest +import net.woggioni.rbcs.api.CacheValueMetadata +import net.woggioni.rbcs.api.exception.ContentTooLargeException +import net.woggioni.rbcs.api.message.CacheMessage +import net.woggioni.rbcs.api.message.CacheMessage.CacheContent +import net.woggioni.rbcs.api.message.CacheMessage.CacheGetRequest +import net.woggioni.rbcs.api.message.CacheMessage.CachePutRequest +import net.woggioni.rbcs.api.message.CacheMessage.CachePutResponse +import net.woggioni.rbcs.api.message.CacheMessage.CacheValueFoundResponse +import net.woggioni.rbcs.api.message.CacheMessage.CacheValueNotFoundResponse +import net.woggioni.rbcs.api.message.CacheMessage.LastCacheContent +import net.woggioni.rbcs.common.ByteBufInputStream +import net.woggioni.rbcs.common.ByteBufOutputStream +import net.woggioni.rbcs.common.RBCS.processCacheKey +import net.woggioni.rbcs.common.RBCS.toIntOrNull +import net.woggioni.rbcs.common.createLogger +import net.woggioni.rbcs.common.debug +import net.woggioni.rbcs.common.extractChunk +import net.woggioni.rbcs.common.trace +import net.woggioni.rbcs.server.memcache.client.MemcacheClient +import net.woggioni.rbcs.server.memcache.client.MemcacheRequestController +import net.woggioni.rbcs.server.memcache.client.MemcacheResponseHandler +import java.io.ByteArrayOutputStream +import java.io.ObjectInputStream +import java.io.ObjectOutputStream +import java.nio.ByteBuffer +import java.nio.channels.Channels +import java.nio.channels.FileChannel +import java.nio.channels.ReadableByteChannel +import java.nio.file.Files +import java.nio.file.StandardOpenOption +import java.time.Duration +import java.time.Instant +import java.util.concurrent.CompletableFuture +import java.util.zip.Deflater +import java.util.zip.DeflaterOutputStream +import java.util.zip.InflaterOutputStream +import io.netty.channel.Channel as NettyChannel + +class MemcacheCacheHandler( + private val client: MemcacheClient, + private val digestAlgorithm: String?, + private val compressionEnabled: Boolean, + private val compressionLevel: Int, + private val chunkSize: Int, + private val maxAge: Duration +) : SimpleChannelInboundHandler() { + companion object { + private val log = createLogger() + + private fun encodeExpiry(expiry: Duration): Int { + val expirySeconds = expiry.toSeconds() + return expirySeconds.toInt().takeIf { it.toLong() == expirySeconds } + ?: Instant.ofEpochSecond(expirySeconds).epochSecond.toInt() + } + } + + private inner class InProgressGetRequest( + private val key: String, + private val ctx: ChannelHandlerContext + ) { + private val acc = ctx.alloc().compositeBuffer() + private val chunk = ctx.alloc().compositeBuffer() + private val outputStream = ByteBufOutputStream(chunk).let { + if (compressionEnabled) { + InflaterOutputStream(it) + } else { + it + } + } + private var responseSent = false + private var metadataSize: Int? = null + + fun write(buf: ByteBuf) { + acc.addComponent(true, buf.retain()) + if (metadataSize == null && acc.readableBytes() >= Int.SIZE_BYTES) { + metadataSize = acc.readInt() + } + metadataSize + ?.takeIf { !responseSent } + ?.takeIf { acc.readableBytes() >= it } + ?.let { mSize -> + val metadata = ObjectInputStream(ByteBufInputStream(acc)).use { + acc.retain() + it.readObject() as CacheValueMetadata + } + ctx.writeAndFlush(CacheValueFoundResponse(key, metadata)) + responseSent = true + acc.readerIndex(Int.SIZE_BYTES + mSize) + } + if (responseSent) { + acc.readBytes(outputStream, acc.readableBytes()) + if(acc.readableBytes() >= chunkSize) { + flush(false) + } + } + } + + private fun flush(last : Boolean) { + val toSend = extractChunk(chunk, ctx.alloc()) + val msg = if(last) { + log.trace(ctx) { + "Sending last chunk to client on channel ${ctx.channel().id().asShortText()}" + } + LastCacheContent(toSend) + } else { + log.trace(ctx) { + "Sending chunk to client on channel ${ctx.channel().id().asShortText()}" + } + CacheContent(toSend) + } + ctx.writeAndFlush(msg) + } + + fun commit() { + acc.release() + chunk.retain() + outputStream.close() + flush(true) + chunk.release() + } + + fun rollback() { + acc.release() + outputStream.close() + } + } + + private inner class InProgressPutRequest( + private val ch : NettyChannel, + metadata : CacheValueMetadata, + val digest : ByteBuf, + val requestController: CompletableFuture, + private val alloc: ByteBufAllocator + ) { + private var totalSize = 0 + private var tmpFile : FileChannel? = null + private val accumulator = alloc.compositeBuffer() + private val stream = ByteBufOutputStream(accumulator).let { + if (compressionEnabled) { + DeflaterOutputStream(it, Deflater(compressionLevel)) + } else { + it + } + } + + init { + ByteArrayOutputStream().let { baos -> + ObjectOutputStream(baos).use { + it.writeObject(metadata) + } + val serializedBytes = baos.toByteArray() + accumulator.writeInt(serializedBytes.size) + accumulator.writeBytes(serializedBytes) + } + } + + fun write(buf: ByteBuf) { + totalSize += buf.readableBytes() + buf.readBytes(stream, buf.readableBytes()) + tmpFile?.let { + flushToDisk(it, accumulator) + } + if(accumulator.readableBytes() > 0x100000) { + log.debug(ch) { + "Entry is too big, buffering it into a file" + } + val opts = arrayOf( + StandardOpenOption.DELETE_ON_CLOSE, + StandardOpenOption.READ, + StandardOpenOption.WRITE, + StandardOpenOption.TRUNCATE_EXISTING + ) + FileChannel.open(Files.createTempFile("rbcs-memcache", ".tmp"), *opts).let { fc -> + tmpFile = fc + flushToDisk(fc, accumulator) + } + } + } + + private fun flushToDisk(fc : FileChannel, buf : CompositeByteBuf) { + val chunk = extractChunk(buf, alloc) + fc.write(chunk.nioBuffer()) + chunk.release() + } + + fun commit() : Pair { + digest.release() + accumulator.retain() + stream.close() + val fileChannel = tmpFile + return if(fileChannel != null) { + flushToDisk(fileChannel, accumulator) + accumulator.release() + fileChannel.position(0) + val fileSize = fileChannel.size().toIntOrNull() ?: let { + fileChannel.close() + throw ContentTooLargeException("Request body is too large", null) + } + fileSize to fileChannel + } else { + accumulator.readableBytes() to Channels.newChannel(ByteBufInputStream(accumulator)) + } + } + + fun rollback() { + stream.close() + digest.release() + tmpFile?.close() + } + } + + private var inProgressPutRequest: InProgressPutRequest? = null + private var inProgressGetRequest: InProgressGetRequest? = null + + override fun channelRead0(ctx: ChannelHandlerContext, msg: CacheMessage) { + when (msg) { + is CacheGetRequest -> handleGetRequest(ctx, msg) + is CachePutRequest -> handlePutRequest(ctx, msg) + is LastCacheContent -> handleLastCacheContent(ctx, msg) + is CacheContent -> handleCacheContent(ctx, msg) + else -> ctx.fireChannelRead(msg) + } + } + + private fun handleGetRequest(ctx: ChannelHandlerContext, msg: CacheGetRequest) { + log.debug(ctx) { + "Fetching ${msg.key} from memcache" + } + val key = ctx.alloc().buffer().also { + it.writeBytes(processCacheKey(msg.key, digestAlgorithm)) + } + val responseHandler = object : MemcacheResponseHandler { + override fun responseReceived(response: BinaryMemcacheResponse) { + val status = response.status() + when (status) { + BinaryMemcacheResponseStatus.SUCCESS -> { + log.debug(ctx) { + "Cache hit for key ${msg.key} on memcache" + } + inProgressGetRequest = InProgressGetRequest(msg.key, ctx) + } + + BinaryMemcacheResponseStatus.KEY_ENOENT -> { + log.debug(ctx) { + "Cache miss for key ${msg.key} on memcache" + } + ctx.writeAndFlush(CacheValueNotFoundResponse()) + } + } + } + + override fun contentReceived(content: MemcacheContent) { + log.trace(ctx) { + "${if(content is LastMemcacheContent) "Last chunk" else "Chunk"} of ${content.content().readableBytes()} bytes received from memcache for key ${msg.key}" + } + inProgressGetRequest?.write(content.content()) + if (content is LastMemcacheContent) { + inProgressGetRequest?.commit() + } + } + + override fun exceptionCaught(ex: Throwable) { + inProgressGetRequest?.let { + inProgressGetRequest = null + it.rollback() + } + this@MemcacheCacheHandler.exceptionCaught(ctx, ex) + } + } + client.sendRequest(key.retainedDuplicate(), responseHandler).thenAccept { requestHandle -> + log.trace(ctx) { + "Sending GET request for key ${msg.key} to memcache" + } + val request = DefaultBinaryMemcacheRequest(key).apply { + setOpcode(BinaryMemcacheOpcodes.GET) + } + requestHandle.sendRequest(request) + } + } + + private fun handlePutRequest(ctx: ChannelHandlerContext, msg: CachePutRequest) { + val key = ctx.alloc().buffer().also { + it.writeBytes(processCacheKey(msg.key, digestAlgorithm)) + } + val responseHandler = object : MemcacheResponseHandler { + override fun responseReceived(response: BinaryMemcacheResponse) { + val status = response.status() + when (status) { + BinaryMemcacheResponseStatus.SUCCESS -> { + log.debug(ctx) { + "Inserted key ${msg.key} into memcache" + } + ctx.writeAndFlush(CachePutResponse(msg.key)) + } + else -> this@MemcacheCacheHandler.exceptionCaught(ctx, MemcacheException(status)) + } + } + + override fun contentReceived(content: MemcacheContent) {} + + override fun exceptionCaught(ex: Throwable) { + this@MemcacheCacheHandler.exceptionCaught(ctx, ex) + } + } + + val requestController = client.sendRequest(key.retainedDuplicate(), responseHandler).whenComplete { _, ex -> + ex?.let { + this@MemcacheCacheHandler.exceptionCaught(ctx, ex) + } + } + inProgressPutRequest = InProgressPutRequest(ctx.channel(), msg.metadata, key, requestController, ctx.alloc()) + } + + private fun handleCacheContent(ctx: ChannelHandlerContext, msg: CacheContent) { + inProgressPutRequest?.let { request -> + log.trace(ctx) { + "Received chunk of ${msg.content().readableBytes()} bytes for memcache" + } + request.write(msg.content()) + } + } + + private fun handleLastCacheContent(ctx: ChannelHandlerContext, msg: LastCacheContent) { + inProgressPutRequest?.let { request -> + inProgressPutRequest = null + log.trace(ctx) { + "Received last chunk of ${msg.content().readableBytes()} bytes for memcache" + } + request.write(msg.content()) + val key = request.digest.retainedDuplicate() + val (payloadSize, payloadSource) = request.commit() + val extras = ctx.alloc().buffer(8, 8) + extras.writeInt(0) + extras.writeInt(encodeExpiry(maxAge)) + val totalBodyLength = request.digest.readableBytes() + extras.readableBytes() + payloadSize + request.requestController.whenComplete { requestController, ex -> + if(ex == null) { + log.trace(ctx) { + "Sending SET request to memcache" + } + requestController.sendRequest(DefaultBinaryMemcacheRequest().apply { + setOpcode(BinaryMemcacheOpcodes.SET) + setKey(key) + setExtras(extras) + setTotalBodyLength(totalBodyLength) + }) + log.trace(ctx) { + "Sending request payload to memcache" + } + payloadSource.use { source -> + val bb = ByteBuffer.allocate(chunkSize) + while (true) { + val read = source.read(bb) + bb.limit() + if(read >= 0 && bb.position() < chunkSize && bb.hasRemaining()) { + continue + } + val chunk = ctx.alloc().buffer(chunkSize) + bb.flip() + chunk.writeBytes(bb) + bb.clear() + log.trace(ctx) { + "Sending ${chunk.readableBytes()} bytes chunk to memcache" + } + if(read < 0) { + requestController.sendContent(DefaultLastMemcacheContent(chunk)) + break + } else { + requestController.sendContent(DefaultMemcacheContent(chunk)) + } + } + } + } else { + payloadSource.close() + } + } + } + } + + override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { + inProgressGetRequest?.let { + inProgressGetRequest = null + it.rollback() + } + inProgressPutRequest?.let { + inProgressPutRequest = null + it.requestController.thenAccept { controller -> + controller.exceptionCaught(cause) + } + it.rollback() + } + super.exceptionCaught(ctx, cause) + } +} \ No newline at end of file diff --git a/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/MemcacheCacheProvider.kt b/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/MemcacheCacheProvider.kt index 445e03e..4c7064b 100644 --- a/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/MemcacheCacheProvider.kt +++ b/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/MemcacheCacheProvider.kt @@ -28,12 +28,12 @@ class MemcacheCacheProvider : CacheProvider { val maxAge = el.renderAttribute("max-age") ?.let(Duration::parse) ?: Duration.ofDays(1) - val maxSize = el.renderAttribute("max-size") - ?.let(Integer::decode) - ?: 0x100000 val chunkSize = el.renderAttribute("chunk-size") ?.let(Integer::decode) - ?: 0x4000 + ?: 0x10000 + val compressionLevel = el.renderAttribute("compression-level") + ?.let(Integer::decode) + ?: -1 val compressionMode = el.renderAttribute("compression-mode") ?.let { when (it) { @@ -41,7 +41,6 @@ class MemcacheCacheProvider : CacheProvider { else -> MemcacheCacheConfiguration.CompressionMode.DEFLATE } } - ?: MemcacheCacheConfiguration.CompressionMode.DEFLATE val digestAlgorithm = el.renderAttribute("digest") for (child in el.asIterable()) { when (child.nodeName) { @@ -62,9 +61,9 @@ class MemcacheCacheProvider : CacheProvider { return MemcacheCacheConfiguration( servers, maxAge, - maxSize, digestAlgorithm, compressionMode, + compressionLevel, chunkSize ) } @@ -85,7 +84,6 @@ class MemcacheCacheProvider : CacheProvider { } } attr("max-age", maxAge.toString()) - attr("max-size", maxSize.toString()) attr("chunk-size", chunkSize.toString()) digestAlgorithm?.let { digestAlgorithm -> attr("digest", digestAlgorithm) @@ -97,6 +95,7 @@ class MemcacheCacheProvider : CacheProvider { } ) } + attr("compression-level", compressionLevel.toString()) } result } diff --git a/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/client/Event.kt b/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/client/Event.kt deleted file mode 100644 index 06afbf5..0000000 --- a/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/client/Event.kt +++ /dev/null @@ -1,30 +0,0 @@ -package net.woggioni.rbcs.server.memcache.client - -import io.netty.buffer.ByteBuf -import io.netty.handler.codec.memcache.LastMemcacheContent -import io.netty.handler.codec.memcache.MemcacheContent -import io.netty.handler.codec.memcache.binary.BinaryMemcacheRequest -import io.netty.handler.codec.memcache.binary.BinaryMemcacheResponse - -sealed interface StreamingRequestEvent { - class SendRequest(val request : BinaryMemcacheRequest) : StreamingRequestEvent - open class SendChunk(val chunk : ByteBuf) : StreamingRequestEvent - class SendLastChunk(chunk : ByteBuf) : SendChunk(chunk) - class ExceptionCaught(val exception : Throwable) : StreamingRequestEvent -} - -sealed interface StreamingResponseEvent { - class ResponseReceived(val response : BinaryMemcacheResponse) : StreamingResponseEvent - open class ContentReceived(val content : MemcacheContent) : StreamingResponseEvent - class LastContentReceived(val lastContent : LastMemcacheContent) : ContentReceived(lastContent) - class ExceptionCaught(val exception : Throwable) : StreamingResponseEvent -} - -interface MemcacheRequestHandle { - fun handleEvent(evt : StreamingRequestEvent) -} - -interface MemcacheResponseHandle { - fun handleEvent(evt : StreamingResponseEvent) -} - diff --git a/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/client/MemcacheClient.kt b/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/client/MemcacheClient.kt index 5c6e4f1..950b26a 100644 --- a/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/client/MemcacheClient.kt +++ b/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/client/MemcacheClient.kt @@ -4,6 +4,7 @@ package net.woggioni.rbcs.server.memcache.client import io.netty.bootstrap.Bootstrap import io.netty.buffer.ByteBuf import io.netty.channel.Channel +import io.netty.channel.ChannelFutureListener import io.netty.channel.ChannelHandlerContext import io.netty.channel.ChannelOption import io.netty.channel.ChannelPipeline @@ -13,31 +14,29 @@ import io.netty.channel.pool.AbstractChannelPoolHandler import io.netty.channel.pool.ChannelPool import io.netty.channel.pool.FixedChannelPool import io.netty.channel.socket.nio.NioSocketChannel -import io.netty.handler.codec.memcache.DefaultLastMemcacheContent -import io.netty.handler.codec.memcache.DefaultMemcacheContent import io.netty.handler.codec.memcache.LastMemcacheContent import io.netty.handler.codec.memcache.MemcacheContent import io.netty.handler.codec.memcache.MemcacheObject import io.netty.handler.codec.memcache.binary.BinaryMemcacheClientCodec +import io.netty.handler.codec.memcache.binary.BinaryMemcacheRequest import io.netty.handler.codec.memcache.binary.BinaryMemcacheResponse -import io.netty.handler.logging.LoggingHandler import io.netty.util.concurrent.GenericFutureListener import net.woggioni.rbcs.common.HostAndPort -import net.woggioni.rbcs.common.contextLogger -import net.woggioni.rbcs.common.debug +import net.woggioni.rbcs.common.createLogger +import net.woggioni.rbcs.common.warn import net.woggioni.rbcs.server.memcache.MemcacheCacheConfiguration +import net.woggioni.rbcs.server.memcache.MemcacheCacheHandler +import java.io.IOException import java.net.InetSocketAddress import java.util.concurrent.CompletableFuture import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.atomic.AtomicLong import io.netty.util.concurrent.Future as NettyFuture -class MemcacheClient(private val cfg: MemcacheCacheConfiguration) : AutoCloseable { +class MemcacheClient(private val servers: List, private val chunkSize : Int) : AutoCloseable { private companion object { - @JvmStatic - private val log = contextLogger() + private val log = createLogger() } private val group: NioEventLoopGroup @@ -47,8 +46,6 @@ class MemcacheClient(private val cfg: MemcacheCacheConfiguration) : AutoCloseabl group = NioEventLoopGroup() } - private val counter = AtomicLong(0) - private fun newConnectionPool(server: MemcacheCacheConfiguration.Server): FixedChannelPool { val bootstrap = Bootstrap().apply { group(group) @@ -63,32 +60,33 @@ class MemcacheClient(private val cfg: MemcacheCacheConfiguration) : AutoCloseabl override fun channelCreated(ch: Channel) { val pipeline: ChannelPipeline = ch.pipeline() - pipeline.addLast(BinaryMemcacheClientCodec()) - pipeline.addLast(LoggingHandler()) + pipeline.addLast(BinaryMemcacheClientCodec(chunkSize, true)) } } return FixedChannelPool(bootstrap, channelPoolHandler, server.maxConnections) } - fun sendRequest(key: ByteBuf, responseHandle: MemcacheResponseHandle): CompletableFuture { - val server = cfg.servers.let { servers -> - if (servers.size > 1) { - var checksum = 0 - while (key.readableBytes() > 4) { - val byte = key.readInt() - checksum = checksum xor byte - } - while (key.readableBytes() > 0) { - val byte = key.readByte() - checksum = checksum xor byte.toInt() - } - servers[checksum % servers.size] - } else { - servers.first() + fun sendRequest( + key: ByteBuf, + responseHandler: MemcacheResponseHandler + ): CompletableFuture { + val server = if (servers.size > 1) { + var checksum = 0 + while (key.readableBytes() > 4) { + val byte = key.readInt() + checksum = checksum xor byte } + while (key.readableBytes() > 0) { + val byte = key.readByte() + checksum = checksum xor byte.toInt() + } + servers[checksum % servers.size] + } else { + servers.first() } + key.release() - val response = CompletableFuture() + val response = CompletableFuture() // Custom handler for processing responses val pool = connectionPool.computeIfAbsent(server.endpoint) { newConnectionPool(server) @@ -96,74 +94,107 @@ class MemcacheClient(private val cfg: MemcacheCacheConfiguration) : AutoCloseabl pool.acquire().addListener(object : GenericFutureListener> { override fun operationComplete(channelFuture: NettyFuture) { if (channelFuture.isSuccess) { + + var requestSent = false + var requestBodySent = false + var requestFinished = false + var responseReceived = false + var responseBodyReceived = false + var responseFinished = false + var requestBodySize = 0 + var requestBodyBytesSent = 0 + + + val channel = channelFuture.now + var connectionClosedByTheRemoteServer = true + val closeCallback = { + if (connectionClosedByTheRemoteServer) { + val ex = IOException("The memcache server closed the connection") + val completed = response.completeExceptionally(ex) + if(!completed) responseHandler.exceptionCaught(ex) + log.warn { + "RequestSent: $requestSent, RequestBodySent: $requestBodySent, " + + "RequestFinished: $requestFinished, ResponseReceived: $responseReceived, " + + "ResponseBodyReceived: $responseBodyReceived, ResponseFinished: $responseFinished, " + + "RequestBodySize: $requestBodySize, RequestBodyBytesSent: $requestBodyBytesSent" + } + } + pool.release(channel) + } + val closeListener = ChannelFutureListener { + closeCallback() + } + channel.closeFuture().addListener(closeListener) val pipeline = channel.pipeline() val handler = object : SimpleChannelInboundHandler() { + + override fun handlerAdded(ctx: ChannelHandlerContext) { + channel.closeFuture().removeListener(closeListener) + } + override fun channelRead0( ctx: ChannelHandlerContext, msg: MemcacheObject ) { when (msg) { - is BinaryMemcacheResponse -> responseHandle.handleEvent( - StreamingResponseEvent.ResponseReceived( - msg - ) - ) + is BinaryMemcacheResponse -> { + responseHandler.responseReceived(msg) + responseReceived = true + } is LastMemcacheContent -> { - responseHandle.handleEvent( - StreamingResponseEvent.LastContentReceived( - msg - ) - ) - pipeline.removeLast() + responseFinished = true + responseHandler.contentReceived(msg) + pipeline.remove(this) pool.release(channel) } - is MemcacheContent -> responseHandle.handleEvent( - StreamingResponseEvent.ContentReceived( - msg - ) - ) + is MemcacheContent -> { + responseBodyReceived = true + responseHandler.contentReceived(msg) + } } } + override fun channelInactive(ctx: ChannelHandlerContext) { + closeCallback() + ctx.fireChannelInactive() + } + override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { - responseHandle.handleEvent(StreamingResponseEvent.ExceptionCaught(cause)) + connectionClosedByTheRemoteServer = false ctx.close() - pipeline.removeLast() pool.release(channel) + responseHandler.exceptionCaught(cause) } } + channel.pipeline() .addLast("client-handler", handler) - response.complete(object : MemcacheRequestHandle { - override fun handleEvent(evt: StreamingRequestEvent) { - when (evt) { - is StreamingRequestEvent.SendRequest -> { - channel.writeAndFlush(evt.request) - } + response.complete(object : MemcacheRequestController { - is StreamingRequestEvent.SendLastChunk -> { - channel.writeAndFlush(DefaultLastMemcacheContent(evt.chunk)) - val value = counter.incrementAndGet() - log.debug { - "Finished request counter: $value" - } - } + override fun sendRequest(request: BinaryMemcacheRequest) { + requestBodySize = request.totalBodyLength() - request.keyLength() - request.extrasLength() + channel.writeAndFlush(request) + requestSent = true + } - is StreamingRequestEvent.SendChunk -> { - channel.writeAndFlush(DefaultMemcacheContent(evt.chunk)) - } - - is StreamingRequestEvent.ExceptionCaught -> { - responseHandle.handleEvent(StreamingResponseEvent.ExceptionCaught(evt.exception)) - channel.close() - pipeline.removeLast() - pool.release(channel) + override fun sendContent(content: MemcacheContent) { + val size = content.content().readableBytes() + channel.writeAndFlush(content).addListener { + requestBodyBytesSent += size + requestBodySent = true + if(content is LastMemcacheContent) { + requestFinished = true } } } + + override fun exceptionCaught(ex: Throwable) { + connectionClosedByTheRemoteServer = false + channel.close() + } }) } else { response.completeExceptionally(channelFuture.cause()) diff --git a/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/client/MemcacheRequestController.kt b/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/client/MemcacheRequestController.kt new file mode 100644 index 0000000..06cc772 --- /dev/null +++ b/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/client/MemcacheRequestController.kt @@ -0,0 +1,13 @@ +package net.woggioni.rbcs.server.memcache.client + +import io.netty.handler.codec.memcache.MemcacheContent +import io.netty.handler.codec.memcache.binary.BinaryMemcacheRequest + +interface MemcacheRequestController { + + fun sendRequest(request : BinaryMemcacheRequest) + + fun sendContent(content : MemcacheContent) + + fun exceptionCaught(ex : Throwable) +} diff --git a/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/client/MemcacheResponseHandler.kt b/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/client/MemcacheResponseHandler.kt new file mode 100644 index 0000000..19bb930 --- /dev/null +++ b/rbcs-server-memcache/src/main/kotlin/net/woggioni/rbcs/server/memcache/client/MemcacheResponseHandler.kt @@ -0,0 +1,14 @@ +package net.woggioni.rbcs.server.memcache.client + +import io.netty.handler.codec.memcache.MemcacheContent +import io.netty.handler.codec.memcache.binary.BinaryMemcacheResponse + +interface MemcacheResponseHandler { + + + fun responseReceived(response : BinaryMemcacheResponse) + + fun contentReceived(content : MemcacheContent) + + fun exceptionCaught(ex : Throwable) +} diff --git a/rbcs-server-memcache/src/main/resources/net/woggioni/rbcs/server/memcache/schema/rbcs-memcache.xsd b/rbcs-server-memcache/src/main/resources/net/woggioni/rbcs/server/memcache/schema/rbcs-memcache.xsd index acb1db4..66a57a5 100644 --- a/rbcs-server-memcache/src/main/resources/net/woggioni/rbcs/server/memcache/schema/rbcs-memcache.xsd +++ b/rbcs-server-memcache/src/main/resources/net/woggioni/rbcs/server/memcache/schema/rbcs-memcache.xsd @@ -20,10 +20,10 @@ - - + + diff --git a/rbcs-server-memcache/src/test/kotlin/net/woggioni/rbcs/server/memcache/client/ByteBufferTest.kt b/rbcs-server-memcache/src/test/kotlin/net/woggioni/rbcs/server/memcache/client/ByteBufferTest.kt new file mode 100644 index 0000000..5086dde --- /dev/null +++ b/rbcs-server-memcache/src/test/kotlin/net/woggioni/rbcs/server/memcache/client/ByteBufferTest.kt @@ -0,0 +1,27 @@ +package net.woggioni.rbcs.server.memcache.client + +import io.netty.buffer.ByteBufUtil +import io.netty.buffer.Unpooled +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Test +import java.io.ByteArrayInputStream +import java.nio.ByteBuffer +import java.nio.channels.Channels +import kotlin.random.Random + +class ByteBufferTest { + + @Test + fun test() { + val byteBuffer = ByteBuffer.allocate(0x100) + val originalBytes = Random(101325).nextBytes(0x100) + Channels.newChannel(ByteArrayInputStream(originalBytes)).use { source -> + source.read(byteBuffer) + } + byteBuffer.flip() + val buf = Unpooled.buffer() + buf.writeBytes(byteBuffer) + val finalBytes = ByteBufUtil.getBytes(buf) + Assertions.assertArrayEquals(originalBytes, finalBytes) + } +} \ No newline at end of file diff --git a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/Logging.kt b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/Logging.kt deleted file mode 100644 index 31ef121..0000000 --- a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/Logging.kt +++ /dev/null @@ -1,30 +0,0 @@ -package net.woggioni.rbcs.server - -import io.netty.channel.ChannelHandlerContext -import org.slf4j.Logger -import java.net.InetSocketAddress - -inline fun Logger.trace(ctx : ChannelHandlerContext, messageBuilder : () -> String) { - log(this, ctx, { isTraceEnabled }, { trace(it) } , messageBuilder) -} -inline fun Logger.debug(ctx : ChannelHandlerContext, messageBuilder : () -> String) { - log(this, ctx, { isDebugEnabled }, { debug(it) } , messageBuilder) -} -inline fun Logger.info(ctx : ChannelHandlerContext, messageBuilder : () -> String) { - log(this, ctx, { isInfoEnabled }, { info(it) } , messageBuilder) -} -inline fun Logger.warn(ctx : ChannelHandlerContext, messageBuilder : () -> String) { - log(this, ctx, { isWarnEnabled }, { warn(it) } , messageBuilder) -} -inline fun Logger.error(ctx : ChannelHandlerContext, messageBuilder : () -> String) { - log(this, ctx, { isErrorEnabled }, { error(it) } , messageBuilder) -} - -inline fun log(log : Logger, ctx : ChannelHandlerContext, - filter : Logger.() -> Boolean, - loggerMethod : Logger.(String) -> Unit, messageBuilder : () -> String) { - if(log.filter()) { - val clientAddress = (ctx.channel().remoteAddress() as InetSocketAddress).address.hostAddress - log.loggerMethod(clientAddress + " - " + messageBuilder()) - } -} \ No newline at end of file diff --git a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/RemoteBuildCacheServer.kt b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/RemoteBuildCacheServer.kt index 4971620..018b75a 100644 --- a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/RemoteBuildCacheServer.kt +++ b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/RemoteBuildCacheServer.kt @@ -37,7 +37,7 @@ import net.woggioni.rbcs.common.PasswordSecurity.decodePasswordHash import net.woggioni.rbcs.common.PasswordSecurity.hashPassword import net.woggioni.rbcs.common.RBCS.toUrl import net.woggioni.rbcs.common.Xml -import net.woggioni.rbcs.common.contextLogger +import net.woggioni.rbcs.common.createLogger import net.woggioni.rbcs.common.debug import net.woggioni.rbcs.common.info import net.woggioni.rbcs.server.auth.AbstractNettyHttpAuthenticator @@ -47,7 +47,10 @@ import net.woggioni.rbcs.server.auth.RoleAuthorizer import net.woggioni.rbcs.server.configuration.Parser import net.woggioni.rbcs.server.configuration.Serializer import net.woggioni.rbcs.server.exception.ExceptionHandler +import net.woggioni.rbcs.server.handler.MaxRequestSizeHandler import net.woggioni.rbcs.server.handler.ServerHandler +import net.woggioni.rbcs.server.handler.TraceHandler +import net.woggioni.rbcs.server.throttling.BucketManager import net.woggioni.rbcs.server.throttling.ThrottlingHandler import java.io.OutputStream import java.net.InetSocketAddress @@ -56,19 +59,23 @@ import java.nio.file.Path import java.security.KeyStore import java.security.PrivateKey import java.security.cert.X509Certificate +import java.time.Duration +import java.time.Instant import java.util.Arrays import java.util.Base64 +import java.util.concurrent.CompletableFuture import java.util.concurrent.Future import java.util.concurrent.TimeUnit +import java.util.concurrent.TimeoutException import java.util.regex.Matcher import java.util.regex.Pattern import javax.naming.ldap.LdapName import javax.net.ssl.SSLPeerUnverifiedException class RemoteBuildCacheServer(private val cfg: Configuration) { - private val log = contextLogger() companion object { + private val log = createLogger() val userAttribute: AttributeKey = AttributeKey.valueOf("user") val groupAttribute: AttributeKey> = AttributeKey.valueOf("group") @@ -142,7 +149,9 @@ class RemoteBuildCacheServer(private val cfg: Configuration) { private class NettyHttpBasicAuthenticator( private val users: Map, authorizer: Authorizer ) : AbstractNettyHttpAuthenticator(authorizer) { - private val log = contextLogger() + companion object { + private val log = createLogger() + } override fun authenticate(ctx: ChannelHandlerContext, req: HttpRequest): AuthenticationResult? { val authorizationHeader = req.headers()[HttpHeaderNames.AUTHORIZATION] ?: let { @@ -242,13 +251,13 @@ class RemoteBuildCacheServer(private val cfg: Configuration) { } return keystore } + + private val log = createLogger() } - private val log = contextLogger() + private val cacheHandlerFactory = cfg.cache.materialize() - private val cache = cfg.cache.materialize() - - private val exceptionHandler = ExceptionHandler() + private val bucketManager = BucketManager.from(cfg) private val authenticator = when (val auth = cfg.authentication) { is Configuration.BasicAuthentication -> NettyHttpBasicAuthenticator(cfg.users, RoleAuthorizer()) @@ -359,59 +368,94 @@ class RemoteBuildCacheServer(private val cfg: Configuration) { pipeline.addLast(SSL_HANDLER_NAME, it) } pipeline.addLast(HttpServerCodec()) + pipeline.addLast(MaxRequestSizeHandler.NAME, MaxRequestSizeHandler(cfg.connection.maxRequestSize)) pipeline.addLast(HttpChunkContentCompressor(1024)) pipeline.addLast(ChunkedWriteHandler()) authenticator?.let { pipeline.addLast(it) } - pipeline.addLast(ThrottlingHandler(cfg)) + pipeline.addLast(ThrottlingHandler(bucketManager, cfg.connection)) val serverHandler = let { val prefix = Path.of("/").resolve(Path.of(cfg.serverPath ?: "/")) - ServerHandler(cache, prefix) + ServerHandler(prefix) } - pipeline.addLast(eventExecutorGroup, serverHandler) - pipeline.addLast(exceptionHandler) + pipeline.addLast(eventExecutorGroup, ServerHandler.NAME, serverHandler) + pipeline.addLast(cacheHandlerFactory.newHandler()) + pipeline.addLast(TraceHandler) + pipeline.addLast(ExceptionHandler) } override fun close() { - cache.close() + cacheHandlerFactory.close() } } class ServerHandle( - httpChannelFuture: ChannelFuture, + closeFuture: ChannelFuture, + private val bossGroup: EventExecutorGroup, private val executorGroups: Iterable, - private val serverInitializer: AutoCloseable - ) : AutoCloseable { - private val httpChannel: Channel = httpChannelFuture.channel() - private val closeFuture: ChannelFuture = httpChannel.closeFuture() - private val log = contextLogger() + private val serverInitializer: AutoCloseable, + ) : Future by from(closeFuture, executorGroups, serverInitializer) { - fun shutdown(): Future { - return httpChannel.close() - } + companion object { + private val log = createLogger() - override fun close() { - try { - closeFuture.sync() - } catch (ex: Throwable) { - log.error(ex.message, ex) - } - try { - serverInitializer.close() - } catch (ex: Throwable) { - log.error(ex.message, ex) - } - executorGroups.forEach { - try { - it.shutdownGracefully().sync() - } catch (ex: Throwable) { - log.error(ex.message, ex) + private fun from( + closeFuture: ChannelFuture, + executorGroups: Iterable, + serverInitializer: AutoCloseable + ): CompletableFuture { + val result = CompletableFuture() + closeFuture.addListener { + val errors = mutableListOf() + val deadline = Instant.now().plusSeconds(20) + + + for (executorGroup in executorGroups) { + val future = executorGroup.terminationFuture() + try { + val now = Instant.now() + if (now > deadline) { + future.get(0, TimeUnit.SECONDS) + } else { + future.get(Duration.between(now, deadline).toMillis(), TimeUnit.MILLISECONDS) + } + } + catch (te: TimeoutException) { + errors.addLast(te) + log.warn("Timeout while waiting for shutdown of $executorGroup", te) + } catch (ex: Throwable) { + log.warn(ex.message, ex) + errors.addLast(ex) + } + } + try { + serverInitializer.close() + } catch (ex: Throwable) { + log.error(ex.message, ex) + errors.addLast(ex) + } + if(errors.isEmpty()) { + result.complete(null) + } else { + result.completeExceptionally(errors.first()) + } + } + + return result.thenAccept { + log.info { + "RemoteBuildCacheServer has been gracefully shut down" + } } } - log.info { - "RemoteBuildCacheServer has been gracefully shut down" + } + + + fun sendShutdownSignal() { + bossGroup.shutdownGracefully() + executorGroups.map { + it.shutdownGracefully() } } } @@ -442,10 +486,16 @@ class RemoteBuildCacheServer(private val cfg: Configuration) { // Bind and start to accept incoming connections. val bindAddress = InetSocketAddress(cfg.host, cfg.port) - val httpChannel = bootstrap.bind(bindAddress).sync() + val httpChannel = bootstrap.bind(bindAddress).sync().channel() log.info { "RemoteBuildCacheServer is listening on ${cfg.host}:${cfg.port}" } - return ServerHandle(httpChannel, setOf(bossGroup, workerGroup, eventExecutorGroup), serverInitializer) + + return ServerHandle( + httpChannel.closeFuture(), + bossGroup, + setOf(workerGroup, eventExecutorGroup), + serverInitializer + ) } } diff --git a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/FileSystemCache.kt b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/FileSystemCache.kt index 6ea08cf..833bb16 100644 --- a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/FileSystemCache.kt +++ b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/FileSystemCache.kt @@ -1,42 +1,33 @@ package net.woggioni.rbcs.server.cache -import io.netty.buffer.ByteBufAllocator import net.woggioni.jwo.JWO -import net.woggioni.rbcs.api.Cache -import net.woggioni.rbcs.api.RequestHandle -import net.woggioni.rbcs.api.ResponseHandle -import net.woggioni.rbcs.api.event.RequestStreamingEvent -import net.woggioni.rbcs.api.event.ResponseStreamingEvent -import net.woggioni.rbcs.common.ByteBufOutputStream -import net.woggioni.rbcs.common.RBCS.digestString -import net.woggioni.rbcs.common.contextLogger -import net.woggioni.rbcs.common.extractChunk +import net.woggioni.rbcs.api.CacheValueMetadata +import net.woggioni.rbcs.common.createLogger +import java.io.ByteArrayOutputStream +import java.io.InputStream +import java.io.ObjectInputStream +import java.io.ObjectOutputStream +import java.io.Serializable +import java.nio.ByteBuffer +import java.nio.channels.Channels import java.nio.channels.FileChannel import java.nio.file.Files import java.nio.file.Path import java.nio.file.StandardCopyOption import java.nio.file.StandardOpenOption import java.nio.file.attribute.BasicFileAttributes -import java.security.MessageDigest import java.time.Duration import java.time.Instant -import java.util.concurrent.CompletableFuture -import java.util.zip.Deflater -import java.util.zip.DeflaterOutputStream -import java.util.zip.InflaterInputStream class FileSystemCache( val root: Path, - val maxAge: Duration, - val digestAlgorithm: String?, - val compressionEnabled: Boolean, - val compressionLevel: Int, - val chunkSize: Int -) : Cache { + val maxAge: Duration +) : AutoCloseable { + + class EntryValue(val metadata: CacheValueMetadata, val channel : FileChannel, val offset : Long, val size : Long) : Serializable private companion object { - @JvmStatic - private val log = contextLogger() + private val log = createLogger() } init { @@ -48,111 +39,77 @@ class FileSystemCache( private var nextGc = Instant.now() - override fun get(key: String, responseHandle: ResponseHandle, alloc: ByteBufAllocator) { - (digestAlgorithm - ?.let(MessageDigest::getInstance) - ?.let { md -> - digestString(key.toByteArray(), md) - } ?: key).let { digest -> - root.resolve(digest).takeIf(Files::exists) - ?.let { file -> - file.takeIf(Files::exists)?.let { file -> - responseHandle.handleEvent(ResponseStreamingEvent.RESPONSE_RECEIVED) - if (compressionEnabled) { - val compositeBuffer = alloc.compositeBuffer() - ByteBufOutputStream(compositeBuffer).use { outputStream -> - InflaterInputStream(Files.newInputStream(file)).use { inputStream -> - val ioBuffer = alloc.buffer(chunkSize) - try { - while (true) { - val read = ioBuffer.writeBytes(inputStream, chunkSize) - val last = read < 0 - if (read > 0) { - ioBuffer.readBytes(outputStream, read) - } - if (last) { - compositeBuffer.retain() - outputStream.close() - } - if (compositeBuffer.readableBytes() >= chunkSize || last) { - val chunk = extractChunk(compositeBuffer, alloc) - val evt = if (last) { - ResponseStreamingEvent.LastChunkReceived(chunk) - } else { - ResponseStreamingEvent.ChunkReceived(chunk) - } - responseHandle.handleEvent(evt) - } - if (last) break - } - } finally { - ioBuffer.release() - } - } - } - } else { - responseHandle.handleEvent( - ResponseStreamingEvent.FileReceived( - FileChannel.open(file, StandardOpenOption.READ) - ) - ) + fun get(key: String): EntryValue? = + root.resolve(key).takeIf(Files::exists) + ?.let { file -> + val size = Files.size(file) + val channel = FileChannel.open(file, StandardOpenOption.READ) + val source = Channels.newInputStream(channel) + val tmp = ByteArray(Integer.BYTES) + val buffer = ByteBuffer.wrap(tmp) + source.read(tmp) + buffer.rewind() + val offset = (Integer.BYTES + buffer.getInt()).toLong() + var count = 0 + val wrapper = object : InputStream() { + override fun read(): Int { + return source.read().also { + if (it > 0) count += it } } - } ?: responseHandle.handleEvent(ResponseStreamingEvent.NOT_FOUND) + + override fun read(b: ByteArray, off: Int, len: Int): Int { + return source.read(b, off, len).also { + if (it > 0) count += it + } + } + + override fun close() { + } + } + val metadata = ObjectInputStream(wrapper).use { ois -> + ois.readObject() as CacheValueMetadata + } + EntryValue(metadata, channel, offset, size) + } + + class FileSink(metadata: CacheValueMetadata, private val path: Path, private val tmpFile: Path) { + val channel: FileChannel + + init { + val baos = ByteArrayOutputStream() + ObjectOutputStream(baos).use { + it.writeObject(metadata) + } + Files.newOutputStream(tmpFile).use { + val bytes = baos.toByteArray() + val buffer = ByteBuffer.allocate(Integer.BYTES) + buffer.putInt(bytes.size) + buffer.rewind() + it.write(buffer.array()) + it.write(bytes) + } + channel = FileChannel.open(tmpFile, StandardOpenOption.APPEND) + } + + fun commit() { + channel.close() + Files.move(tmpFile, path, StandardCopyOption.ATOMIC_MOVE) + } + + fun rollback() { + channel.close() + Files.delete(path) } } - override fun put( + fun put( key: String, - responseHandle: ResponseHandle, - alloc: ByteBufAllocator - ): CompletableFuture { - try { - (digestAlgorithm - ?.let(MessageDigest::getInstance) - ?.let { md -> - digestString(key.toByteArray(), md) - } ?: key).let { digest -> - val file = root.resolve(digest) - val tmpFile = Files.createTempFile(root, null, ".tmp") - val stream = Files.newOutputStream(tmpFile).let { - if (compressionEnabled) { - val deflater = Deflater(compressionLevel) - DeflaterOutputStream(it, deflater) - } else { - it - } - } - return CompletableFuture.completedFuture(object : RequestHandle { - override fun handleEvent(evt: RequestStreamingEvent) { - try { - when (evt) { - is RequestStreamingEvent.LastChunkReceived -> { - evt.chunk.readBytes(stream, evt.chunk.readableBytes()) - stream.close() - Files.move(tmpFile, file, StandardCopyOption.ATOMIC_MOVE) - responseHandle.handleEvent(ResponseStreamingEvent.RESPONSE_RECEIVED) - } - - is RequestStreamingEvent.ChunkReceived -> { - evt.chunk.readBytes(stream, evt.chunk.readableBytes()) - } - - is RequestStreamingEvent.ExceptionCaught -> { - Files.delete(tmpFile) - stream.close() - } - } - } catch (ex: Throwable) { - responseHandle.handleEvent(ResponseStreamingEvent.ExceptionCaught(ex)) - } - } - }) - } - } catch (ex: Throwable) { - responseHandle.handleEvent(ResponseStreamingEvent.ExceptionCaught(ex)) - return CompletableFuture.failedFuture(ex) - } + metadata: CacheValueMetadata, + ): FileSink { + val file = root.resolve(key) + val tmpFile = Files.createTempFile(root, null, ".tmp") + return FileSink(metadata, file, tmpFile) } private val garbageCollector = Thread.ofVirtual().name("file-system-cache-gc").start { diff --git a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/FileSystemCacheConfiguration.kt b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/FileSystemCacheConfiguration.kt index 7a029af..081cdf1 100644 --- a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/FileSystemCacheConfiguration.kt +++ b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/FileSystemCacheConfiguration.kt @@ -1,8 +1,9 @@ package net.woggioni.rbcs.server.cache +import net.woggioni.jwo.Application +import net.woggioni.rbcs.api.CacheHandlerFactory import net.woggioni.rbcs.api.Configuration import net.woggioni.rbcs.common.RBCS -import net.woggioni.jwo.Application import java.nio.file.Path import java.time.Duration @@ -14,14 +15,16 @@ data class FileSystemCacheConfiguration( val compressionLevel: Int, val chunkSize: Int, ) : Configuration.Cache { - override fun materialize() = FileSystemCache( - root ?: Application.builder("rbcs").build().computeCacheDirectory(), - maxAge, - digestAlgorithm, - compressionEnabled, - compressionLevel, - chunkSize, - ) + + override fun materialize() = object : CacheHandlerFactory { + private val cache = FileSystemCache(root ?: Application.builder("rbcs").build().computeCacheDirectory(), maxAge) + + override fun close() { + cache.close() + } + + override fun newHandler() = FileSystemCacheHandler(cache, digestAlgorithm, compressionEnabled, compressionLevel, chunkSize) + } override fun getNamespaceURI() = RBCS.RBCS_NAMESPACE_URI diff --git a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/FileSystemCacheHandler.kt b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/FileSystemCacheHandler.kt new file mode 100644 index 0000000..fca3bc6 --- /dev/null +++ b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/FileSystemCacheHandler.kt @@ -0,0 +1,124 @@ +package net.woggioni.rbcs.server.cache + +import io.netty.buffer.ByteBuf +import io.netty.channel.ChannelHandlerContext +import io.netty.channel.DefaultFileRegion +import io.netty.channel.SimpleChannelInboundHandler +import io.netty.handler.codec.http.LastHttpContent +import io.netty.handler.stream.ChunkedNioFile +import net.woggioni.rbcs.api.CacheValueMetadata +import net.woggioni.rbcs.api.message.CacheMessage +import net.woggioni.rbcs.api.message.CacheMessage.CacheContent +import net.woggioni.rbcs.api.message.CacheMessage.CacheGetRequest +import net.woggioni.rbcs.api.message.CacheMessage.CachePutRequest +import net.woggioni.rbcs.api.message.CacheMessage.CachePutResponse +import net.woggioni.rbcs.api.message.CacheMessage.CacheValueFoundResponse +import net.woggioni.rbcs.api.message.CacheMessage.CacheValueNotFoundResponse +import net.woggioni.rbcs.api.message.CacheMessage.LastCacheContent +import net.woggioni.rbcs.common.RBCS.processCacheKey +import java.nio.channels.Channels +import java.util.Base64 +import java.util.zip.Deflater +import java.util.zip.DeflaterOutputStream +import java.util.zip.InflaterInputStream + +class FileSystemCacheHandler( + private val cache: FileSystemCache, + private val digestAlgorithm: String?, + private val compressionEnabled: Boolean, + private val compressionLevel: Int, + private val chunkSize: Int +) : SimpleChannelInboundHandler() { + + private inner class InProgressPutRequest( + val key : String, + private val fileSink : FileSystemCache.FileSink + ) { + + private val stream = Channels.newOutputStream(fileSink.channel).let { + if (compressionEnabled) { + DeflaterOutputStream(it, Deflater(compressionLevel)) + } else { + it + } + } + + fun write(buf: ByteBuf) { + buf.readBytes(stream, buf.readableBytes()) + } + + fun commit() { + stream.close() + fileSink.commit() + } + + fun rollback() { + fileSink.rollback() + } + } + + private var inProgressPutRequest: InProgressPutRequest? = null + + override fun channelRead0(ctx: ChannelHandlerContext, msg: CacheMessage) { + when (msg) { + is CacheGetRequest -> handleGetRequest(ctx, msg) + is CachePutRequest -> handlePutRequest(ctx, msg) + is LastCacheContent -> handleLastCacheContent(ctx, msg) + is CacheContent -> handleCacheContent(ctx, msg) + else -> ctx.fireChannelRead(msg) + } + } + + private fun handleGetRequest(ctx: ChannelHandlerContext, msg: CacheGetRequest) { + val key = String(Base64.getUrlEncoder().encode(processCacheKey(msg.key, digestAlgorithm))) + cache.get(key)?.also { entryValue -> + ctx.writeAndFlush(CacheValueFoundResponse(msg.key, entryValue.metadata)) + entryValue.channel.let { channel -> + if(compressionEnabled) { + InflaterInputStream(Channels.newInputStream(channel)).use { stream -> + + outerLoop@ + while (true) { + val buf = ctx.alloc().heapBuffer(chunkSize) + while(buf.readableBytes() < chunkSize) { + val read = buf.writeBytes(stream, chunkSize) + if(read < 0) { + ctx.writeAndFlush(LastCacheContent(buf)) + break@outerLoop + } + } + ctx.writeAndFlush(CacheContent(buf)) + } + } + } else { + ctx.writeAndFlush(ChunkedNioFile(channel, entryValue.offset, entryValue.size - entryValue.offset, chunkSize)) + ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT) + } + } + } ?: ctx.writeAndFlush(CacheValueNotFoundResponse()) + } + + private fun handlePutRequest(ctx: ChannelHandlerContext, msg: CachePutRequest) { + val key = String(Base64.getUrlEncoder().encode(processCacheKey(msg.key, digestAlgorithm))) + val sink = cache.put(key, msg.metadata) + inProgressPutRequest = InProgressPutRequest(msg.key, sink) + } + + private fun handleCacheContent(ctx: ChannelHandlerContext, msg: CacheContent) { + inProgressPutRequest!!.write(msg.content()) + } + + private fun handleLastCacheContent(ctx: ChannelHandlerContext, msg: LastCacheContent) { + inProgressPutRequest?.let { request -> + inProgressPutRequest = null + request.write(msg.content()) + request.commit() + ctx.writeAndFlush(CachePutResponse(request.key)) + } + } + + override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { + inProgressPutRequest?.rollback() + super.exceptionCaught(ctx, cause) + } +} \ No newline at end of file diff --git a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/FileSystemCacheProvider.kt b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/FileSystemCacheProvider.kt index c1e524f..32092f5 100644 --- a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/FileSystemCacheProvider.kt +++ b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/FileSystemCacheProvider.kt @@ -33,7 +33,7 @@ class FileSystemCacheProvider : CacheProvider { val digestAlgorithm = el.renderAttribute("digest") ?: "MD5" val chunkSize = el.renderAttribute("chunk-size") ?.let(Integer::decode) - ?: 0x4000 + ?: 0x10000 return FileSystemCacheConfiguration( path, @@ -50,7 +50,9 @@ class FileSystemCacheProvider : CacheProvider { Xml.of(doc, result) { val prefix = doc.lookupPrefix(RBCS.RBCS_NAMESPACE_URI) attr("xs:type", "${prefix}:fileSystemCacheType", RBCS.XML_SCHEMA_NAMESPACE_URI) - attr("path", root.toString()) + root?.let { + attr("path", it.toString()) + } attr("max-age", maxAge.toString()) digestAlgorithm?.let { digestAlgorithm -> attr("digest", digestAlgorithm) diff --git a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/InMemoryCache.kt b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/InMemoryCache.kt index d1b66e1..9d14862 100644 --- a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/InMemoryCache.kt +++ b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/InMemoryCache.kt @@ -1,47 +1,41 @@ package net.woggioni.rbcs.server.cache import io.netty.buffer.ByteBuf -import io.netty.buffer.ByteBufAllocator -import net.woggioni.rbcs.api.Cache -import net.woggioni.rbcs.api.RequestHandle -import net.woggioni.rbcs.api.ResponseHandle -import net.woggioni.rbcs.api.event.RequestStreamingEvent -import net.woggioni.rbcs.api.event.ResponseStreamingEvent -import net.woggioni.rbcs.common.ByteBufOutputStream -import net.woggioni.rbcs.common.RBCS.digestString -import net.woggioni.rbcs.common.contextLogger -import net.woggioni.rbcs.common.extractChunk -import java.security.MessageDigest +import net.woggioni.rbcs.api.CacheValueMetadata +import net.woggioni.rbcs.common.createLogger import java.time.Duration import java.time.Instant -import java.util.concurrent.CompletableFuture import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.PriorityBlockingQueue import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicLong -import java.util.zip.Deflater -import java.util.zip.DeflaterOutputStream -import java.util.zip.Inflater -import java.util.zip.InflaterOutputStream + +private class CacheKey(private val value: ByteArray) { + override fun equals(other: Any?) = if (other is CacheKey) { + value.contentEquals(other.value) + } else false + + override fun hashCode() = value.contentHashCode() +} + +class CacheEntry( + val metadata: CacheValueMetadata, + val content: ByteBuf +) class InMemoryCache( private val maxAge: Duration, - private val maxSize: Long, - private val digestAlgorithm: String?, - private val compressionEnabled: Boolean, - private val compressionLevel: Int, - private val chunkSize : Int -) : Cache { + private val maxSize: Long +) : AutoCloseable { companion object { - @JvmStatic - private val log = contextLogger() + private val log = createLogger() } private val size = AtomicLong() - private val map = ConcurrentHashMap() + private val map = ConcurrentHashMap() - private class RemovalQueueElement(val key: String, val value: ByteBuf, val expiry: Instant) : + private class RemovalQueueElement(val key: CacheKey, val value: CacheEntry, val expiry: Instant) : Comparable { override fun compareTo(other: RemovalQueueElement) = expiry.compareTo(other.expiry) } @@ -54,14 +48,14 @@ class InMemoryCache( private val garbageCollector = Thread.ofVirtual().name("in-memory-cache-gc").start { while (running) { val el = removalQueue.poll(1, TimeUnit.SECONDS) ?: continue - val buf = el.value + val value = el.value val now = Instant.now() if (now > el.expiry) { - val removed = map.remove(el.key, buf) + val removed = map.remove(el.key, value) if (removed) { - updateSizeAfterRemoval(buf) + updateSizeAfterRemoval(value.content) //Decrease the reference count for map - buf.release() + value.content.release() } } else { removalQueue.put(el) @@ -73,12 +67,12 @@ class InMemoryCache( private fun removeEldest(): Long { while (true) { val el = removalQueue.take() - val buf = el.value - val removed = map.remove(el.key, buf) + val value = el.value + val removed = map.remove(el.key, value) if (removed) { - val newSize = updateSizeAfterRemoval(buf) + val newSize = updateSizeAfterRemoval(value.content) //Decrease the reference count for map - buf.release() + value.content.release() return newSize } } @@ -95,114 +89,27 @@ class InMemoryCache( garbageCollector.join() } - override fun get(key: String, responseHandle: ResponseHandle, alloc: ByteBufAllocator) { - try { - (digestAlgorithm - ?.let(MessageDigest::getInstance) - ?.let { md -> - digestString(key.toByteArray(), md) - } ?: key - ).let { digest -> - map[digest] - ?.let { value -> - val copy = value.retainedDuplicate() - responseHandle.handleEvent(ResponseStreamingEvent.RESPONSE_RECEIVED) - val output = alloc.compositeBuffer() - if (compressionEnabled) { - try { - val stream = ByteBufOutputStream(output).let { - val inflater = Inflater() - InflaterOutputStream(it, inflater) - } - stream.use { os -> - var readable = copy.readableBytes() - while (true) { - copy.readBytes(os, chunkSize.coerceAtMost(readable)) - readable = copy.readableBytes() - val last = readable == 0 - if (last) stream.flush() - if (output.readableBytes() >= chunkSize || last) { - val chunk = extractChunk(output, alloc) - val evt = if (last) { - ResponseStreamingEvent.LastChunkReceived(chunk) - } else { - ResponseStreamingEvent.ChunkReceived(chunk) - } - responseHandle.handleEvent(evt) - } - if (last) break - } - } - } finally { - copy.release() - } - } else { - responseHandle.handleEvent( - ResponseStreamingEvent.LastChunkReceived(copy) - ) - } - } ?: responseHandle.handleEvent(ResponseStreamingEvent.NOT_FOUND) - } - } catch (ex: Throwable) { - responseHandle.handleEvent(ResponseStreamingEvent.ExceptionCaught(ex)) + fun get(key: ByteArray) = map[CacheKey(key)]?.run { + CacheEntry(metadata, content.retainedDuplicate()) + } + + fun put( + key: ByteArray, + value: CacheEntry, + ) { + val cacheKey = CacheKey(key) + val oldSize = map.put(cacheKey, value)?.let { old -> + val result = old.content.readableBytes() + old.content.release() + result + } ?: 0 + val delta = value.content.readableBytes() - oldSize + var newSize = size.updateAndGet { currentSize: Long -> + currentSize + delta + } + removalQueue.put(RemovalQueueElement(cacheKey, value, Instant.now().plus(maxAge))) + while (newSize > maxSize) { + newSize = removeEldest() } } - - override fun put( - key: String, - responseHandle: ResponseHandle, - alloc: ByteBufAllocator - ): CompletableFuture { - return CompletableFuture.completedFuture(object : RequestHandle { - val buf = alloc.heapBuffer() - val stream = ByteBufOutputStream(buf).let { - if (compressionEnabled) { - val deflater = Deflater(compressionLevel) - DeflaterOutputStream(it, deflater) - } else { - it - } - } - - override fun handleEvent(evt: RequestStreamingEvent) { - when (evt) { - is RequestStreamingEvent.ChunkReceived -> { - evt.chunk.readBytes(stream, evt.chunk.readableBytes()) - if (evt is RequestStreamingEvent.LastChunkReceived) { - (digestAlgorithm - ?.let(MessageDigest::getInstance) - ?.let { md -> - digestString(key.toByteArray(), md) - } ?: key - ).let { digest -> - val oldSize = map.put(digest, buf.retain())?.let { old -> - val result = old.readableBytes() - old.release() - result - } ?: 0 - val delta = buf.readableBytes() - oldSize - var newSize = size.updateAndGet { currentSize : Long -> - currentSize + delta - } - removalQueue.put(RemovalQueueElement(digest, buf, Instant.now().plus(maxAge))) - while(newSize > maxSize) { - newSize = removeEldest() - } - stream.close() - responseHandle.handleEvent(ResponseStreamingEvent.RESPONSE_RECEIVED) - } - } - } - - is RequestStreamingEvent.ExceptionCaught -> { - stream.close() - } - - else -> { - - } - } - } - }) - } } \ No newline at end of file diff --git a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/InMemoryCacheConfiguration.kt b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/InMemoryCacheConfiguration.kt index 4793988..a6030e0 100644 --- a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/InMemoryCacheConfiguration.kt +++ b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/InMemoryCacheConfiguration.kt @@ -1,5 +1,6 @@ package net.woggioni.rbcs.server.cache +import net.woggioni.rbcs.api.CacheHandlerFactory import net.woggioni.rbcs.api.Configuration import net.woggioni.rbcs.common.RBCS import java.time.Duration @@ -12,14 +13,15 @@ data class InMemoryCacheConfiguration( val compressionLevel: Int, val chunkSize : Int ) : Configuration.Cache { - override fun materialize() = InMemoryCache( - maxAge, - maxSize, - digestAlgorithm, - compressionEnabled, - compressionLevel, - chunkSize - ) + override fun materialize() = object : CacheHandlerFactory { + private val cache = InMemoryCache(maxAge, maxSize) + + override fun close() { + cache.close() + } + + override fun newHandler() = InMemoryCacheHandler(cache, digestAlgorithm, compressionEnabled, compressionLevel) + } override fun getNamespaceURI() = RBCS.RBCS_NAMESPACE_URI diff --git a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/InMemoryCacheHandler.kt b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/InMemoryCacheHandler.kt new file mode 100644 index 0000000..6796021 --- /dev/null +++ b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/InMemoryCacheHandler.kt @@ -0,0 +1,135 @@ +package net.woggioni.rbcs.server.cache + +import io.netty.buffer.ByteBuf +import io.netty.channel.ChannelHandlerContext +import io.netty.channel.SimpleChannelInboundHandler +import net.woggioni.rbcs.api.message.CacheMessage +import net.woggioni.rbcs.api.message.CacheMessage.CacheContent +import net.woggioni.rbcs.api.message.CacheMessage.CacheGetRequest +import net.woggioni.rbcs.api.message.CacheMessage.CachePutRequest +import net.woggioni.rbcs.api.message.CacheMessage.CachePutResponse +import net.woggioni.rbcs.api.message.CacheMessage.CacheValueFoundResponse +import net.woggioni.rbcs.api.message.CacheMessage.CacheValueNotFoundResponse +import net.woggioni.rbcs.api.message.CacheMessage.LastCacheContent +import net.woggioni.rbcs.common.ByteBufOutputStream +import net.woggioni.rbcs.common.RBCS.processCacheKey +import java.util.zip.Deflater +import java.util.zip.DeflaterOutputStream +import java.util.zip.InflaterOutputStream + +class InMemoryCacheHandler( + private val cache: InMemoryCache, + private val digestAlgorithm: String?, + private val compressionEnabled: Boolean, + private val compressionLevel: Int +) : SimpleChannelInboundHandler() { + + private interface InProgressPutRequest : AutoCloseable { + val request: CachePutRequest + val buf: ByteBuf + + fun append(buf: ByteBuf) + } + + private inner class InProgressPlainPutRequest(ctx: ChannelHandlerContext, override val request: CachePutRequest) : + InProgressPutRequest { + override val buf = ctx.alloc().compositeBuffer() + + private val stream = ByteBufOutputStream(buf).let { + if (compressionEnabled) { + DeflaterOutputStream(it, Deflater(compressionLevel)) + } else { + it + } + } + + override fun append(buf: ByteBuf) { + this.buf.addComponent(true, buf.retain()) + } + + override fun close() { + buf.release() + } + } + + private inner class InProgressCompressedPutRequest( + ctx: ChannelHandlerContext, + override val request: CachePutRequest + ) : InProgressPutRequest { + + override val buf = ctx.alloc().heapBuffer() + + private val stream = ByteBufOutputStream(buf).let { + DeflaterOutputStream(it, Deflater(compressionLevel)) + } + + override fun append(buf: ByteBuf) { + buf.readBytes(stream, buf.readableBytes()) + } + + override fun close() { + stream.close() + } + } + + private var inProgressPutRequest: InProgressPutRequest? = null + + override fun channelRead0(ctx: ChannelHandlerContext, msg: CacheMessage) { + when (msg) { + is CacheGetRequest -> handleGetRequest(ctx, msg) + is CachePutRequest -> handlePutRequest(ctx, msg) + is LastCacheContent -> handleLastCacheContent(ctx, msg) + is CacheContent -> handleCacheContent(ctx, msg) + else -> ctx.fireChannelRead(msg) + } + } + + private fun handleGetRequest(ctx: ChannelHandlerContext, msg: CacheGetRequest) { + cache.get(processCacheKey(msg.key, digestAlgorithm))?.let { value -> + ctx.writeAndFlush(CacheValueFoundResponse(msg.key, value.metadata)) + if (compressionEnabled) { + val buf = ctx.alloc().heapBuffer() + InflaterOutputStream(ByteBufOutputStream(buf)).use { + value.content.readBytes(it, value.content.readableBytes()) + buf.retain() + } + ctx.writeAndFlush(LastCacheContent(buf)) + } else { + ctx.writeAndFlush(LastCacheContent(value.content)) + } + } ?: ctx.writeAndFlush(CacheValueNotFoundResponse()) + } + + private fun handlePutRequest(ctx: ChannelHandlerContext, msg: CachePutRequest) { + inProgressPutRequest = if(compressionEnabled) { + InProgressCompressedPutRequest(ctx, msg) + } else { + InProgressPlainPutRequest(ctx, msg) + } + } + + private fun handleCacheContent(ctx: ChannelHandlerContext, msg: CacheContent) { + inProgressPutRequest?.append(msg.content()) + } + + private fun handleLastCacheContent(ctx: ChannelHandlerContext, msg: LastCacheContent) { + handleCacheContent(ctx, msg) + inProgressPutRequest?.let { inProgressRequest -> + inProgressPutRequest = null + val buf = inProgressRequest.buf + buf.retain() + inProgressRequest.close() + val cacheKey = processCacheKey(inProgressRequest.request.key, digestAlgorithm) + cache.put(cacheKey, CacheEntry(inProgressRequest.request.metadata, buf)) + ctx.writeAndFlush(CachePutResponse(inProgressRequest.request.key)) + } + } + + override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { + inProgressPutRequest?.let { req -> + req.buf.release() + inProgressPutRequest = null + } + super.exceptionCaught(ctx, cause) + } +} \ No newline at end of file diff --git a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/InMemoryCacheProvider.kt b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/InMemoryCacheProvider.kt index c316a4f..f6987a4 100644 --- a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/InMemoryCacheProvider.kt +++ b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/cache/InMemoryCacheProvider.kt @@ -33,7 +33,7 @@ class InMemoryCacheProvider : CacheProvider { val digestAlgorithm = el.renderAttribute("digest") ?: "MD5" val chunkSize = el.renderAttribute("chunk-size") ?.let(Integer::decode) - ?: 0x4000 + ?: 0x10000 return InMemoryCacheConfiguration( maxAge, maxSize, diff --git a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/exception/ExceptionHandler.kt b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/exception/ExceptionHandler.kt index 05e5719..aac98ef 100644 --- a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/exception/ExceptionHandler.kt +++ b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/exception/ExceptionHandler.kt @@ -3,7 +3,7 @@ package net.woggioni.rbcs.server.exception import io.netty.buffer.Unpooled import io.netty.channel.ChannelDuplexHandler import io.netty.channel.ChannelFutureListener -import io.netty.channel.ChannelHandler +import io.netty.channel.ChannelHandler.Sharable import io.netty.channel.ChannelHandlerContext import io.netty.handler.codec.DecoderException import io.netty.handler.codec.http.DefaultFullHttpResponse @@ -17,12 +17,16 @@ import net.woggioni.rbcs.api.exception.CacheException import net.woggioni.rbcs.api.exception.ContentTooLargeException import net.woggioni.rbcs.common.contextLogger import net.woggioni.rbcs.common.debug +import net.woggioni.rbcs.common.log +import org.slf4j.event.Level +import org.slf4j.spi.LoggingEventBuilder +import java.net.ConnectException import java.net.SocketException import javax.net.ssl.SSLException import javax.net.ssl.SSLPeerUnverifiedException -@ChannelHandler.Sharable -class ExceptionHandler : ChannelDuplexHandler() { +@Sharable +object ExceptionHandler : ChannelDuplexHandler() { private val log = contextLogger() private val NOT_AUTHORIZED: FullHttpResponse = DefaultFullHttpResponse( @@ -31,12 +35,6 @@ class ExceptionHandler : ChannelDuplexHandler() { headers()[HttpHeaderNames.CONTENT_LENGTH] = "0" } - private val TOO_BIG: FullHttpResponse = DefaultFullHttpResponse( - HttpVersion.HTTP_1_1, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, Unpooled.EMPTY_BUFFER - ).apply { - headers()[HttpHeaderNames.CONTENT_LENGTH] = "0" - } - private val NOT_AVAILABLE: FullHttpResponse = DefaultFullHttpResponse( HttpVersion.HTTP_1_1, HttpResponseStatus.SERVICE_UNAVAILABLE, Unpooled.EMPTY_BUFFER ).apply { @@ -49,6 +47,12 @@ class ExceptionHandler : ChannelDuplexHandler() { headers()[HttpHeaderNames.CONTENT_LENGTH] = "0" } + private val TOO_BIG: FullHttpResponse = DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, Unpooled.EMPTY_BUFFER + ).apply { + headers()[HttpHeaderNames.CONTENT_LENGTH] = "0" + } + override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { when (cause) { is DecoderException -> { @@ -56,6 +60,11 @@ class ExceptionHandler : ChannelDuplexHandler() { ctx.close() } + is ConnectException -> { + log.error(cause.message, cause) + ctx.writeAndFlush(SERVER_ERROR.retainedDuplicate()) + } + is SocketException -> { log.debug(cause.message, cause) ctx.close() @@ -72,6 +81,9 @@ class ExceptionHandler : ChannelDuplexHandler() { } is ContentTooLargeException -> { + log.log(Level.DEBUG, ctx.channel()) { builder : LoggingEventBuilder -> + builder.setMessage("Request body is too large") + } ctx.writeAndFlush(TOO_BIG.retainedDuplicate()) .addListener(ChannelFutureListener.CLOSE_ON_FAILURE) } diff --git a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/handler/CacheContentHandler.kt b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/handler/CacheContentHandler.kt new file mode 100644 index 0000000..5c89269 --- /dev/null +++ b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/handler/CacheContentHandler.kt @@ -0,0 +1,28 @@ +package net.woggioni.rbcs.server.handler + +import io.netty.channel.ChannelHandler.Sharable +import io.netty.channel.ChannelHandlerContext +import io.netty.channel.SimpleChannelInboundHandler +import io.netty.handler.codec.http.HttpContent +import io.netty.handler.codec.http.LastHttpContent +import net.woggioni.rbcs.api.message.CacheMessage.CacheContent +import net.woggioni.rbcs.api.message.CacheMessage.LastCacheContent + +@Sharable +object CacheContentHandler : SimpleChannelInboundHandler() { + val NAME = this::class.java.name + + override fun channelRead0(ctx: ChannelHandlerContext, msg: HttpContent) { + when(msg) { + is LastHttpContent -> { + ctx.fireChannelRead(LastCacheContent(msg.content().retain())) + ctx.pipeline().remove(this) + } + else -> ctx.fireChannelRead(CacheContent(msg.content().retain())) + } + } + + override fun exceptionCaught(ctx: ChannelHandlerContext?, cause: Throwable?) { + super.exceptionCaught(ctx, cause) + } +} \ No newline at end of file diff --git a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/handler/MaxRequestSizeHandler.kt b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/handler/MaxRequestSizeHandler.kt new file mode 100644 index 0000000..e5babeb --- /dev/null +++ b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/handler/MaxRequestSizeHandler.kt @@ -0,0 +1,40 @@ +package net.woggioni.rbcs.server.handler + +import io.netty.channel.ChannelHandlerContext +import io.netty.channel.ChannelInboundHandlerAdapter +import io.netty.handler.codec.http.HttpContent +import io.netty.handler.codec.http.HttpRequest +import net.woggioni.rbcs.api.exception.ContentTooLargeException + + +class MaxRequestSizeHandler(private val maxRequestSize : Int) : ChannelInboundHandlerAdapter() { + companion object { + val NAME = MaxRequestSizeHandler::class.java.name + } + + private var cumulativeSize = 0 + + override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { + when(msg) { + is HttpRequest -> { + cumulativeSize = 0 + ctx.fireChannelRead(msg) + } + is HttpContent -> { + val exceeded = cumulativeSize > maxRequestSize + if(!exceeded) { + cumulativeSize += msg.content().readableBytes() + } + if(cumulativeSize > maxRequestSize) { + msg.release() + if(!exceeded) { + ctx.fireExceptionCaught(ContentTooLargeException("Request body is too large", null)) + } + } else { + ctx.fireChannelRead(msg) + } + } + else -> ctx.fireChannelRead(msg) + } + } +} \ No newline at end of file diff --git a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/handler/ServerHandler.kt b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/handler/ServerHandler.kt index e1add8c..67f5845 100644 --- a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/handler/ServerHandler.kt +++ b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/handler/ServerHandler.kt @@ -1,128 +1,148 @@ package net.woggioni.rbcs.server.handler -import io.netty.buffer.Unpooled -import io.netty.channel.ChannelFutureListener +import io.netty.channel.ChannelDuplexHandler import io.netty.channel.ChannelHandlerContext -import io.netty.channel.DefaultFileRegion -import io.netty.channel.SimpleChannelInboundHandler +import io.netty.channel.ChannelPromise import io.netty.handler.codec.http.DefaultFullHttpResponse import io.netty.handler.codec.http.DefaultHttpContent import io.netty.handler.codec.http.DefaultHttpResponse import io.netty.handler.codec.http.DefaultLastHttpContent -import io.netty.handler.codec.http.HttpContent import io.netty.handler.codec.http.HttpHeaderNames import io.netty.handler.codec.http.HttpHeaderValues +import io.netty.handler.codec.http.HttpHeaders import io.netty.handler.codec.http.HttpMethod -import io.netty.handler.codec.http.HttpObject import io.netty.handler.codec.http.HttpRequest import io.netty.handler.codec.http.HttpResponseStatus import io.netty.handler.codec.http.HttpUtil -import io.netty.handler.codec.http.LastHttpContent -import net.woggioni.rbcs.api.Cache -import net.woggioni.rbcs.api.RequestHandle -import net.woggioni.rbcs.api.ResponseHandle -import net.woggioni.rbcs.api.event.RequestStreamingEvent -import net.woggioni.rbcs.api.event.ResponseStreamingEvent -import net.woggioni.rbcs.common.contextLogger +import io.netty.handler.codec.http.HttpVersion +import net.woggioni.rbcs.api.CacheValueMetadata +import net.woggioni.rbcs.api.message.CacheMessage +import net.woggioni.rbcs.api.message.CacheMessage.CacheContent +import net.woggioni.rbcs.api.message.CacheMessage.CacheGetRequest +import net.woggioni.rbcs.api.message.CacheMessage.CachePutRequest +import net.woggioni.rbcs.api.message.CacheMessage.CachePutResponse +import net.woggioni.rbcs.api.message.CacheMessage.CacheValueFoundResponse +import net.woggioni.rbcs.api.message.CacheMessage.CacheValueNotFoundResponse +import net.woggioni.rbcs.api.message.CacheMessage.LastCacheContent +import net.woggioni.rbcs.common.createLogger import net.woggioni.rbcs.common.debug -import net.woggioni.rbcs.server.debug -import net.woggioni.rbcs.server.warn +import net.woggioni.rbcs.common.warn import java.nio.file.Path -import java.util.concurrent.CompletableFuture +import java.util.Locale -class ServerHandler(private val cache: Cache, private val serverPrefix: Path) : - SimpleChannelInboundHandler() { +class ServerHandler(private val serverPrefix: Path) : + ChannelDuplexHandler() { - private val log = contextLogger() + companion object { + private val log = createLogger() + val NAME = this::class.java.name + } - override fun channelRead0(ctx: ChannelHandlerContext, msg: HttpObject) { - when(msg) { - is HttpRequest -> handleRequest(ctx, msg) - is HttpContent -> handleContent(msg) + private var httpVersion = HttpVersion.HTTP_1_1 + private var keepAlive = true + + private fun resetRequestMetadata() { + httpVersion = HttpVersion.HTTP_1_1 + keepAlive = true + } + + private fun setRequestMetadata(req: HttpRequest) { + httpVersion = req.protocolVersion() + keepAlive = HttpUtil.isKeepAlive(req) + } + + private fun setKeepAliveHeader(headers: HttpHeaders) { + if (!keepAlive) { + headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) + } else { + headers.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE) } } - private var requestHandle : CompletableFuture = CompletableFuture.completedFuture(null) - private fun handleContent(content : HttpContent) { - content.retain() - requestHandle.thenAccept { handle -> - handle?.let { - val evt = if(content is LastHttpContent) { - RequestStreamingEvent.LastChunkReceived(content.content()) - - } else { - RequestStreamingEvent.ChunkReceived(content.content()) - } - it.handleEvent(evt) - content.release() - } ?: content.release() - } + override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { + when (msg) { + is HttpRequest -> handleRequest(ctx, msg) + else -> super.channelRead(ctx, msg) } + } - private fun handleRequest(ctx : ChannelHandlerContext, msg : HttpRequest) { - val keepAlive: Boolean = HttpUtil.isKeepAlive(msg) + override fun write(ctx: ChannelHandlerContext, msg: Any, promise: ChannelPromise?) { + if (msg is CacheMessage) { + try { + when (msg) { + is CachePutResponse -> { + val response = DefaultFullHttpResponse(httpVersion, HttpResponseStatus.CREATED) + val keyBytes = msg.key.toByteArray(Charsets.UTF_8) + response.headers().apply { + set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.TEXT_PLAIN) + set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED) + } + setKeepAliveHeader(response.headers()) + ctx.write(response) + val buf = ctx.alloc().buffer(keyBytes.size).apply { + writeBytes(keyBytes) + } + ctx.writeAndFlush(DefaultLastHttpContent(buf)) + } + + is CacheValueNotFoundResponse -> { + val response = DefaultFullHttpResponse(httpVersion, HttpResponseStatus.NOT_FOUND) + response.headers()[HttpHeaderNames.CONTENT_LENGTH] = 0 + setKeepAliveHeader(response.headers()) + ctx.writeAndFlush(response) + } + + is CacheValueFoundResponse -> { + val response = DefaultHttpResponse(httpVersion, HttpResponseStatus.OK) + response.headers().apply { + set(HttpHeaderNames.CONTENT_TYPE, msg.metadata.mimeType ?: HttpHeaderValues.APPLICATION_OCTET_STREAM) + msg.metadata.contentDisposition?.let { contentDisposition -> + set(HttpHeaderNames.CONTENT_DISPOSITION, contentDisposition) + } + } + setKeepAliveHeader(response.headers()) + response.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED) + ctx.writeAndFlush(response) + } + + is LastCacheContent -> { + ctx.writeAndFlush(DefaultLastHttpContent(msg.content())) + } + + is CacheContent -> { + ctx.writeAndFlush(DefaultHttpContent(msg.content())) + } + + else -> throw UnsupportedOperationException("This should never happen") + }.let { channelFuture -> + if (promise != null) { + channelFuture.addListener { + if (it.isSuccess) promise.setSuccess() + else promise.setFailure(it.cause()) + } + } + } + } finally { + resetRequestMetadata() + } + } else super.write(ctx, msg, promise) + } + + + private fun handleRequest(ctx: ChannelHandlerContext, msg: HttpRequest) { + setRequestMetadata(msg) val method = msg.method() if (method === HttpMethod.GET) { val path = Path.of(msg.uri()) val prefix = path.parent - val key = path.fileName?.toString() ?: let { - val response = DefaultFullHttpResponse(msg.protocolVersion(), HttpResponseStatus.NOT_FOUND) - response.headers()[HttpHeaderNames.CONTENT_LENGTH] = 0 - ctx.writeAndFlush(response) - return - } if (serverPrefix == prefix) { - val responseHandle = ResponseHandle { evt -> - when (evt) { - is ResponseStreamingEvent.ResponseReceived -> { - val response = DefaultHttpResponse(msg.protocolVersion(), HttpResponseStatus.OK) - response.headers()[HttpHeaderNames.CONTENT_TYPE] = HttpHeaderValues.APPLICATION_OCTET_STREAM - if (!keepAlive) { - response.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) - } else { - response.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE) - } - response.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED) - ctx.writeAndFlush(response) - } - - is ResponseStreamingEvent.LastChunkReceived -> { - val channelFuture = ctx.writeAndFlush(DefaultLastHttpContent(evt.chunk)) - if (!keepAlive) { - channelFuture - .addListener(ChannelFutureListener.CLOSE) - } - } - - is ResponseStreamingEvent.ChunkReceived -> { - ctx.writeAndFlush(DefaultHttpContent(evt.chunk)) - } - - is ResponseStreamingEvent.ExceptionCaught -> { - ctx.fireExceptionCaught(evt.exception) - } - - is ResponseStreamingEvent.NotFound -> { - val response = DefaultFullHttpResponse(msg.protocolVersion(), HttpResponseStatus.NOT_FOUND) - response.headers()[HttpHeaderNames.CONTENT_LENGTH] = 0 - ctx.writeAndFlush(response) - } - - is ResponseStreamingEvent.FileReceived -> { - val content = DefaultFileRegion(evt.file, 0, evt.file.size()) - if (keepAlive) { - ctx.write(content) - ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT.retainedDuplicate()) - } else { - ctx.writeAndFlush(content) - .addListener(ChannelFutureListener.CLOSE) - } - } - } - } - cache.get(key, responseHandle, ctx.alloc()) + ctx.pipeline().addAfter(NAME, CacheContentHandler.NAME, CacheContentHandler) + path.fileName?.toString() + ?.let(::CacheGetRequest) + ?.let(ctx::fireChannelRead) + ?: ctx.channel().write(CacheValueNotFoundResponse()) } else { log.warn(ctx) { "Got request for unhandled path '${msg.uri()}'" @@ -140,33 +160,14 @@ class ServerHandler(private val cache: Cache, private val serverPrefix: Path) : log.debug(ctx) { "Added value for key '$key' to build cache" } - val responseHandle = ResponseHandle { evt -> - when (evt) { - is ResponseStreamingEvent.ResponseReceived -> { - val response = DefaultFullHttpResponse( - msg.protocolVersion(), HttpResponseStatus.CREATED, - Unpooled.copiedBuffer(key.toByteArray()) - ) - response.headers()[HttpHeaderNames.CONTENT_LENGTH] = response.content().readableBytes() - ctx.writeAndFlush(response) - this.requestHandle = CompletableFuture.completedFuture(null) - } - is ResponseStreamingEvent.ChunkReceived -> { - evt.chunk.release() - } - is ResponseStreamingEvent.ExceptionCaught -> { - ctx.fireExceptionCaught(evt.exception) - } - else -> {} + ctx.pipeline().addAfter(NAME, CacheContentHandler.NAME, CacheContentHandler) + path.fileName?.toString() + ?.let { + val mimeType = HttpUtil.getMimeType(msg)?.toString() + CachePutRequest(key, CacheValueMetadata(msg.headers().get(HttpHeaderNames.CONTENT_DISPOSITION), mimeType)) } - } - - this.requestHandle = cache.put(key, responseHandle, ctx.alloc()).exceptionally { ex -> - ctx.fireExceptionCaught(ex) - null - }.also { - log.debug { "Replacing request handle with $it"} - } + ?.let(ctx::fireChannelRead) + ?: ctx.channel().write(CacheValueNotFoundResponse()) } else { log.warn(ctx) { "Got request for unhandled path '${msg.uri()}'" @@ -176,40 +177,7 @@ class ServerHandler(private val cache: Cache, private val serverPrefix: Path) : ctx.writeAndFlush(response) } } else if (method == HttpMethod.TRACE) { - val replayedRequestHead = ctx.alloc().buffer() - replayedRequestHead.writeCharSequence( - "TRACE ${Path.of(msg.uri())} ${msg.protocolVersion().text()}\r\n", - Charsets.US_ASCII - ) - msg.headers().forEach { (key, value) -> - replayedRequestHead.apply { - writeCharSequence(key, Charsets.US_ASCII) - writeCharSequence(": ", Charsets.US_ASCII) - writeCharSequence(value, Charsets.UTF_8) - writeCharSequence("\r\n", Charsets.US_ASCII) - } - } - replayedRequestHead.writeCharSequence("\r\n", Charsets.US_ASCII) - this.requestHandle = CompletableFuture.completedFuture(RequestHandle { evt -> - when(evt) { - is RequestStreamingEvent.LastChunkReceived -> { - ctx.writeAndFlush(DefaultLastHttpContent(evt.chunk.retain())) - this.requestHandle = CompletableFuture.completedFuture(null) - } - is RequestStreamingEvent.ChunkReceived -> ctx.writeAndFlush(DefaultHttpContent(evt.chunk.retain())) - is RequestStreamingEvent.ExceptionCaught -> ctx.fireExceptionCaught(evt.exception) - else -> { - - } - } - }).also { - log.debug { "Replacing request handle with $it"} - } - val response = DefaultHttpResponse(msg.protocolVersion(), HttpResponseStatus.OK) - response.headers().apply { - set(HttpHeaderNames.CONTENT_TYPE, "message/http") - } - ctx.writeAndFlush(response) + super.channelRead(ctx, msg) } else { log.warn(ctx) { "Got request with unhandled method '${msg.method().name()}'" @@ -220,10 +188,43 @@ class ServerHandler(private val cache: Cache, private val serverPrefix: Path) : } } - override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { - requestHandle.thenAccept { handle -> - handle?.handleEvent(RequestStreamingEvent.ExceptionCaught(cause)) + + data class ContentDisposition(val type: Type?, val fileName: String?) { + enum class Type { + attachment, `inline`; + + companion object { + @JvmStatic + fun parse(maybeString: String?) = maybeString.let { s -> + try { + java.lang.Enum.valueOf(Type::class.java, s) + } catch (ex: IllegalArgumentException) { + null + } + } + } } + + companion object { + @JvmStatic + fun parse(contentDisposition: String) : ContentDisposition { + val parts = contentDisposition.split(";").dropLastWhile { it.isEmpty() }.toTypedArray() + val dispositionType = parts[0].trim { it <= ' ' }.let(Type::parse) // Get the type (e.g., attachment) + + var filename: String? = null + for (i in 1.. { + val response = DefaultHttpResponse(msg.protocolVersion(), HttpResponseStatus.OK) + response.headers().apply { + set(HttpHeaderNames.CONTENT_TYPE, "message/http") + set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED) + } + ctx.write(response) + val replayedRequestHead = ctx.alloc().buffer() + replayedRequestHead.writeCharSequence( + "TRACE ${Path.of(msg.uri())} ${msg.protocolVersion().text()}\r\n", + Charsets.US_ASCII + ) + msg.headers().forEach { (key, value) -> + replayedRequestHead.apply { + writeCharSequence(key, Charsets.US_ASCII) + writeCharSequence(": ", Charsets.US_ASCII) + writeCharSequence(value, Charsets.UTF_8) + writeCharSequence("\r\n", Charsets.US_ASCII) + } + } + replayedRequestHead.writeCharSequence("\r\n", Charsets.US_ASCII) + ctx.writeAndFlush(replayedRequestHead) + } + is LastHttpContent -> { + ctx.writeAndFlush(msg) + } + is HttpContent -> ctx.writeAndFlush(msg) + else -> super.channelRead(ctx, msg) + } + } + + override fun exceptionCaught(ctx: ChannelHandlerContext?, cause: Throwable?) { + super.exceptionCaught(ctx, cause) + } +} \ No newline at end of file diff --git a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/throttling/ThrottlingHandler.kt b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/throttling/ThrottlingHandler.kt index 029b409..1302719 100644 --- a/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/throttling/ThrottlingHandler.kt +++ b/rbcs-server/src/main/kotlin/net/woggioni/rbcs/server/throttling/ThrottlingHandler.kt @@ -13,23 +13,20 @@ import net.woggioni.rbcs.common.contextLogger import net.woggioni.rbcs.server.RemoteBuildCacheServer import net.woggioni.jwo.Bucket import net.woggioni.jwo.LongMath +import net.woggioni.rbcs.common.createLogger import java.net.InetSocketAddress import java.time.Duration import java.time.temporal.ChronoUnit import java.util.concurrent.TimeUnit -class ThrottlingHandler(cfg: Configuration) : ChannelInboundHandlerAdapter() { +class ThrottlingHandler(private val bucketManager : BucketManager, + private val connectionConfiguration : Configuration.Connection) : ChannelInboundHandlerAdapter() { private companion object { - @JvmStatic - private val log = contextLogger() + private val log = createLogger() } - private val bucketManager = BucketManager.from(cfg) - - private val connectionConfiguration = cfg.connection - private var queuedContent : MutableList? = null /** @@ -98,6 +95,7 @@ class ThrottlingHandler(cfg: Configuration) : ChannelInboundHandlerAdapter() { handleBuckets(buckets, ctx, msg, false) }, waitDuration.toMillis(), TimeUnit.MILLISECONDS) } else { + this.queuedContent = null sendThrottledResponse(ctx, waitDuration) } } diff --git a/rbcs-server/src/main/resources/net/woggioni/rbcs/server/schema/rbcs.xsd b/rbcs-server/src/main/resources/net/woggioni/rbcs/server/schema/rbcs.xsd index caeac74..16cd8b2 100644 --- a/rbcs-server/src/main/resources/net/woggioni/rbcs/server/schema/rbcs.xsd +++ b/rbcs-server/src/main/resources/net/woggioni/rbcs/server/schema/rbcs.xsd @@ -39,7 +39,7 @@ - + @@ -52,11 +52,11 @@ - + - - + + @@ -64,12 +64,12 @@ - + - - + + @@ -222,10 +222,17 @@ - + + + + + + + + diff --git a/rbcs-server/src/test/kotlin/net/woggioni/rbcs/server/test/AbstractServerTest.kt b/rbcs-server/src/test/kotlin/net/woggioni/rbcs/server/test/AbstractServerTest.kt index a00cc94..df4383d 100644 --- a/rbcs-server/src/test/kotlin/net/woggioni/rbcs/server/test/AbstractServerTest.kt +++ b/rbcs-server/src/test/kotlin/net/woggioni/rbcs/server/test/AbstractServerTest.kt @@ -43,8 +43,9 @@ abstract class AbstractServerTest { } private fun stopServer() { - this.serverHandle?.use { - it.shutdown() + this.serverHandle?.let { + it.sendShutdownSignal() + it.get() } } } \ No newline at end of file diff --git a/rbcs-server/src/test/kotlin/net/woggioni/rbcs/server/test/AbstractTlsServerTest.kt b/rbcs-server/src/test/kotlin/net/woggioni/rbcs/server/test/AbstractTlsServerTest.kt index ac50be4..8477b5a 100644 --- a/rbcs-server/src/test/kotlin/net/woggioni/rbcs/server/test/AbstractTlsServerTest.kt +++ b/rbcs-server/src/test/kotlin/net/woggioni/rbcs/server/test/AbstractTlsServerTest.kt @@ -154,7 +154,7 @@ abstract class AbstractTlsServerTest : AbstractServerTest() { sequenceOf(writersGroup, readersGroup).map { it.name to it }.toMap(), FileSystemCacheConfiguration(this.cacheDir, maxAge = Duration.ofSeconds(3600 * 24), - compressionEnabled = true, + compressionEnabled = false, compressionLevel = Deflater.DEFAULT_COMPRESSION, digestAlgorithm = "MD5", chunkSize = 0x1000 diff --git a/rbcs-server/src/test/resources/net/woggioni/rbcs/server/test/valid/rbcs-memcached-tls.xml b/rbcs-server/src/test/resources/net/woggioni/rbcs/server/test/valid/rbcs-memcached-tls.xml index 2972235..57631cf 100644 --- a/rbcs-server/src/test/resources/net/woggioni/rbcs/server/test/valid/rbcs-memcached-tls.xml +++ b/rbcs-server/src/test/resources/net/woggioni/rbcs/server/test/valid/rbcs-memcached-tls.xml @@ -13,7 +13,7 @@ read-timeout="PT5M" write-timeout="PT5M"/> - + diff --git a/rbcs-server/src/test/resources/net/woggioni/rbcs/server/test/valid/rbcs-memcached.xml b/rbcs-server/src/test/resources/net/woggioni/rbcs/server/test/valid/rbcs-memcached.xml index 0b6fb2e..0437fe8 100644 --- a/rbcs-server/src/test/resources/net/woggioni/rbcs/server/test/valid/rbcs-memcached.xml +++ b/rbcs-server/src/test/resources/net/woggioni/rbcs/server/test/valid/rbcs-memcached.xml @@ -12,7 +12,7 @@ idle-timeout="PT30M" max-request-size="101325"/> - + diff --git a/rbcs-servlet/Dockerfile b/rbcs-servlet/Dockerfile new file mode 100644 index 0000000..e6991ed --- /dev/null +++ b/rbcs-servlet/Dockerfile @@ -0,0 +1,3 @@ +FROM tomcat:jdk21 + +COPY ./rbcs-servlet-*.war /usr/local/tomcat/webapps/rbcs-servlet.war \ No newline at end of file diff --git a/rbcs-servlet/build.gradle b/rbcs-servlet/build.gradle new file mode 100644 index 0000000..26ece13 --- /dev/null +++ b/rbcs-servlet/build.gradle @@ -0,0 +1,33 @@ +plugins { + alias(catalog.plugins.kotlin.jvm) + alias(catalog.plugins.gradle.docker) + id 'war' +} + +import com.bmuschko.gradle.docker.tasks.image.DockerBuildImage + +dependencies { + compileOnly catalog.jakarta.servlet.api + compileOnly catalog.jakarta.enterprise.cdi.api + + implementation catalog.jwo + implementation catalog.jakarta.el + implementation catalog.jakarta.cdi.el.api + implementation catalog.weld.servlet.core + implementation catalog.weld.web +} + +Provider prepareDockerBuild = tasks.register('prepareDockerBuild', Copy) { + group = 'docker' + into project.layout.buildDirectory.file('docker') + from(tasks.war) + from(file('Dockerfile')) +} + +Provider dockerBuild = tasks.register('dockerBuildImage', DockerBuildImage) { + group = 'docker' + dependsOn(prepareDockerBuild) + images.add('gitea.woggioni.net/woggioni/rbcs/servlet:latest') + images.add("gitea.woggioni.net/woggioni/rbcs/servlet:${version}") +} + diff --git a/rbcs-servlet/conf/server.xml b/rbcs-servlet/conf/server.xml new file mode 100644 index 0000000..ec62f95 --- /dev/null +++ b/rbcs-servlet/conf/server.xml @@ -0,0 +1,140 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/rbcs-servlet/conf/tomcat-users.xml b/rbcs-servlet/conf/tomcat-users.xml new file mode 100644 index 0000000..691b052 --- /dev/null +++ b/rbcs-servlet/conf/tomcat-users.xml @@ -0,0 +1,58 @@ + + + + + + + + + + diff --git a/rbcs-servlet/src/main/kotlin/net/woggioni/rbcs/servlet/CacheServlet.kt b/rbcs-servlet/src/main/kotlin/net/woggioni/rbcs/servlet/CacheServlet.kt new file mode 100644 index 0000000..92e95c8 --- /dev/null +++ b/rbcs-servlet/src/main/kotlin/net/woggioni/rbcs/servlet/CacheServlet.kt @@ -0,0 +1,169 @@ +package net.woggioni.rbcs.servlet + +import jakarta.annotation.PreDestroy +import jakarta.enterprise.context.ApplicationScoped +import jakarta.inject.Inject +import jakarta.servlet.annotation.WebServlet +import jakarta.servlet.http.HttpServlet +import jakarta.servlet.http.HttpServletRequest +import jakarta.servlet.http.HttpServletResponse +import net.woggioni.jwo.HttpClient.HttpStatus +import net.woggioni.jwo.JWO +import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream +import java.nio.file.Path +import java.time.Duration +import java.time.Instant +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.PriorityBlockingQueue +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicLong +import java.util.logging.Logger + + +private class CacheKey(private val value: ByteArray) { + override fun equals(other: Any?) = if (other is CacheKey) { + value.contentEquals(other.value) + } else false + + override fun hashCode() = value.contentHashCode() +} + + +@ApplicationScoped +open class InMemoryServletCache : AutoCloseable { + + private val maxAge= Duration.ofDays(7) + private val maxSize = 0x8000000 + + companion object { + @JvmStatic + private val log = Logger.getLogger(this::class.java.name) + } + + private val size = AtomicLong() + private val map = ConcurrentHashMap() + + private class RemovalQueueElement(val key: CacheKey, val value: ByteArray, val expiry: Instant) : + Comparable { + override fun compareTo(other: RemovalQueueElement) = expiry.compareTo(other.expiry) + } + + private val removalQueue = PriorityBlockingQueue() + + @Volatile + private var running = false + + private val garbageCollector = Thread.ofVirtual().name("in-memory-cache-gc").start { + while (running) { + val el = removalQueue.poll(1, TimeUnit.SECONDS) ?: continue + val value = el.value + val now = Instant.now() + if (now > el.expiry) { + val removed = map.remove(el.key, value) + if (removed) { + updateSizeAfterRemoval(value) + } + } else { + removalQueue.put(el) + Thread.sleep(minOf(Duration.between(now, el.expiry), Duration.ofSeconds(1))) + } + } + } + + private fun removeEldest(): Long { + while (true) { + val el = removalQueue.take() + val value = el.value + val removed = map.remove(el.key, value) + if (removed) { + val newSize = updateSizeAfterRemoval(value) + return newSize + } + } + } + + private fun updateSizeAfterRemoval(removed: ByteArray): Long { + return size.updateAndGet { currentSize: Long -> + currentSize - removed.size + } + } + + @PreDestroy + override fun close() { + running = false + garbageCollector.join() + } + + open fun get(key: ByteArray) = map[CacheKey(key)] + + open fun put( + key: ByteArray, + value: ByteArray, + ) { + val cacheKey = CacheKey(key) + val oldSize = map.put(cacheKey, value)?.let { old -> + val result = old.size + result + } ?: 0 + val delta = value.size - oldSize + var newSize = size.updateAndGet { currentSize: Long -> + currentSize + delta + } + removalQueue.put(RemovalQueueElement(cacheKey, value, Instant.now().plus(maxAge))) + while (newSize > maxSize) { + newSize = removeEldest() + } + } +} + + +@WebServlet(urlPatterns = ["/cache/*"]) +class CacheServlet : HttpServlet() { + companion object { + @JvmStatic + private val log = Logger.getLogger(this::class.java.name) + } + + @Inject + private lateinit var cache : InMemoryServletCache + + private fun getKey(req : HttpServletRequest) : String { + return Path.of(req.pathInfo).fileName.toString() + } + + override fun doPut(req: HttpServletRequest, resp: HttpServletResponse) { + val baos = ByteArrayOutputStream() + baos.use { + JWO.copy(req.inputStream, baos) + } + val key = getKey(req) + cache.put(key.toByteArray(Charsets.UTF_8), baos.toByteArray()) + resp.status = 201 + resp.setContentLength(0) + log.fine { + "[${Thread.currentThread().name}] Added value for key $key" + } + } + + override fun doGet(req: HttpServletRequest, resp: HttpServletResponse) { + val key = getKey(req) + val value = cache.get(key.toByteArray(Charsets.UTF_8)) + if (value == null) { + log.fine { + "[${Thread.currentThread().name}] Cache miss for key $key" + } + resp.status = HttpStatus.NOT_FOUND.code + resp.setContentLength(0) + } else { + log.fine { + "[${Thread.currentThread().name}] Cache hit for key $key" + } + resp.status = HttpStatus.OK.code + resp.setContentLength(value.size) + ByteArrayInputStream(value).use { + JWO.copy(it, resp.outputStream) + } + } + } +} \ No newline at end of file diff --git a/rbcs-servlet/src/main/resources/META-INF/beans.xml b/rbcs-servlet/src/main/resources/META-INF/beans.xml new file mode 100644 index 0000000..253b7a3 --- /dev/null +++ b/rbcs-servlet/src/main/resources/META-INF/beans.xml @@ -0,0 +1,5 @@ + + \ No newline at end of file diff --git a/rbcs-servlet/src/main/resources/META-INF/context.xml b/rbcs-servlet/src/main/resources/META-INF/context.xml new file mode 100644 index 0000000..85bd8de --- /dev/null +++ b/rbcs-servlet/src/main/resources/META-INF/context.xml @@ -0,0 +1,7 @@ + + + + \ No newline at end of file diff --git a/rbcs-servlet/src/main/resources/logging.properties b/rbcs-servlet/src/main/resources/logging.properties new file mode 100644 index 0000000..436e24a --- /dev/null +++ b/rbcs-servlet/src/main/resources/logging.properties @@ -0,0 +1,8 @@ +handlers = java.util.logging.ConsoleHandler +.level=INFO +net.woggioni.rbcs.servlet.level=FINEST +java.util.logging.ConsoleHandler.level=INFO +java.util.logging.ConsoleHandler.formatter = java.util.logging.SimpleFormatter +java.util.logging.SimpleFormatter.format = %1$tF %1$tT [%4$s] %2$s %5$s %6$s%n +org.apache.catalina.core.ContainerBase.[Catalina].level=ALL +org.apache.catalina.core.ContainerBase.[Catalina].handlers=java.util.logging.ConsoleHandler diff --git a/rbcs-servlet/src/main/webapp/WEB-INF/web.xml b/rbcs-servlet/src/main/webapp/WEB-INF/web.xml new file mode 100644 index 0000000..3f86349 --- /dev/null +++ b/rbcs-servlet/src/main/webapp/WEB-INF/web.xml @@ -0,0 +1,8 @@ + + + org.jboss.weld.module.web.servlet.WeldTerminalListener + + \ No newline at end of file diff --git a/settings.gradle b/settings.gradle index e44344d..1d41b7c 100644 --- a/settings.gradle +++ b/settings.gradle @@ -31,4 +31,5 @@ include 'rbcs-server-memcache' include 'rbcs-cli' include 'rbcs-client' include 'rbcs-server' +include 'rbcs-servlet' include 'docker'