package net.woggioni.gbcs.server import io.netty.bootstrap.ServerBootstrap import io.netty.buffer.ByteBuf import io.netty.buffer.Unpooled import io.netty.channel.Channel import io.netty.channel.ChannelDuplexHandler import io.netty.channel.ChannelFuture import io.netty.channel.ChannelFutureListener import io.netty.channel.ChannelHandler.Sharable import io.netty.channel.ChannelHandlerContext import io.netty.channel.ChannelInboundHandlerAdapter import io.netty.channel.ChannelInitializer import io.netty.channel.ChannelOption import io.netty.channel.ChannelPromise import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.handler.codec.DecoderException import io.netty.handler.codec.compression.CompressionOptions import io.netty.handler.codec.http.DefaultFullHttpResponse import io.netty.handler.codec.http.DefaultHttpContent import io.netty.handler.codec.http.FullHttpResponse import io.netty.handler.codec.http.HttpContentCompressor import io.netty.handler.codec.http.HttpHeaderNames import io.netty.handler.codec.http.HttpObjectAggregator import io.netty.handler.codec.http.HttpRequest import io.netty.handler.codec.http.HttpResponseStatus import io.netty.handler.codec.http.HttpServerCodec import io.netty.handler.codec.http.HttpVersion import io.netty.handler.ssl.ClientAuth import io.netty.handler.ssl.SslContext import io.netty.handler.ssl.SslContextBuilder import io.netty.handler.ssl.SslHandler import io.netty.handler.stream.ChunkedWriteHandler import io.netty.handler.timeout.IdleStateEvent import io.netty.handler.timeout.IdleStateHandler import io.netty.handler.timeout.ReadTimeoutException import io.netty.handler.timeout.ReadTimeoutHandler import io.netty.handler.timeout.WriteTimeoutException import io.netty.handler.timeout.WriteTimeoutHandler import io.netty.util.concurrent.DefaultEventExecutorGroup import io.netty.util.concurrent.EventExecutorGroup import net.woggioni.gbcs.api.Configuration import net.woggioni.gbcs.api.Role import net.woggioni.gbcs.api.exception.ConfigurationException import net.woggioni.gbcs.api.exception.ContentTooLargeException import net.woggioni.gbcs.common.GBCS.toUrl import net.woggioni.gbcs.common.PasswordSecurity.decodePasswordHash import net.woggioni.gbcs.common.PasswordSecurity.hashPassword import net.woggioni.gbcs.common.Xml import net.woggioni.gbcs.common.contextLogger import net.woggioni.gbcs.common.debug import net.woggioni.gbcs.common.info import net.woggioni.gbcs.server.auth.AbstractNettyHttpAuthenticator import net.woggioni.gbcs.server.auth.Authorizer import net.woggioni.gbcs.server.auth.ClientCertificateValidator import net.woggioni.gbcs.server.auth.RoleAuthorizer import net.woggioni.gbcs.server.configuration.Parser import net.woggioni.gbcs.server.configuration.Serializer import net.woggioni.gbcs.server.handler.ServerHandler import net.woggioni.jwo.JWO import net.woggioni.jwo.Tuple2 import java.io.OutputStream import java.net.InetSocketAddress import java.nio.file.Files import java.nio.file.Path import java.security.KeyStore import java.security.PrivateKey import java.security.cert.X509Certificate import java.util.Arrays import java.util.Base64 import java.util.concurrent.TimeUnit import java.util.regex.Matcher import java.util.regex.Pattern import javax.naming.ldap.LdapName import javax.net.ssl.SSLPeerUnverifiedException class GradleBuildCacheServer(private val cfg: Configuration) { private val log = contextLogger() companion object { val DEFAULT_CONFIGURATION_URL by lazy { "classpath:net/woggioni/gbcs/gbcs-default.xml".toUrl() } private const val SSL_HANDLER_NAME = "sslHandler" fun loadConfiguration(configurationFile: Path): Configuration { val doc = Files.newInputStream(configurationFile).use { Xml.parseXml(configurationFile.toUri().toURL(), it) } return Parser.parse(doc) } fun dumpConfiguration(conf: Configuration, outputStream: OutputStream) { Xml.write(Serializer.serialize(conf), outputStream) } } private class HttpChunkContentCompressor( threshold: Int, vararg compressionOptions: CompressionOptions = emptyArray() ) : HttpContentCompressor(threshold, *compressionOptions) { override fun write(ctx: ChannelHandlerContext, msg: Any, promise: ChannelPromise) { var message: Any? = msg if (message is ByteBuf) { // convert ByteBuf to HttpContent to make it work with compression. This is needed as we use the // ChunkedWriteHandler to send files when compression is enabled. val buff = message if (buff.isReadable) { // We only encode non empty buffers, as empty buffers can be used for determining when // the content has been flushed and it confuses the HttpContentCompressor // if we let it go message = DefaultHttpContent(buff) } } super.write(ctx, message, promise) } } @Sharable private class ClientCertificateAuthenticator( authorizer: Authorizer, private val anonymousUserRoles: Set?, private val userExtractor: Configuration.UserExtractor?, private val groupExtractor: Configuration.GroupExtractor?, ) : AbstractNettyHttpAuthenticator(authorizer) { override fun authenticate(ctx: ChannelHandlerContext, req: HttpRequest): Set? { return try { val sslHandler = (ctx.pipeline().get(SSL_HANDLER_NAME) as? SslHandler) ?: throw ConfigurationException("Client certificate authentication cannot be used when TLS is disabled") val sslEngine = sslHandler.engine() sslEngine.session.peerCertificates.takeIf { it.isNotEmpty() }?.let { peerCertificates -> val clientCertificate = peerCertificates.first() as X509Certificate val user = userExtractor?.extract(clientCertificate) val group = groupExtractor?.extract(clientCertificate) (group?.roles ?: emptySet()) + (user?.roles ?: emptySet()) } ?: anonymousUserRoles } catch (es: SSLPeerUnverifiedException) { anonymousUserRoles } } } @Sharable private class NettyHttpBasicAuthenticator( private val users: Map, authorizer: Authorizer ) : AbstractNettyHttpAuthenticator(authorizer) { private val log = contextLogger() override fun authenticate(ctx: ChannelHandlerContext, req: HttpRequest): Set? { val authorizationHeader = req.headers()[HttpHeaderNames.AUTHORIZATION] ?: let { log.debug(ctx) { "Missing Authorization header" } return users[""]?.roles } val cursor = authorizationHeader.indexOf(' ') if (cursor < 0) { log.debug(ctx) { "Invalid Authorization header: '$authorizationHeader'" } return users[""]?.roles } val authenticationType = authorizationHeader.substring(0, cursor) if ("Basic" != authenticationType) { log.debug(ctx) { "Invalid authentication type header: '$authenticationType'" } return users[""]?.roles } val (username, password) = Base64.getDecoder().decode(authorizationHeader.substring(cursor + 1)) .let(::String) .let { val colon = it.indexOf(':') if (colon < 0) { log.debug(ctx) { "Missing colon from authentication" } return null } it.substring(0, colon) to it.substring(colon + 1) } return username.let(users::get)?.takeIf { user -> user.password?.let { passwordAndSalt -> val (_, salt) = decodePasswordHash(passwordAndSalt) hashPassword(password, Base64.getEncoder().encodeToString(salt)) == passwordAndSalt } ?: false }?.roles } } private class ServerInitializer( private val cfg: Configuration, private val eventExecutorGroup: EventExecutorGroup ) : ChannelInitializer() { companion object { private fun createSslCtx(tls: Configuration.Tls): SslContext { val keyStore = tls.keyStore return if (keyStore == null) { throw IllegalArgumentException("No keystore configured") } else { val javaKeyStore = loadKeystore(keyStore.file, keyStore.password) val serverKey = javaKeyStore.getKey( keyStore.keyAlias, (keyStore.keyPassword ?: "").let(String::toCharArray) ) as PrivateKey val serverCert: Array = Arrays.stream(javaKeyStore.getCertificateChain(keyStore.keyAlias)) .map { it as X509Certificate } .toArray { size -> Array(size) { null } } SslContextBuilder.forServer(serverKey, *serverCert).apply { if (tls.isVerifyClients) { clientAuth(ClientAuth.OPTIONAL) tls.trustStore?.let { trustStore -> val ts = loadKeystore(trustStore.file, trustStore.password) trustManager( ClientCertificateValidator.getTrustManager(ts, trustStore.isCheckCertificateStatus) ) } } }.build() } } fun loadKeystore(file: Path, password: String?): KeyStore { val ext = JWO.splitExtension(file) .map(Tuple2::get_2) .orElseThrow { IllegalArgumentException( "Keystore file '${file}' must have .jks, .p12, .pfx extension" ) } val keystore = when (ext.substring(1).lowercase()) { "jks" -> KeyStore.getInstance("JKS") "p12", "pfx" -> KeyStore.getInstance("PKCS12") else -> throw IllegalArgumentException( "Keystore file '${file}' must have .jks, .p12, .pfx extension" ) } Files.newInputStream(file).use { keystore.load(it, password?.let(String::toCharArray)) } return keystore } } private val log = contextLogger() private val serverHandler = let { val cacheImplementation = cfg.cache.materialize() val prefix = Path.of("/").resolve(Path.of(cfg.serverPath ?: "/")) ServerHandler(cacheImplementation, prefix) } private val exceptionHandler = ExceptionHandler() private val authenticator = when (val auth = cfg.authentication) { is Configuration.BasicAuthentication -> NettyHttpBasicAuthenticator(cfg.users, RoleAuthorizer()) is Configuration.ClientCertificateAuthentication -> { ClientCertificateAuthenticator( RoleAuthorizer(), cfg.users[""]?.roles, userExtractor(auth), groupExtractor(auth) ) } else -> null } private val sslContext: SslContext? = cfg.tls?.let(Companion::createSslCtx) private fun userExtractor(authentication: Configuration.ClientCertificateAuthentication) = authentication.userExtractor?.let { extractor -> val pattern = Pattern.compile(extractor.pattern) val rdnType = extractor.rdnType Configuration.UserExtractor { cert: X509Certificate -> val userName = LdapName(cert.subjectX500Principal.name).rdns.find { it.type == rdnType }?.let { pattern.matcher(it.value.toString()) }?.takeIf(Matcher::matches)?.group(1) cfg.users[userName] ?: throw java.lang.RuntimeException("Failed to extract user") } } private fun groupExtractor(authentication: Configuration.ClientCertificateAuthentication) = authentication.groupExtractor?.let { extractor -> val pattern = Pattern.compile(extractor.pattern) val rdnType = extractor.rdnType Configuration.GroupExtractor { cert: X509Certificate -> val groupName = LdapName(cert.subjectX500Principal.name).rdns.find { it.type == rdnType }?.let { pattern.matcher(it.value.toString()) }?.takeIf(Matcher::matches)?.group(1) cfg.groups[groupName] ?: throw java.lang.RuntimeException("Failed to extract group") } } override fun initChannel(ch: Channel) { log.debug { "Created connection ${ch.id().asShortText()} with ${ch.remoteAddress()}" } ch.closeFuture().addListener { log.debug { "Closed connection ${ch.id().asShortText()} with ${ch.remoteAddress()}" } } val pipeline = ch.pipeline() cfg.connection.apply { pipeline.addLast(ReadTimeoutHandler(readTimeout.toMillis(), TimeUnit.MILLISECONDS)) pipeline.addLast(WriteTimeoutHandler(writeTimeout.toMillis(), TimeUnit.MILLISECONDS)) pipeline.addLast(IdleStateHandler(false, 0, 0, idleTimeout.toMillis(), TimeUnit.MILLISECONDS)) } pipeline.addLast(object : ChannelInboundHandlerAdapter() { override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) { if (evt is IdleStateEvent) { log.debug { "Idle timeout reached on channel ${ch.id().asShortText()}, closing the connection" } ctx.close() } } }) sslContext?.newHandler(ch.alloc())?.also { pipeline.addLast(SSL_HANDLER_NAME, it) } pipeline.addLast(HttpServerCodec()) pipeline.addLast(HttpChunkContentCompressor(1024)) pipeline.addLast(ChunkedWriteHandler()) pipeline.addLast(HttpObjectAggregator(cfg.connection.maxRequestSize)) authenticator?.let { pipeline.addLast(it) } pipeline.addLast(eventExecutorGroup, serverHandler) pipeline.addLast(exceptionHandler) } } @Sharable private class ExceptionHandler : ChannelDuplexHandler() { private val log = contextLogger() private val NOT_AUTHORIZED: FullHttpResponse = DefaultFullHttpResponse( HttpVersion.HTTP_1_1, HttpResponseStatus.FORBIDDEN, Unpooled.EMPTY_BUFFER ).apply { headers()[HttpHeaderNames.CONTENT_LENGTH] = "0" } private val TOO_BIG: FullHttpResponse = DefaultFullHttpResponse( HttpVersion.HTTP_1_1, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, Unpooled.EMPTY_BUFFER ).apply { headers()[HttpHeaderNames.CONTENT_LENGTH] = "0" } override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { when (cause) { is DecoderException -> { log.error(cause.message, cause) ctx.close() } is SSLPeerUnverifiedException -> { ctx.writeAndFlush(NOT_AUTHORIZED.retainedDuplicate()) .addListener(ChannelFutureListener.CLOSE_ON_FAILURE) } is ContentTooLargeException -> { ctx.writeAndFlush(TOO_BIG.retainedDuplicate()) .addListener(ChannelFutureListener.CLOSE_ON_FAILURE) } is ReadTimeoutException -> { log.debug { val channelId = ctx.channel().id().asShortText() "Read timeout on channel $channelId, closing the connection" } ctx.close() } is WriteTimeoutException -> { log.debug { val channelId = ctx.channel().id().asShortText() "Write timeout on channel $channelId, closing the connection" } ctx.close() } else -> { log.error(cause.message, cause) ctx.close() } } } } class ServerHandle( httpChannelFuture: ChannelFuture, private val executorGroups: Iterable ) : AutoCloseable { private val httpChannel: Channel = httpChannelFuture.channel() private val closeFuture: ChannelFuture = httpChannel.closeFuture() private val log = contextLogger() fun shutdown(): ChannelFuture { return httpChannel.close() } override fun close() { try { closeFuture.sync() } finally { executorGroups.forEach { it.shutdownGracefully().sync() } } log.info { "GradleBuildCacheServer has been gracefully shut down" } } } fun run(): ServerHandle { // Create the multithreaded event loops for the server val bossGroup = NioEventLoopGroup(0) val serverSocketChannel = NioServerSocketChannel::class.java val workerGroup = bossGroup val eventExecutorGroup = run { val threadFactory = if (cfg.eventExecutor.isUseVirtualThreads) { Thread.ofVirtual().factory() } else { null } DefaultEventExecutorGroup(Runtime.getRuntime().availableProcessors(), threadFactory) } // A helper class that simplifies server configuration val bootstrap = ServerBootstrap().apply { // Configure the server group(bossGroup, workerGroup) channel(serverSocketChannel) childHandler(ServerInitializer(cfg, eventExecutorGroup)) option(ChannelOption.SO_BACKLOG, cfg.incomingConnectionsBacklogSize) childOption(ChannelOption.SO_KEEPALIVE, true) } // Bind and start to accept incoming connections. val bindAddress = InetSocketAddress(cfg.host, cfg.port) val httpChannel = bootstrap.bind(bindAddress).sync() log.info { "GradleBuildCacheServer is listening on ${cfg.host}:${cfg.port}" } return ServerHandle(httpChannel, setOf(bossGroup, workerGroup, eventExecutorGroup)) } }