version 1.0

This commit is contained in:
2024-12-23 16:00:14 +08:00
parent 13f7ecc88a
commit 688a196a52
29 changed files with 1951 additions and 433 deletions

View File

@@ -1,21 +1,5 @@
package net.woggioni.gbcs
import java.net.InetSocketAddress
import java.net.URL
import java.net.URLStreamHandlerFactory
import java.nio.channels.FileChannel
import java.nio.file.Files
import java.nio.file.Path
import java.nio.file.StandardCopyOption
import java.nio.file.StandardOpenOption
import java.security.KeyStore
import java.security.MessageDigest
import java.security.PrivateKey
import java.security.cert.X509Certificate
import java.util.AbstractMap.SimpleEntry
import java.util.Arrays
import java.util.Base64
import java.util.concurrent.Executors
import io.netty.bootstrap.ServerBootstrap
import io.netty.buffer.ByteBuf
import io.netty.buffer.Unpooled
@@ -54,23 +38,44 @@ import io.netty.handler.ssl.ClientAuth
import io.netty.handler.ssl.SslContext
import io.netty.handler.ssl.SslContextBuilder
import io.netty.handler.stream.ChunkedNioFile
import io.netty.handler.stream.ChunkedNioStream
import io.netty.handler.stream.ChunkedWriteHandler
import io.netty.util.concurrent.DefaultEventExecutorGroup
import io.netty.util.concurrent.DefaultThreadFactory
import io.netty.util.concurrent.EventExecutorGroup
import net.woggioni.gbcs.cache.Cache
import net.woggioni.gbcs.cache.FileSystemCache
import net.woggioni.gbcs.configuration.Configuration
import net.woggioni.gbcs.url.ClasspathUrlStreamHandlerFactoryProvider
import javax.net.ssl.SSLPeerUnverifiedException
import net.woggioni.jwo.Application
import net.woggioni.jwo.JWO
import net.woggioni.jwo.Tuple2
import java.net.InetSocketAddress
import java.net.URI
import java.util.concurrent.ForkJoinPool
import java.net.URL
import java.net.URLStreamHandlerFactory
import java.nio.channels.FileChannel
import java.nio.file.Files
import java.nio.file.Path
import java.security.KeyStore
import java.security.MessageDigest
import java.security.PrivateKey
import java.security.cert.X509Certificate
import java.util.Arrays
import java.util.Base64
import java.util.concurrent.Executors
import java.util.regex.Matcher
import java.util.regex.Pattern
import javax.naming.ldap.LdapName
import javax.net.ssl.SSLEngine
import javax.net.ssl.SSLPeerUnverifiedException
class GradleBuildCacheServer(private val cfg : Configuration) {
class GradleBuildCacheServer(private val cfg: Configuration) {
internal class HttpChunkContentCompressor(threshold : Int, vararg compressionOptions: CompressionOptions = emptyArray())
: HttpContentCompressor(threshold, *compressionOptions) {
private class HttpChunkContentCompressor(
threshold: Int,
vararg compressionOptions: CompressionOptions = emptyArray()
) : HttpContentCompressor(threshold, *compressionOptions) {
override fun write(ctx: ChannelHandlerContext, msg: Any, promise: ChannelPromise) {
var message: Any? = msg
if (message is ByteBuf) {
@@ -88,14 +93,42 @@ class GradleBuildCacheServer(private val cfg : Configuration) {
}
}
private class NettyHttpBasicAuthenticator(
private val credentials: Map<String, String>, authorizer: Authorizer) : AbstractNettyHttpAuthenticator(authorizer) {
private class ClientCertificateAuthenticator(
authorizer: Authorizer,
private val sslEngine: SSLEngine,
private val userExtractor: Configuration.UserExtractor?,
private val groupExtractor: Configuration.GroupExtractor?,
) : AbstractNettyHttpAuthenticator(authorizer) {
companion object {
private val log = contextLogger()
}
override fun authenticate(ctx: ChannelHandlerContext, req: HttpRequest): String? {
override fun authenticate(ctx: ChannelHandlerContext, req: HttpRequest): Set<Role>? {
return try {
sslEngine.session.peerCertificates
} catch (es : SSLPeerUnverifiedException) {
null
}?.takeIf {
it.isNotEmpty()
}?.let { peerCertificates ->
val clientCertificate = peerCertificates.first() as X509Certificate
val user = userExtractor?.extract(clientCertificate)
val group = groupExtractor?.extract(clientCertificate)
(group?.roles ?: emptySet()) + (user?.roles ?: emptySet())
}
}
}
private class NettyHttpBasicAuthenticator(
private val users: Map<String, Configuration.User>, authorizer: Authorizer
) : AbstractNettyHttpAuthenticator(authorizer) {
companion object {
private val log = contextLogger()
}
override fun authenticate(ctx: ChannelHandlerContext, req: HttpRequest): Set<Role>? {
val authorizationHeader = req.headers()[HttpHeaderNames.AUTHORIZATION] ?: let {
log.debug(ctx) {
"Missing Authorization header"
@@ -116,67 +149,77 @@ class GradleBuildCacheServer(private val cfg : Configuration) {
}
return null
}
val (user, password) = Base64.getDecoder().decode(authorizationHeader.substring(cursor + 1))
.let(::String)
.let {
val colon = it.indexOf(':')
if(colon < 0) {
log.debug(ctx) {
"Missing colon from authentication"
}
return null
val (username, password) = Base64.getDecoder().decode(authorizationHeader.substring(cursor + 1))
.let(::String)
.let {
val colon = it.indexOf(':')
if (colon < 0) {
log.debug(ctx) {
"Missing colon from authentication"
}
it.substring(0, colon) to it.substring(colon + 1)
return null
}
return user.takeIf {
credentials[user] == password
}
it.substring(0, colon) to it.substring(colon + 1)
}
return username.let(users::get)?.takeIf { user ->
user.password?.let { passwordAndSalt ->
val (_, salt) = decodePasswordHash(passwordAndSalt)
hashPassword(password, Base64.getEncoder().encodeToString(salt)) == passwordAndSalt
} ?: false
}?.roles
}
}
private class ServerInitializer(private val cfg : Configuration) : ChannelInitializer<Channel>() {
private class ServerInitializer(private val cfg: Configuration) : ChannelInitializer<Channel>() {
private fun createSslCtx(tlsConfiguration : TlsConfiguration) : SslContext {
val keyStore = tlsConfiguration.keyStore
return if(keyStore == null) {
private fun createSslCtx(tls: Configuration.Tls): SslContext {
val keyStore = tls.keyStore
return if (keyStore == null) {
throw IllegalArgumentException("No keystore configured")
} else {
val javaKeyStore = loadKeystore(keyStore.file, keyStore.password)
val serverKey = javaKeyStore.getKey(
keyStore.keyAlias, keyStore.keyPassword?.let(String::toCharArray)) as PrivateKey
val serverCert : Array<X509Certificate> = Arrays.stream(javaKeyStore.getCertificateChain(keyStore.keyAlias))
.map { it as X509Certificate }
.toArray {size -> Array<X509Certificate?>(size) { null } }
keyStore.keyAlias, keyStore.keyPassword?.let(String::toCharArray)
) as PrivateKey
val serverCert: Array<X509Certificate> =
Arrays.stream(javaKeyStore.getCertificateChain(keyStore.keyAlias))
.map { it as X509Certificate }
.toArray { size -> Array<X509Certificate?>(size) { null } }
SslContextBuilder.forServer(serverKey, *serverCert).apply {
if(tlsConfiguration.verifyClients) {
clientAuth(ClientAuth.REQUIRE)
val trustStore = tlsConfiguration.trustStore
if (tls.verifyClients) {
clientAuth(ClientAuth.OPTIONAL)
val trustStore = tls.trustStore
if (trustStore != null) {
val ts = loadKeystore(trustStore.file, trustStore.password)
trustManager(
ClientCertificateValidator.getTrustManager(ts, trustStore.checkCertificateStatus))
ClientCertificateValidator.getTrustManager(ts, trustStore.checkCertificateStatus)
)
}
}
}.build()
}
}
private val sslContext : SslContext? = cfg.tlsConfiguration?.let(this::createSslCtx)
private val sslContext: SslContext? = cfg.tls?.let(this::createSslCtx)
private val group: EventExecutorGroup = DefaultEventExecutorGroup(Runtime.getRuntime().availableProcessors())
companion object {
val group: EventExecutorGroup = DefaultEventExecutorGroup(Runtime.getRuntime().availableProcessors())
fun loadKeystore(file : Path, password : String?) : KeyStore {
fun loadKeystore(file: Path, password: String?): KeyStore {
val ext = JWO.splitExtension(file)
.map(Tuple2<String, String>::get_2)
.orElseThrow {
IllegalArgumentException(
"Keystore file '${file}' must have .jks, .p12, .pfx extension")
"Keystore file '${file}' must have .jks, .p12, .pfx extension"
)
}
val keystore = when(ext.substring(1).lowercase()) {
val keystore = when (ext.substring(1).lowercase()) {
"jks" -> KeyStore.getInstance("JKS")
"p12", "pfx" -> KeyStore.getInstance("PKCS12")
else -> throw IllegalArgumentException(
"Keystore file '${file}' must have .jks, .p12, .pfx extension")
"Keystore file '${file}' must have .jks, .p12, .pfx extension"
)
}
Files.newInputStream(file).use {
keystore.load(it, password?.let(String::toCharArray))
@@ -185,21 +228,71 @@ class GradleBuildCacheServer(private val cfg : Configuration) {
}
}
private fun userExtractor(authentication: Configuration.ClientCertificateAuthentication) =
authentication.userExtractor?.let { extractor ->
val pattern = Pattern.compile(extractor.pattern)
val rdnType = extractor.rdnType
Configuration.UserExtractor { cert: X509Certificate ->
val userName = LdapName(cert.subjectX500Principal.name).rdns.find {
it.type == rdnType
}?.let {
pattern.matcher(it.value.toString())
}?.takeIf(Matcher::matches)?.group(1)
cfg.users[userName] ?: throw java.lang.RuntimeException("Failed to extract user")
}
}
private fun groupExtractor(authentication: Configuration.ClientCertificateAuthentication) =
authentication.groupExtractor?.let { extractor ->
val pattern = Pattern.compile(extractor.pattern)
val rdnType = extractor.rdnType
Configuration.GroupExtractor { cert: X509Certificate ->
val groupName = LdapName(cert.subjectX500Principal.name).rdns.find {
it.type == rdnType
}?.let {
pattern.matcher(it.value.toString())
}?.takeIf(Matcher::matches)?.group(1)
cfg.groups[groupName] ?: throw java.lang.RuntimeException("Failed to extract group")
}
}
override fun initChannel(ch: Channel) {
val userAuthorizer = UserAuthorizer(cfg.users)
val pipeline = ch.pipeline()
if(sslContext != null) {
val auth = cfg.authentication
var authenticator : AbstractNettyHttpAuthenticator? = null
if (auth is Configuration.BasicAuthentication) {
val roleAuthorizer = RoleAuthorizer()
authenticator = (NettyHttpBasicAuthenticator(cfg.users, roleAuthorizer))
}
if (sslContext != null) {
val sslHandler = sslContext.newHandler(ch.alloc())
pipeline.addLast(sslHandler)
if(auth is Configuration.ClientCertificateAuthentication) {
val roleAuthorizer = RoleAuthorizer()
authenticator = ClientCertificateAuthenticator(
roleAuthorizer,
sslHandler.engine(),
userExtractor(auth),
groupExtractor(auth)
)
}
}
pipeline.addLast(HttpServerCodec())
pipeline.addLast(HttpChunkContentCompressor(1024))
pipeline.addLast(ChunkedWriteHandler())
pipeline.addLast(HttpObjectAggregator(Int.MAX_VALUE))
// pipeline.addLast(NettyHttpBasicAuthenticator(mapOf("user" to "password")), userAuthorizer)
pipeline.addLast(group, ServerHandler(cfg.cacheFolder, cfg.serverPath))
authenticator?.let{
pipeline.addLast(it)
}
val cacheImplementation = when(val cache = cfg.cache) {
is Configuration.FileSystemCache -> {
FileSystemCache(cache.root, cache.maxAge)
}
else -> throw NotImplementedError()
}
pipeline.addLast(group, ServerHandler(cacheImplementation, cfg.serverPath))
pipeline.addLast(ExceptionHandler())
Files.createDirectories(cfg.cacheFolder)
}
}
@@ -207,36 +300,42 @@ class GradleBuildCacheServer(private val cfg : Configuration) {
private val log = contextLogger()
private val NOT_AUTHORIZED: FullHttpResponse = DefaultFullHttpResponse(
HttpVersion.HTTP_1_1, HttpResponseStatus.FORBIDDEN, Unpooled.EMPTY_BUFFER).apply {
HttpVersion.HTTP_1_1, HttpResponseStatus.FORBIDDEN, Unpooled.EMPTY_BUFFER
).apply {
headers()[HttpHeaderNames.CONTENT_LENGTH] = "0"
}
override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
when(cause) {
when (cause) {
is DecoderException -> {
log.error(cause.message, cause)
ctx.close()
}
is SSLPeerUnverifiedException -> {
ctx.writeAndFlush(NOT_AUTHORIZED.retainedDuplicate()).addListener(ChannelFutureListener.CLOSE_ON_FAILURE)
ctx.writeAndFlush(NOT_AUTHORIZED.retainedDuplicate())
.addListener(ChannelFutureListener.CLOSE_ON_FAILURE)
}
else -> {
log.error(cause.message, cause)
ctx.close()
}
log.error(cause.message, cause)
ctx.close()
}
}
}
}
private class ServerHandler(private val cacheDir: Path, private val serverPrefix: String) : SimpleChannelInboundHandler<FullHttpRequest>() {
private class ServerHandler(private val cache: Cache, private val serverPrefix: String?) :
SimpleChannelInboundHandler<FullHttpRequest>() {
companion object {
private val log = contextLogger()
private fun splitPath(req: HttpRequest): Map.Entry<String, String> {
private fun splitPath(req: HttpRequest): Pair<String?, String> {
val uri = req.uri()
val i = uri.lastIndexOf('/')
if (i < 0) throw RuntimeException(String.format("Malformed request URI: '%s'", uri))
return SimpleEntry(uri.substring(0, i), uri.substring(i + 1))
return uri.substring(0, i).takeIf(String::isNotEmpty) to uri.substring(i + 1)
}
}
@@ -246,14 +345,13 @@ class GradleBuildCacheServer(private val cfg : Configuration) {
if (method === HttpMethod.GET) {
val (prefix, key) = splitPath(msg)
if (serverPrefix == prefix) {
val file = cacheDir.resolve(digestString(key.toByteArray()))
if (Files.exists(file)) {
cache.get(digestString(key.toByteArray()))?.let { channel ->
log.debug(ctx) {
"Cache hit for key '$key'"
}
val response = DefaultHttpResponse(msg.protocolVersion(), HttpResponseStatus.OK)
response.headers()[HttpHeaderNames.CONTENT_TYPE] = HttpHeaderValues.APPLICATION_OCTET_STREAM
if(!keepAlive) {
if (!keepAlive) {
response.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE)
response.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.IDENTITY)
} else {
@@ -261,14 +359,22 @@ class GradleBuildCacheServer(private val cfg : Configuration) {
response.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED)
}
ctx.write(response)
val channel = FileChannel.open(file, StandardOpenOption.READ)
if(keepAlive) {
ctx.write(ChunkedNioFile(channel))
ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT)
} else {
ctx.writeAndFlush(DefaultFileRegion(channel, 0, Files.size(file))).addListener(ChannelFutureListener.CLOSE)
when (channel) {
is FileChannel -> {
if (keepAlive) {
ctx.write(ChunkedNioFile(channel))
ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT)
} else {
ctx.writeAndFlush(DefaultFileRegion(channel, 0, channel.size()))
.addListener(ChannelFutureListener.CLOSE)
}
}
else -> {
ctx.write(ChunkedNioStream(channel))
ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT)
}
}
} else {
} ?: let {
log.debug(ctx) {
"Cache miss for key '$key'"
}
@@ -291,19 +397,11 @@ class GradleBuildCacheServer(private val cfg : Configuration) {
"Added value for key '$key' to build cache"
}
val content = msg.content()
val file = cacheDir.resolve(digestString(key.toByteArray()))
val tmpFile = Files.createTempFile(cacheDir, null, ".tmp")
try {
Files.newOutputStream(tmpFile).use {
content.readBytes(it, content.readableBytes())
}
Files.move(tmpFile, file, StandardCopyOption.ATOMIC_MOVE)
} catch (t : Throwable) {
Files.delete(tmpFile)
throw t
}
val response = DefaultFullHttpResponse(msg.protocolVersion(), HttpResponseStatus.CREATED,
Unpooled.copiedBuffer(key.toByteArray()))
cache.put(digestString(key.toByteArray()), content)
val response = DefaultFullHttpResponse(
msg.protocolVersion(), HttpResponseStatus.CREATED,
Unpooled.copiedBuffer(key.toByteArray())
)
response.headers()[HttpHeaderNames.CONTENT_LENGTH] = response.content().readableBytes()
ctx.writeAndFlush(response)
} else {
@@ -326,13 +424,14 @@ class GradleBuildCacheServer(private val cfg : Configuration) {
}
class ServerHandle(
private val httpChannel : ChannelFuture,
private val httpChannel: ChannelFuture,
private val bossGroup: EventLoopGroup,
private val workerGroup: EventLoopGroup) : AutoCloseable {
private val workerGroup: EventLoopGroup
) : AutoCloseable {
private val closeFuture : ChannelFuture = httpChannel.channel().closeFuture()
private val closeFuture: ChannelFuture = httpChannel.channel().closeFuture()
fun shutdown() : ChannelFuture {
fun shutdown(): ChannelFuture {
return httpChannel.channel().close()
}
@@ -341,7 +440,7 @@ class GradleBuildCacheServer(private val cfg : Configuration) {
closeFuture.sync()
} finally {
val fut1 = workerGroup.shutdownGracefully()
val fut2 = if(bossGroup !== workerGroup) {
val fut2 = if (bossGroup !== workerGroup) {
bossGroup.shutdownGracefully()
} else null
fut1.sync()
@@ -350,11 +449,11 @@ class GradleBuildCacheServer(private val cfg : Configuration) {
}
}
fun run() : ServerHandle {
fun run(): ServerHandle {
// Create the multithreaded event loops for the server
val bossGroup = NioEventLoopGroup()
val serverSocketChannel = NioServerSocketChannel::class.java
val workerGroup = if(cfg.useVirtualThread) {
val workerGroup = if (cfg.useVirtualThread) {
NioEventLoopGroup(0, Executors.newVirtualThreadPerTaskExecutor())
} else {
NioEventLoopGroup(0, Executors.newWorkStealingPool())
@@ -378,12 +477,20 @@ class GradleBuildCacheServer(private val cfg : Configuration) {
companion object {
private fun String.toUrl() : URL = URL.of(URI(this), null)
private val log by lazy {
contextLogger()
}
private const val PROTOCOL_HANDLER = "java.protocol.handler.pkgs"
private const val HANDLERS_PACKAGE = "net.woggioni.gbcs.url"
val CONFIGURATION_SCHEMA_URL by lazy {
"classpath:net/woggioni/gbcs/gbcs.xsd".toUrl()
}
val DEFAULT_CONFIGURATION_URL by lazy { "classpath:net/woggioni/gbcs/gbcs-default.xml".toUrl() }
/**
* Reset any cached handlers just in case a jar protocol has already been used. We
* reset the handler by trying to set a null [URLStreamHandlerFactory] which
@@ -396,6 +503,7 @@ class GradleBuildCacheServer(private val cfg : Configuration) {
// Ignore
}
}
fun registerUrlProtocolHandler() {
val handlers = System.getProperty(PROTOCOL_HANDLER, "")
System.setProperty(
@@ -405,21 +513,19 @@ class GradleBuildCacheServer(private val cfg : Configuration) {
resetCachedUrlHandlers()
}
fun loadConfiguration(args: Array<String>) : Configuration {
registerUrlProtocolHandler()
fun loadConfiguration(args: Array<String>): Configuration {
// registerUrlProtocolHandler()
URL.setURLStreamHandlerFactory(ClasspathUrlStreamHandlerFactoryProvider())
Thread.currentThread().contextClassLoader = GradleBuildCacheServer::class.java.classLoader
// Thread.currentThread().contextClassLoader = GradleBuildCacheServer::class.java.classLoader
val app = Application.builder("gbcs")
.configurationDirectoryEnvVar("GBCS_CONFIGURATION_DIR")
.configurationDirectoryPropertyKey("net.woggioni.gbcs.conf.dir")
.build()
val confDir = app.computeConfigurationDirectory()
val configurationFile = confDir.resolve("gbcs.xml")
log.info { "here" }
if(!Files.exists(configurationFile)) {
if (!Files.exists(configurationFile)) {
Files.createDirectories(confDir)
val defaultConfigurationFileResourcePath = "classpath:net/woggioni/gbcs/gbcs-default.xml"
val defaultConfigurationFileResource = URI(defaultConfigurationFileResourcePath).toURL()
val defaultConfigurationFileResource = DEFAULT_CONFIGURATION_URL
Files.newOutputStream(configurationFile).use { outputStream ->
defaultConfigurationFileResource.openStream().use { inputStream ->
JWO.copy(inputStream, outputStream)
@@ -427,37 +533,36 @@ class GradleBuildCacheServer(private val cfg : Configuration) {
}
}
// val schemaUrl = javaClass.getResource("/net/woggioni/gbcs/gbcs.xsd")
val schemaUrl = URL.of(URI("classpath:net/woggioni/gbcs/gbcs.xsd"), null)
val dbf = Xml.newDocumentBuilderFactory()
val schemaUrl = CONFIGURATION_SCHEMA_URL
val dbf = Xml.newDocumentBuilderFactory(schemaUrl)
// dbf.schema = Xml.getSchema(this::class.java.module.getResourceAsStream("/net/woggioni/gbcs/gbcs.xsd"))
dbf.schema = Xml.getSchema(schemaUrl)
val db = dbf.newDocumentBuilder().apply {
setErrorHandler(Xml.ErrorHandler(schemaUrl))
}
val doc = Files.newInputStream(configurationFile).use(db::parse)
return Configuration.parse(doc.documentElement)
return Configuration.parse(doc)
}
@JvmStatic
fun main(args: Array<String>) {
val configuration = loadConfiguration(args)
// Runtime.getRuntime().addShutdownHook(Thread {
// Thread.sleep(5000)
// javaClass.classLoader.loadClass("net.woggioni.jwo.exception.ChildProcessException")
// }
// )
GradleBuildCacheServer(configuration).run().use {
}
}
fun digest(data : ByteArray,
md : MessageDigest = MessageDigest.getInstance("MD5")) : ByteArray {
fun digest(
data: ByteArray,
md: MessageDigest = MessageDigest.getInstance("MD5")
): ByteArray {
md.update(data)
return md.digest()
}
fun digestString(data : ByteArray,
md : MessageDigest = MessageDigest.getInstance("MD5")) : String {
fun digestString(
data: ByteArray,
md: MessageDigest = MessageDigest.getInstance("MD5")
): String {
return JWO.bytesToHex(digest(data, md))
}
}
@@ -467,8 +572,9 @@ object GraalNativeImageConfiguration {
@JvmStatic
fun main(args: Array<String>) {
val conf = GradleBuildCacheServer.loadConfiguration(args)
val handle = GradleBuildCacheServer(conf).run()
Thread.sleep(10_000)
handle.shutdown()
GradleBuildCacheServer(conf).run().use {
Thread.sleep(3000)
it.shutdown()
}
}
}