package net.woggioni.gbcs import io.netty.bootstrap.ServerBootstrap import io.netty.buffer.ByteBuf import io.netty.buffer.Unpooled import io.netty.channel.Channel import io.netty.channel.ChannelDuplexHandler import io.netty.channel.ChannelFuture import io.netty.channel.ChannelFutureListener import io.netty.channel.ChannelHandlerContext import io.netty.channel.ChannelInitializer import io.netty.channel.ChannelOption import io.netty.channel.ChannelPromise import io.netty.channel.DefaultFileRegion import io.netty.channel.SimpleChannelInboundHandler import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.handler.codec.DecoderException import io.netty.handler.codec.compression.CompressionOptions import io.netty.handler.codec.http.DefaultFullHttpResponse import io.netty.handler.codec.http.DefaultHttpContent import io.netty.handler.codec.http.DefaultHttpResponse import io.netty.handler.codec.http.FullHttpRequest import io.netty.handler.codec.http.FullHttpResponse import io.netty.handler.codec.http.HttpContentCompressor import io.netty.handler.codec.http.HttpHeaderNames import io.netty.handler.codec.http.HttpHeaderValues import io.netty.handler.codec.http.HttpMethod import io.netty.handler.codec.http.HttpObjectAggregator import io.netty.handler.codec.http.HttpRequest import io.netty.handler.codec.http.HttpResponseStatus import io.netty.handler.codec.http.HttpServerCodec import io.netty.handler.codec.http.HttpUtil import io.netty.handler.codec.http.HttpVersion import io.netty.handler.codec.http.LastHttpContent import io.netty.handler.ssl.ClientAuth import io.netty.handler.ssl.SslContext import io.netty.handler.ssl.SslContextBuilder import io.netty.handler.stream.ChunkedNioFile import io.netty.handler.stream.ChunkedNioStream import io.netty.handler.stream.ChunkedWriteHandler import io.netty.util.concurrent.DefaultEventExecutorGroup import io.netty.util.concurrent.EventExecutorGroup import net.woggioni.gbcs.api.Cache import net.woggioni.gbcs.api.Configuration import net.woggioni.gbcs.api.Role import net.woggioni.gbcs.api.exception.ContentTooLargeException import net.woggioni.gbcs.auth.AbstractNettyHttpAuthenticator import net.woggioni.gbcs.auth.Authorizer import net.woggioni.gbcs.auth.ClientCertificateValidator import net.woggioni.gbcs.auth.RoleAuthorizer import net.woggioni.gbcs.base.GBCS.toUrl import net.woggioni.gbcs.base.PasswordSecurity.decodePasswordHash import net.woggioni.gbcs.base.PasswordSecurity.hashPassword import net.woggioni.gbcs.base.Xml import net.woggioni.gbcs.base.contextLogger import net.woggioni.gbcs.base.info import net.woggioni.gbcs.configuration.Parser import net.woggioni.gbcs.configuration.Serializer import net.woggioni.jwo.JWO import net.woggioni.jwo.Tuple2 import java.io.OutputStream import java.net.InetSocketAddress import java.nio.channels.FileChannel import java.nio.file.Files import java.nio.file.Path import java.security.KeyStore import java.security.PrivateKey import java.security.cert.X509Certificate import java.util.Arrays import java.util.Base64 import java.util.regex.Matcher import java.util.regex.Pattern import javax.naming.ldap.LdapName import javax.net.ssl.SSLEngine import javax.net.ssl.SSLPeerUnverifiedException class GradleBuildCacheServer(private val cfg: Configuration) { private class HttpChunkContentCompressor( threshold: Int, vararg compressionOptions: CompressionOptions = emptyArray() ) : HttpContentCompressor(threshold, *compressionOptions) { override fun write(ctx: ChannelHandlerContext, msg: Any, promise: ChannelPromise) { var message: Any? = msg if (message is ByteBuf) { // convert ByteBuf to HttpContent to make it work with compression. This is needed as we use the // ChunkedWriteHandler to send files when compression is enabled. val buff = message if (buff.isReadable) { // We only encode non empty buffers, as empty buffers can be used for determining when // the content has been flushed and it confuses the HttpContentCompressor // if we let it go message = DefaultHttpContent(buff) } } super.write(ctx, message, promise) } } private class ClientCertificateAuthenticator( authorizer: Authorizer, private val sslEngine: SSLEngine, private val userExtractor: Configuration.UserExtractor?, private val groupExtractor: Configuration.GroupExtractor?, ) : AbstractNettyHttpAuthenticator(authorizer) { companion object { private val log = contextLogger() } override fun authenticate(ctx: ChannelHandlerContext, req: HttpRequest): Set? { return try { sslEngine.session.peerCertificates } catch (es: SSLPeerUnverifiedException) { null }?.takeIf { it.isNotEmpty() }?.let { peerCertificates -> val clientCertificate = peerCertificates.first() as X509Certificate val user = userExtractor?.extract(clientCertificate) val group = groupExtractor?.extract(clientCertificate) (group?.roles ?: emptySet()) + (user?.roles ?: emptySet()) } } } private class NettyHttpBasicAuthenticator( private val users: Map, authorizer: Authorizer ) : AbstractNettyHttpAuthenticator(authorizer) { companion object { private val log = contextLogger() } override fun authenticate(ctx: ChannelHandlerContext, req: HttpRequest): Set? { val authorizationHeader = req.headers()[HttpHeaderNames.AUTHORIZATION] ?: let { log.debug(ctx) { "Missing Authorization header" } return null } val cursor = authorizationHeader.indexOf(' ') if (cursor < 0) { log.debug(ctx) { "Invalid Authorization header: '$authorizationHeader'" } return null } val authenticationType = authorizationHeader.substring(0, cursor) if ("Basic" != authenticationType) { log.debug(ctx) { "Invalid authentication type header: '$authenticationType'" } return null } val (username, password) = Base64.getDecoder().decode(authorizationHeader.substring(cursor + 1)) .let(::String) .let { val colon = it.indexOf(':') if (colon < 0) { log.debug(ctx) { "Missing colon from authentication" } return null } it.substring(0, colon) to it.substring(colon + 1) } return username.let(users::get)?.takeIf { user -> user.password?.let { passwordAndSalt -> val (_, salt) = decodePasswordHash(passwordAndSalt) hashPassword(password, Base64.getEncoder().encodeToString(salt)) == passwordAndSalt } ?: false }?.roles } } private class ServerInitializer( private val cfg: Configuration, private val eventExecutorGroup: EventExecutorGroup ) : ChannelInitializer() { companion object { private fun createSslCtx(tls: Configuration.Tls): SslContext { val keyStore = tls.keyStore return if (keyStore == null) { throw IllegalArgumentException("No keystore configured") } else { val javaKeyStore = loadKeystore(keyStore.file, keyStore.password) val serverKey = javaKeyStore.getKey( keyStore.keyAlias, keyStore.keyPassword?.let(String::toCharArray) ) as PrivateKey val serverCert: Array = Arrays.stream(javaKeyStore.getCertificateChain(keyStore.keyAlias)) .map { it as X509Certificate } .toArray { size -> Array(size) { null } } SslContextBuilder.forServer(serverKey, *serverCert).apply { if (tls.isVerifyClients) { clientAuth(ClientAuth.OPTIONAL) val trustStore = tls.trustStore if (trustStore != null) { val ts = loadKeystore(trustStore.file, trustStore.password) trustManager( ClientCertificateValidator.getTrustManager(ts, trustStore.isCheckCertificateStatus) ) } } }.build() } } fun loadKeystore(file: Path, password: String?): KeyStore { val ext = JWO.splitExtension(file) .map(Tuple2::get_2) .orElseThrow { IllegalArgumentException( "Keystore file '${file}' must have .jks, .p12, .pfx extension" ) } val keystore = when (ext.substring(1).lowercase()) { "jks" -> KeyStore.getInstance("JKS") "p12", "pfx" -> KeyStore.getInstance("PKCS12") else -> throw IllegalArgumentException( "Keystore file '${file}' must have .jks, .p12, .pfx extension" ) } Files.newInputStream(file).use { keystore.load(it, password?.let(String::toCharArray)) } return keystore } } private val sslContext: SslContext? = cfg.tls?.let(Companion::createSslCtx) private fun userExtractor(authentication: Configuration.ClientCertificateAuthentication) = authentication.userExtractor?.let { extractor -> val pattern = Pattern.compile(extractor.pattern) val rdnType = extractor.rdnType Configuration.UserExtractor { cert: X509Certificate -> val userName = LdapName(cert.subjectX500Principal.name).rdns.find { it.type == rdnType }?.let { pattern.matcher(it.value.toString()) }?.takeIf(Matcher::matches)?.group(1) cfg.users[userName] ?: throw java.lang.RuntimeException("Failed to extract user") } } private fun groupExtractor(authentication: Configuration.ClientCertificateAuthentication) = authentication.groupExtractor?.let { extractor -> val pattern = Pattern.compile(extractor.pattern) val rdnType = extractor.rdnType Configuration.GroupExtractor { cert: X509Certificate -> val groupName = LdapName(cert.subjectX500Principal.name).rdns.find { it.type == rdnType }?.let { pattern.matcher(it.value.toString()) }?.takeIf(Matcher::matches)?.group(1) cfg.groups[groupName] ?: throw java.lang.RuntimeException("Failed to extract group") } } override fun initChannel(ch: Channel) { val pipeline = ch.pipeline() val auth = cfg.authentication var authenticator: AbstractNettyHttpAuthenticator? = null if (auth is Configuration.BasicAuthentication) { val roleAuthorizer = RoleAuthorizer() authenticator = (NettyHttpBasicAuthenticator(cfg.users, roleAuthorizer)) } if (sslContext != null) { val sslHandler = sslContext.newHandler(ch.alloc()) pipeline.addLast(sslHandler) if (auth is Configuration.ClientCertificateAuthentication) { val roleAuthorizer = RoleAuthorizer() authenticator = ClientCertificateAuthenticator( roleAuthorizer, sslHandler.engine(), userExtractor(auth), groupExtractor(auth) ) } } pipeline.addLast(HttpServerCodec()) pipeline.addLast(HttpChunkContentCompressor(1024)) pipeline.addLast(ChunkedWriteHandler()) pipeline.addLast(HttpObjectAggregator(Int.MAX_VALUE)) authenticator?.let { pipeline.addLast(it) } val cacheImplementation = cfg.cache.materialize() val prefix = Path.of("/").resolve(Path.of(cfg.serverPath ?: "/")) pipeline.addLast(eventExecutorGroup, ServerHandler(cacheImplementation, prefix)) pipeline.addLast(ExceptionHandler()) } } private class ExceptionHandler : ChannelDuplexHandler() { private val log = contextLogger() private val NOT_AUTHORIZED: FullHttpResponse = DefaultFullHttpResponse( HttpVersion.HTTP_1_1, HttpResponseStatus.FORBIDDEN, Unpooled.EMPTY_BUFFER ).apply { headers()[HttpHeaderNames.CONTENT_LENGTH] = "0" } private val TOO_BIG: FullHttpResponse = DefaultFullHttpResponse( HttpVersion.HTTP_1_1, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, Unpooled.EMPTY_BUFFER ).apply { headers()[HttpHeaderNames.CONTENT_LENGTH] = "0" } override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { when (cause) { is DecoderException -> { log.error(cause.message, cause) ctx.close() } is SSLPeerUnverifiedException -> { ctx.writeAndFlush(NOT_AUTHORIZED.retainedDuplicate()) .addListener(ChannelFutureListener.CLOSE_ON_FAILURE) } is ContentTooLargeException -> { ctx.writeAndFlush(TOO_BIG.retainedDuplicate()) .addListener(ChannelFutureListener.CLOSE_ON_FAILURE) } else -> { log.error(cause.message, cause) ctx.close() } } } } private class ServerHandler(private val cache: Cache, private val serverPrefix: Path) : SimpleChannelInboundHandler() { companion object { private val log = contextLogger() } override fun channelRead0(ctx: ChannelHandlerContext, msg: FullHttpRequest) { val keepAlive: Boolean = HttpUtil.isKeepAlive(msg) val method = msg.method() if (method === HttpMethod.GET) { val path = Path.of(msg.uri()) val prefix = path.parent val key = path.fileName.toString() if (serverPrefix == prefix) { cache.get(key)?.let { channel -> log.debug(ctx) { "Cache hit for key '$key'" } val response = DefaultHttpResponse(msg.protocolVersion(), HttpResponseStatus.OK) response.headers()[HttpHeaderNames.CONTENT_TYPE] = HttpHeaderValues.APPLICATION_OCTET_STREAM if (!keepAlive) { response.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) response.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.IDENTITY) } else { response.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE) response.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED) } ctx.write(response) when (channel) { is FileChannel -> { if (keepAlive) { ctx.write(ChunkedNioFile(channel)) ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT) } else { ctx.writeAndFlush(DefaultFileRegion(channel, 0, channel.size())) .addListener(ChannelFutureListener.CLOSE) } } else -> { ctx.write(ChunkedNioStream(channel)) ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT) } } } ?: let { log.debug(ctx) { "Cache miss for key '$key'" } val response = DefaultFullHttpResponse(msg.protocolVersion(), HttpResponseStatus.NOT_FOUND) response.headers()[HttpHeaderNames.CONTENT_LENGTH] = 0 ctx.writeAndFlush(response) } } else { log.warn(ctx) { "Got request for unhandled path '${msg.uri()}'" } val response = DefaultFullHttpResponse(msg.protocolVersion(), HttpResponseStatus.BAD_REQUEST) response.headers()[HttpHeaderNames.CONTENT_LENGTH] = 0 ctx.writeAndFlush(response) } } else if (method === HttpMethod.PUT) { val path = Path.of(msg.uri()) val prefix = path.parent val key = path.fileName.toString() if (serverPrefix == prefix) { log.debug(ctx) { "Added value for key '$key' to build cache" } val bodyBytes = msg.content().run { if (isDirect) { ByteArray(readableBytes()).also { readBytes(it) } } else { array() } } cache.put(key, bodyBytes) val response = DefaultFullHttpResponse( msg.protocolVersion(), HttpResponseStatus.CREATED, Unpooled.copiedBuffer(key.toByteArray()) ) response.headers()[HttpHeaderNames.CONTENT_LENGTH] = response.content().readableBytes() ctx.writeAndFlush(response) } else { log.warn(ctx) { "Got request for unhandled path '${msg.uri()}'" } val response = DefaultFullHttpResponse(msg.protocolVersion(), HttpResponseStatus.BAD_REQUEST) response.headers()[HttpHeaderNames.CONTENT_LENGTH] = "0" ctx.writeAndFlush(response) } } else { log.warn(ctx) { "Got request with unhandled method '${msg.method().name()}'" } val response = DefaultFullHttpResponse(msg.protocolVersion(), HttpResponseStatus.BAD_REQUEST) response.headers()[HttpHeaderNames.CONTENT_LENGTH] = "0" ctx.writeAndFlush(response) } } } class ServerHandle( httpChannelFuture: ChannelFuture, private val executorGroups : Iterable ) : AutoCloseable { private val httpChannel: Channel = httpChannelFuture.channel() private val closeFuture: ChannelFuture = httpChannel.closeFuture() fun shutdown(): ChannelFuture { return httpChannel.close() } override fun close() { try { closeFuture.sync() } finally { executorGroups.forEach { it.shutdownGracefully().sync() } } log.info { "GradleBuildCacheServer has been gracefully shut down" } } } fun run(): ServerHandle { // Create the multithreaded event loops for the server val bossGroup = NioEventLoopGroup(0) val serverSocketChannel = NioServerSocketChannel::class.java val workerGroup = bossGroup val eventExecutorGroup = run { val threadFactory = if(cfg.isUseVirtualThread) { Thread.ofVirtual().factory() } else { null } DefaultEventExecutorGroup(Runtime.getRuntime().availableProcessors(), threadFactory) } // A helper class that simplifies server configuration val bootstrap = ServerBootstrap().apply { // Configure the server group(bossGroup, workerGroup) channel(serverSocketChannel) childHandler(ServerInitializer(cfg, eventExecutorGroup)) option(ChannelOption.SO_BACKLOG, 128) childOption(ChannelOption.SO_KEEPALIVE, true) } // Bind and start to accept incoming connections. val bindAddress = InetSocketAddress(cfg.host, cfg.port) val httpChannel = bootstrap.bind(bindAddress).sync() log.info { "GradleBuildCacheServer is listening on ${cfg.host}:${cfg.port}" } return ServerHandle(httpChannel, setOf(bossGroup, workerGroup, eventExecutorGroup)) } companion object { val DEFAULT_CONFIGURATION_URL by lazy { "classpath:net/woggioni/gbcs/gbcs-default.xml".toUrl() } fun loadConfiguration(configurationFile: Path): Configuration { val dbf = Xml.newDocumentBuilderFactory(null) val db = dbf.newDocumentBuilder() val doc = Files.newInputStream(configurationFile).use(db::parse) return Parser.parse(doc) } fun dumpConfiguration(conf: Configuration, outputStream: OutputStream) { Xml.write(Serializer.serialize(conf), outputStream) } private val log = contextLogger() } }