fixed memory leak in InMemoryCache

This commit is contained in:
2025-02-05 19:09:51 +08:00
parent 1b6cf1bd96
commit 7d0f24fa58
7 changed files with 125 additions and 37 deletions

View File

@@ -0,0 +1,19 @@
package net.woggioni.gbcs.common
import io.netty.buffer.ByteBuf
import java.io.InputStream
import java.io.OutputStream
class ByteBufOutputStream(private val buf : ByteBuf) : OutputStream() {
override fun write(b: Int) {
buf.writeByte(b)
}
override fun write(b: ByteArray, off: Int, len: Int) {
buf.writeBytes(b, off, len)
}
override fun close() {
buf.release()
}
}

View File

@@ -24,6 +24,7 @@ import io.netty.handler.codec.memcache.binary.FullBinaryMemcacheRequest
import io.netty.handler.codec.memcache.binary.FullBinaryMemcacheResponse import io.netty.handler.codec.memcache.binary.FullBinaryMemcacheResponse
import io.netty.util.concurrent.GenericFutureListener import io.netty.util.concurrent.GenericFutureListener
import net.woggioni.gbcs.common.ByteBufInputStream import net.woggioni.gbcs.common.ByteBufInputStream
import net.woggioni.gbcs.common.ByteBufOutputStream
import net.woggioni.gbcs.common.GBCS.digest import net.woggioni.gbcs.common.GBCS.digest
import net.woggioni.gbcs.common.HostAndPort import net.woggioni.gbcs.common.HostAndPort
import net.woggioni.gbcs.common.contextLogger import net.woggioni.gbcs.common.contextLogger
@@ -136,6 +137,7 @@ class MemcacheClient(private val cfg: MemcacheCacheConfiguration) : AutoCloseabl
response.completeExceptionally(ex) response.completeExceptionally(ex)
} }
}) })
request.touch()
channel.writeAndFlush(request) channel.writeAndFlush(request)
} else { } else {
response.completeExceptionally(channelFuture.cause()) response.completeExceptionally(channelFuture.cause())
@@ -162,29 +164,35 @@ class MemcacheClient(private val cfg: MemcacheCacheConfiguration) : AutoCloseabl
} }
} }
return sendRequest(request).thenApply { response -> return sendRequest(request).thenApply { response ->
when (val status = response.status()) { try {
BinaryMemcacheResponseStatus.SUCCESS -> { when (val status = response.status()) {
val compressionMode = cfg.compressionMode BinaryMemcacheResponseStatus.SUCCESS -> {
val content = response.content().retain() val compressionMode = cfg.compressionMode
response.release() val content = response.content().retain()
if (compressionMode != null) { content.touch()
when (compressionMode) { if (compressionMode != null) {
MemcacheCacheConfiguration.CompressionMode.GZIP -> { when (compressionMode) {
GZIPInputStream(ByteBufInputStream(content)) MemcacheCacheConfiguration.CompressionMode.GZIP -> {
} GZIPInputStream(ByteBufInputStream(content))
}
MemcacheCacheConfiguration.CompressionMode.DEFLATE -> { MemcacheCacheConfiguration.CompressionMode.DEFLATE -> {
InflaterInputStream(ByteBufInputStream(content)) InflaterInputStream(ByteBufInputStream(content))
}
} }
} } else {
} else { ByteBufInputStream(content)
ByteBufInputStream(content) }.let(Channels::newChannel)
}.let(Channels::newChannel) }
BinaryMemcacheResponseStatus.KEY_ENOENT -> {
null
}
else -> throw MemcacheException(status)
} }
BinaryMemcacheResponseStatus.KEY_ENOENT -> { } finally {
null response.release()
}
else -> throw MemcacheException(status)
} }
} }
} }
@@ -199,16 +207,18 @@ class MemcacheClient(private val cfg: MemcacheCacheConfiguration) : AutoCloseabl
extras.writeInt(0) extras.writeInt(0)
extras.writeInt(encodeExpiry(expiry)) extras.writeInt(encodeExpiry(expiry))
val compressionMode = cfg.compressionMode val compressionMode = cfg.compressionMode
content.retain()
val payload = if (compressionMode != null) { val payload = if (compressionMode != null) {
val inputStream = ByteBufInputStream(Unpooled.wrappedBuffer(content)) val inputStream = ByteBufInputStream(content)
val baos = ByteArrayOutputStream() val buf = content.alloc().buffer()
buf.retain()
val outputStream = when (compressionMode) { val outputStream = when (compressionMode) {
MemcacheCacheConfiguration.CompressionMode.GZIP -> { MemcacheCacheConfiguration.CompressionMode.GZIP -> {
GZIPOutputStream(baos) GZIPOutputStream(ByteBufOutputStream(buf))
} }
MemcacheCacheConfiguration.CompressionMode.DEFLATE -> { MemcacheCacheConfiguration.CompressionMode.DEFLATE -> {
DeflaterOutputStream(baos, Deflater(Deflater.DEFAULT_COMPRESSION, false)) DeflaterOutputStream(ByteBufOutputStream(buf), Deflater(Deflater.DEFAULT_COMPRESSION, false))
} }
} }
inputStream.use { i -> inputStream.use { i ->
@@ -216,7 +226,7 @@ class MemcacheClient(private val cfg: MemcacheCacheConfiguration) : AutoCloseabl
JWO.copy(i, o) JWO.copy(i, o)
} }
} }
Unpooled.wrappedBuffer(baos.toByteArray()) buf
} else { } else {
content content
} }

View File

@@ -1,12 +1,12 @@
package net.woggioni.gbcs.server.cache package net.woggioni.gbcs.server.cache
import io.netty.buffer.ByteBuf import io.netty.buffer.ByteBuf
import io.netty.buffer.Unpooled
import net.woggioni.gbcs.api.Cache import net.woggioni.gbcs.api.Cache
import net.woggioni.gbcs.common.ByteBufInputStream import net.woggioni.gbcs.common.ByteBufInputStream
import net.woggioni.gbcs.common.ByteBufOutputStream
import net.woggioni.gbcs.common.GBCS.digestString import net.woggioni.gbcs.common.GBCS.digestString
import net.woggioni.gbcs.common.contextLogger
import net.woggioni.jwo.JWO import net.woggioni.jwo.JWO
import java.io.ByteArrayOutputStream
import java.nio.channels.Channels import java.nio.channels.Channels
import java.security.MessageDigest import java.security.MessageDigest
import java.time.Duration import java.time.Duration
@@ -14,6 +14,7 @@ import java.time.Instant
import java.util.concurrent.CompletableFuture import java.util.concurrent.CompletableFuture
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.PriorityBlockingQueue import java.util.concurrent.PriorityBlockingQueue
import java.util.concurrent.atomic.AtomicLong
import java.util.zip.Deflater import java.util.zip.Deflater
import java.util.zip.DeflaterOutputStream import java.util.zip.DeflaterOutputStream
import java.util.zip.Inflater import java.util.zip.Inflater
@@ -21,11 +22,18 @@ import java.util.zip.InflaterInputStream
class InMemoryCache( class InMemoryCache(
val maxAge: Duration, val maxAge: Duration,
val maxSize: Long,
val digestAlgorithm: String?, val digestAlgorithm: String?,
val compressionEnabled: Boolean, val compressionEnabled: Boolean,
val compressionLevel: Int val compressionLevel: Int
) : Cache { ) : Cache {
companion object {
@JvmStatic
private val log = contextLogger()
}
private val size = AtomicLong()
private val map = ConcurrentHashMap<String, ByteBuf>() private val map = ConcurrentHashMap<String, ByteBuf>()
private class RemovalQueueElement(val key: String, val value : ByteBuf, val expiry : Instant) : Comparable<RemovalQueueElement> { private class RemovalQueueElement(val key: String, val value : ByteBuf, val expiry : Instant) : Comparable<RemovalQueueElement> {
@@ -38,9 +46,17 @@ class InMemoryCache(
private val garbageCollector = Thread { private val garbageCollector = Thread {
while(true) { while(true) {
val el = removalQueue.take() val el = removalQueue.take()
val buf = el.value
val now = Instant.now() val now = Instant.now()
if(now > el.expiry) { if(now > el.expiry) {
map.remove(el.key, el.value) val removed = map.remove(el.key, buf)
if(removed) {
updateSizeAfterRemoval(buf)
//Decrease the reference count for map
buf.release()
}
//Decrease the reference count for removalQueue
buf.release()
} else { } else {
removalQueue.put(el) removalQueue.put(el)
Thread.sleep(minOf(Duration.between(now, el.expiry), Duration.ofSeconds(1))) Thread.sleep(minOf(Duration.between(now, el.expiry), Duration.ofSeconds(1)))
@@ -50,6 +66,28 @@ class InMemoryCache(
start() start()
} }
private fun removeEldest() : Long {
while(true) {
val el = removalQueue.take()
val buf = el.value
val removed = map.remove(el.key, buf)
//Decrease the reference count for removalQueue
buf.release()
if(removed) {
val newSize = updateSizeAfterRemoval(buf)
//Decrease the reference count for map
buf.release()
return newSize
}
}
}
private fun updateSizeAfterRemoval(removed: ByteBuf) : Long {
return size.updateAndGet { currentSize : Long ->
currentSize - removed.readableBytes()
}
}
override fun close() { override fun close() {
running = false running = false
garbageCollector.join() garbageCollector.join()
@@ -64,11 +102,13 @@ class InMemoryCache(
).let { digest -> ).let { digest ->
map[digest] map[digest]
?.let { value -> ?.let { value ->
val copy = value.retainedDuplicate()
copy.touch("This has to be released by the caller of the cache")
if (compressionEnabled) { if (compressionEnabled) {
val inflater = Inflater() val inflater = Inflater()
Channels.newChannel(InflaterInputStream(ByteBufInputStream(value), inflater)) Channels.newChannel(InflaterInputStream(ByteBufInputStream(copy), inflater))
} else { } else {
Channels.newChannel(ByteBufInputStream(value)) Channels.newChannel(ByteBufInputStream(copy))
} }
} }
}.let { }.let {
@@ -81,18 +121,29 @@ class InMemoryCache(
?.let { md -> ?.let { md ->
digestString(key.toByteArray(), md) digestString(key.toByteArray(), md)
} ?: key).let { digest -> } ?: key).let { digest ->
content.retain()
val value = if (compressionEnabled) { val value = if (compressionEnabled) {
val deflater = Deflater(compressionLevel) val deflater = Deflater(compressionLevel)
val baos = ByteArrayOutputStream() val buf = content.alloc().buffer()
DeflaterOutputStream(baos, deflater).use { stream -> buf.retain()
JWO.copy(ByteBufInputStream(content), stream) DeflaterOutputStream(ByteBufOutputStream(buf), deflater).use { outputStream ->
ByteBufInputStream(content).use { inputStream ->
JWO.copy(inputStream, outputStream)
}
} }
Unpooled.wrappedBuffer(baos.toByteArray()) buf
} else { } else {
content content
} }
map[digest] = value val old = map.put(digest, value)
removalQueue.put(RemovalQueueElement(digest, value, Instant.now().plus(maxAge))) val delta = value.readableBytes() - (old?.readableBytes() ?: 0)
var newSize = size.updateAndGet { currentSize : Long ->
currentSize + delta
}
removalQueue.put(RemovalQueueElement(digest, value.retain(), Instant.now().plus(maxAge)))
while(newSize > maxSize) {
newSize = removeEldest()
}
}.let { }.let {
CompletableFuture.completedFuture<Void>(null) CompletableFuture.completedFuture<Void>(null)
} }

View File

@@ -6,12 +6,14 @@ import java.time.Duration
data class InMemoryCacheConfiguration( data class InMemoryCacheConfiguration(
val maxAge: Duration, val maxAge: Duration,
val maxSize: Long,
val digestAlgorithm : String?, val digestAlgorithm : String?,
val compressionEnabled: Boolean, val compressionEnabled: Boolean,
val compressionLevel: Int, val compressionLevel: Int,
) : Configuration.Cache { ) : Configuration.Cache {
override fun materialize() = InMemoryCache( override fun materialize() = InMemoryCache(
maxAge, maxAge,
maxSize,
digestAlgorithm, digestAlgorithm,
compressionEnabled, compressionEnabled,
compressionLevel compressionLevel

View File

@@ -21,6 +21,9 @@ class InMemoryCacheProvider : CacheProvider<InMemoryCacheConfiguration> {
val maxAge = el.renderAttribute("max-age") val maxAge = el.renderAttribute("max-age")
?.let(Duration::parse) ?.let(Duration::parse)
?: Duration.ofDays(1) ?: Duration.ofDays(1)
val maxSize = el.renderAttribute("max-size")
?.let(java.lang.Long::decode)
?: 0x1000000
val enableCompression = el.renderAttribute("enable-compression") val enableCompression = el.renderAttribute("enable-compression")
?.let(String::toBoolean) ?.let(String::toBoolean)
?: true ?: true
@@ -31,6 +34,7 @@ class InMemoryCacheProvider : CacheProvider<InMemoryCacheConfiguration> {
return InMemoryCacheConfiguration( return InMemoryCacheConfiguration(
maxAge, maxAge,
maxSize,
digestAlgorithm, digestAlgorithm,
enableCompression, enableCompression,
compressionLevel compressionLevel
@@ -43,6 +47,7 @@ class InMemoryCacheProvider : CacheProvider<InMemoryCacheConfiguration> {
val prefix = doc.lookupPrefix(GBCS.GBCS_NAMESPACE_URI) val prefix = doc.lookupPrefix(GBCS.GBCS_NAMESPACE_URI)
attr("xs:type", "${prefix}:inMemoryCacheType", GBCS.XML_SCHEMA_NAMESPACE_URI) attr("xs:type", "${prefix}:inMemoryCacheType", GBCS.XML_SCHEMA_NAMESPACE_URI)
attr("max-age", maxAge.toString()) attr("max-age", maxAge.toString())
attr("max-size", maxSize.toString())
digestAlgorithm?.let { digestAlgorithm -> digestAlgorithm?.let { digestAlgorithm ->
attr("digest", digestAlgorithm) attr("digest", digestAlgorithm)
} }

View File

@@ -107,7 +107,7 @@ class ServerHandler(private val cache: Cache, private val serverPrefix: Path) :
log.debug(ctx) { log.debug(ctx) {
"Added value for key '$key' to build cache" "Added value for key '$key' to build cache"
} }
cache.put(key, msg.content().retain()).thenRun { cache.put(key, msg.content()).thenRun {
val response = DefaultFullHttpResponse( val response = DefaultFullHttpResponse(
msg.protocolVersion(), HttpResponseStatus.CREATED, msg.protocolVersion(), HttpResponseStatus.CREATED,
Unpooled.copiedBuffer(key.toByteArray()) Unpooled.copiedBuffer(key.toByteArray())

View File

@@ -52,6 +52,7 @@
<xs:complexContent> <xs:complexContent>
<xs:extension base="gbcs:cacheType"> <xs:extension base="gbcs:cacheType">
<xs:attribute name="max-age" type="xs:duration" default="P1D"/> <xs:attribute name="max-age" type="xs:duration" default="P1D"/>
<xs:attribute name="max-size" type="xs:token" default="0x1000000"/>
<xs:attribute name="digest" type="xs:token" default="MD5"/> <xs:attribute name="digest" type="xs:token" default="MD5"/>
<xs:attribute name="enable-compression" type="xs:boolean" default="true"/> <xs:attribute name="enable-compression" type="xs:boolean" default="true"/>
<xs:attribute name="compression-level" type="xs:byte" default="-1"/> <xs:attribute name="compression-level" type="xs:byte" default="-1"/>