fixed bug in net.woggioni.jzstd.ZstdInputStream

This commit is contained in:
Walter Oggioni
2020-05-01 21:15:26 +01:00
parent 65929adcd7
commit da433c9d14
4 changed files with 71 additions and 38 deletions

View File

@@ -54,13 +54,14 @@ public class ZstdInputStream extends InputStream {
if (input.pos.intValue() == input.size.intValue()) { if (input.pos.intValue() == input.size.intValue()) {
if (state == State.SOURCE_DEPLETED) { if (state == State.SOURCE_DEPLETED) {
return; return;
} else if(input.src.position() == input.src.capacity()) { } else if (input.src.position() == input.src.capacity()) {
input.src.position(0); input.src.position(0);
} }
int toRead = input.src.capacity() - input.src.position(); int toRead = input.src.capacity() - input.src.position();
if(toRead > 0) { if (toRead > 0) {
int read = source.read(buffer, 0, toRead); int read = source.read(buffer, 0, toRead);
if(read < 0) { if (read < 0) {
state = State.SOURCE_DEPLETED; state = State.SOURCE_DEPLETED;
break; break;
} }
@@ -68,10 +69,9 @@ public class ZstdInputStream extends InputStream {
} }
input.pos = zero; input.pos = zero;
input.size = new size_t(input.src.position()); input.size = new size_t(input.src.position());
output.pos = zero;
} }
int rc = ZstdLibrary.decompressStream(ctx.ctx, output, input); int rc = ZstdLibrary.decompressStream(ctx.ctx, output, input);
if(rc == 0) { if (rc == 0) {
state = State.CTX_FLUSHED; state = State.CTX_FLUSHED;
break; break;
} }
@@ -92,15 +92,14 @@ public class ZstdInputStream extends InputStream {
public int read(byte[] arr, int off, int len) { public int read(byte[] arr, int off, int len) {
int totalRead = 0; int totalRead = 0;
while (totalRead < len) { while (totalRead < len) {
if(output.pos.intValue() == output.dst.position()) { if (output.pos.intValue() == output.dst.position()) {
if(state == State.CTX_FLUSHED) { if (state == State.CTX_FLUSHED) {
if(totalRead == 0) --totalRead; if (totalRead == 0) --totalRead;
break; break;
} } else fill();
else fill();
} }
int toBeRead = Math.min(len, output.pos.intValue() - output.dst.position()); int toBeRead = Math.min(len, output.pos.intValue() - output.dst.position());
output.dst.get(arr, off, toBeRead); output.dst.get(arr, off + totalRead, toBeRead);
totalRead += toBeRead; totalRead += toBeRead;
} }
return totalRead; return totalRead;

View File

@@ -12,7 +12,7 @@ public class ZSTD_inBuffer extends Structure {
public size_t pos; public size_t pos;
public ZSTD_inBuffer(int size) { public ZSTD_inBuffer(int size) {
this.src = ByteBuffer.allocateDirect(size); this.src = ByteBuffer.allocateDirect(size).order(ByteOrder.nativeOrder());
this.size = new size_t(this.src.capacity()); this.size = new size_t(this.src.capacity());
this.pos = new size_t(this.src.capacity()); this.pos = new size_t(this.src.capacity());
} }

View File

@@ -12,7 +12,7 @@ public class ZSTD_outBuffer extends Structure {
public size_t pos; public size_t pos;
public ZSTD_outBuffer(int size) { public ZSTD_outBuffer(int size) {
this.dst = ByteBuffer.allocateDirect(size); this.dst = ByteBuffer.allocateDirect(size).order(ByteOrder.nativeOrder());
this.size = new size_t(dst.capacity()); this.size = new size_t(dst.capacity());
this.pos = new size_t(0); this.pos = new size_t(0);
} }

View File

@@ -2,7 +2,8 @@ package net.woggioni.jzstd;
import lombok.SneakyThrows; import lombok.SneakyThrows;
import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.DisplayNameGeneration;
import org.junit.jupiter.api.DisplayNameGenerator;
import org.junit.jupiter.api.extension.ExtensionContext; import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.Arguments;
@@ -15,7 +16,10 @@ import java.nio.file.Path;
import java.nio.file.Paths; import java.nio.file.Paths;
import java.security.DigestInputStream; import java.security.DigestInputStream;
import java.security.MessageDigest; import java.security.MessageDigest;
import java.util.Arrays;
import java.util.Random; import java.util.Random;
import java.util.function.Supplier;
import java.util.stream.IntStream;
import java.util.stream.Stream; import java.util.stream.Stream;
class SplitOutputStream extends OutputStream { class SplitOutputStream extends OutputStream {
@@ -61,7 +65,7 @@ class HexSequenceGenerator extends InputStream {
public int read() { public int read() {
int result; int result;
if (count < length) { if (count < length) {
result = chars[(int)(count % chars.length)]; result = chars[(int) (count % chars.length)];
++count; ++count;
} else result = -1; } else result = -1;
return result; return result;
@@ -70,11 +74,11 @@ class HexSequenceGenerator extends InputStream {
@Override @Override
@SneakyThrows @SneakyThrows
public int read(byte[] arr, int off, int len) { public int read(byte[] arr, int off, int len) {
if(count == length) return -1; if (count == length) return -1;
else { else {
int i; int i;
for (i = 0; i < len && count + i < length; i++) { for (i = 0; i < len && count + i < length; i++) {
arr[i] = (byte) chars[(int)((count + i) % chars.length)]; arr[i] = (byte) chars[(int) ((count + i) % chars.length)];
} }
count += i; count += i;
return i; return i;
@@ -106,7 +110,7 @@ class RandomBytesGenerator extends InputStream {
@Override @Override
@SneakyThrows @SneakyThrows
public int read(byte[] arr, int off, int len) { public int read(byte[] arr, int off, int len) {
if(count == length) return -1; if (count == length) return -1;
else { else {
int i; int i;
for (i = 0; i < len && count + i < length; i++) { for (i = 0; i < len && count + i < length; i++) {
@@ -124,7 +128,7 @@ public class BasicTest {
public static String bytesToHex(byte[] bytes) { public static String bytesToHex(byte[] bytes) {
char[] hexChars = new char[bytes.length * 2]; char[] hexChars = new char[bytes.length * 2];
for(int j = 0; j < bytes.length; j++) { for (int j = 0; j < bytes.length; j++) {
int v = bytes[j] & 0xFF; int v = bytes[j] & 0xFF;
hexChars[j * 2] = hexArray[v >>> 4]; hexChars[j * 2] = hexArray[v >>> 4];
hexChars[j * 2 + 1] = hexArray[v & 0x0F]; hexChars[j * 2 + 1] = hexArray[v & 0x0F];
@@ -132,25 +136,36 @@ public class BasicTest {
return new String(hexChars); return new String(hexChars);
} }
private static class InputStreamProvider implements ArgumentsProvider { private static class TestCaseProvider implements ArgumentsProvider {
@Override @Override
@SneakyThrows @SneakyThrows
public Stream<? extends Arguments> provideArguments(ExtensionContext context) { public Stream<? extends Arguments> provideArguments(ExtensionContext context) {
return Stream.of( Supplier<InputStream>[] streams = new Supplier[]{
Arguments.of(getClass().getResourceAsStream("/index.html")), () -> getClass().getResourceAsStream("/index.html"),
Arguments.of(new RandomBytesGenerator(123456, 1000)), () -> new RandomBytesGenerator(123456, 1000),
Arguments.of(new RandomBytesGenerator(654321, 65536)), () -> new RandomBytesGenerator(654321, 65536),
Arguments.of(new RandomBytesGenerator(101325, 12345678)), () -> new RandomBytesGenerator(101325, 12345678),
Arguments.of(new HexSequenceGenerator(12345678)) () -> new HexSequenceGenerator(12345678)
); };
int[] bufferSizes = new int[]{
0x100,
0x1000,
0x10000,
0x100000
};
return Arrays.stream(streams)
.flatMap(s -> Arrays.stream(bufferSizes).mapToObj(bufferSize -> Arguments.of(s.get(), bufferSize)));
} }
} }
private static final Path temporaryDirectory = Paths.get(System.getProperty("java.io.tmpdir"));
@SneakyThrows @SneakyThrows
@ParameterizedTest @ParameterizedTest
@ArgumentsSource(InputStreamProvider.class) @ArgumentsSource(TestCaseProvider.class)
public void test(InputStream testStream) { public void test(InputStream testStream, int bufferSize) {
MessageDigest md5 = MessageDigest.getInstance("MD5"); MessageDigest md5 = MessageDigest.getInstance("MD5");
boolean debug = false; boolean debug = false;
@@ -160,17 +175,17 @@ public class BasicTest {
OutputStream os; OutputStream os;
ByteArrayOutputStream baos = new ByteArrayOutputStream(); ByteArrayOutputStream baos = new ByteArrayOutputStream();
os = baos; os = baos;
if(debug) { if (debug) {
os = new SplitOutputStream(os, Files.newOutputStream(Paths.get("/tmp/out.bin.zst"))); os = new SplitOutputStream(os, Files.newOutputStream(temporaryDirectory.resolve("out.bin.zst")));
} }
os = ZstdOutputStream.from(os); os = ZstdOutputStream.from(os);
if (debug) { if (debug) {
new SplitOutputStream( os = new SplitOutputStream(
os, os,
Files.newOutputStream(Paths.get("/tmp/out.bin.original"))); Files.newOutputStream(temporaryDirectory.resolve("/tmp/out.bin.original")));
} }
try { try {
byte[] buffer = new byte[0x100000]; byte[] buffer = new byte[bufferSize];
while (true) { while (true) {
int read = is.read(buffer); int read = is.read(buffer);
if (read < 0) break; if (read < 0) break;
@@ -188,10 +203,29 @@ public class BasicTest {
md5.reset(); md5.reset();
try (InputStream is = new DigestInputStream( try (InputStream is = new DigestInputStream(
ZstdInputStream.from(compressedStream), md5)) { ZstdInputStream.from(compressedStream), md5)) {
byte[] buffer = new byte[0x10000]; OutputStream os;
if(debug) {
os = Files.newOutputStream(temporaryDirectory.resolve("out.bin"));
} else {
os = new OutputStream() {
@Override
public void write(int i) {
}
@Override
public void write(byte[] b, int off, int len) {
}
};
}
try {
byte[] buffer = new byte[bufferSize];
while (true) { while (true) {
int read = is.read(buffer); int read = is.read(buffer);
if (read < 0) break; if (read < 0) break;
os.write(buffer, 0, read);
}
} finally {
os.close();
} }
} }
byte[] roundTripDigest = md5.digest(); byte[] roundTripDigest = md5.digest();