diff --git a/src/main/java/net/woggioni/jzstd/ZstdInputStream.java b/src/main/java/net/woggioni/jzstd/ZstdInputStream.java index 80b85a2..a7da44e 100644 --- a/src/main/java/net/woggioni/jzstd/ZstdInputStream.java +++ b/src/main/java/net/woggioni/jzstd/ZstdInputStream.java @@ -54,13 +54,14 @@ public class ZstdInputStream extends InputStream { if (input.pos.intValue() == input.size.intValue()) { if (state == State.SOURCE_DEPLETED) { return; - } else if(input.src.position() == input.src.capacity()) { + } else if (input.src.position() == input.src.capacity()) { input.src.position(0); } + int toRead = input.src.capacity() - input.src.position(); - if(toRead > 0) { + if (toRead > 0) { int read = source.read(buffer, 0, toRead); - if(read < 0) { + if (read < 0) { state = State.SOURCE_DEPLETED; break; } @@ -68,10 +69,9 @@ public class ZstdInputStream extends InputStream { } input.pos = zero; input.size = new size_t(input.src.position()); - output.pos = zero; } int rc = ZstdLibrary.decompressStream(ctx.ctx, output, input); - if(rc == 0) { + if (rc == 0) { state = State.CTX_FLUSHED; break; } @@ -92,15 +92,14 @@ public class ZstdInputStream extends InputStream { public int read(byte[] arr, int off, int len) { int totalRead = 0; while (totalRead < len) { - if(output.pos.intValue() == output.dst.position()) { - if(state == State.CTX_FLUSHED) { - if(totalRead == 0) --totalRead; + if (output.pos.intValue() == output.dst.position()) { + if (state == State.CTX_FLUSHED) { + if (totalRead == 0) --totalRead; break; - } - else fill(); + } else fill(); } int toBeRead = Math.min(len, output.pos.intValue() - output.dst.position()); - output.dst.get(arr, off, toBeRead); + output.dst.get(arr, off + totalRead, toBeRead); totalRead += toBeRead; } return totalRead; diff --git a/src/main/java/net/woggioni/jzstd/internal/ZSTD_inBuffer.java b/src/main/java/net/woggioni/jzstd/internal/ZSTD_inBuffer.java index 96563a9..bb2c501 100644 --- a/src/main/java/net/woggioni/jzstd/internal/ZSTD_inBuffer.java +++ b/src/main/java/net/woggioni/jzstd/internal/ZSTD_inBuffer.java @@ -12,7 +12,7 @@ public class ZSTD_inBuffer extends Structure { public size_t pos; public ZSTD_inBuffer(int size) { - this.src = ByteBuffer.allocateDirect(size); + this.src = ByteBuffer.allocateDirect(size).order(ByteOrder.nativeOrder()); this.size = new size_t(this.src.capacity()); this.pos = new size_t(this.src.capacity()); } diff --git a/src/main/java/net/woggioni/jzstd/internal/ZSTD_outBuffer.java b/src/main/java/net/woggioni/jzstd/internal/ZSTD_outBuffer.java index 02def98..1d432d9 100644 --- a/src/main/java/net/woggioni/jzstd/internal/ZSTD_outBuffer.java +++ b/src/main/java/net/woggioni/jzstd/internal/ZSTD_outBuffer.java @@ -12,7 +12,7 @@ public class ZSTD_outBuffer extends Structure { public size_t pos; public ZSTD_outBuffer(int size) { - this.dst = ByteBuffer.allocateDirect(size); + this.dst = ByteBuffer.allocateDirect(size).order(ByteOrder.nativeOrder()); this.size = new size_t(dst.capacity()); this.pos = new size_t(0); } diff --git a/src/test/java/net/woggioni/jzstd/BasicTest.java b/src/test/java/net/woggioni/jzstd/BasicTest.java index 1a72a69..fb693b3 100644 --- a/src/test/java/net/woggioni/jzstd/BasicTest.java +++ b/src/test/java/net/woggioni/jzstd/BasicTest.java @@ -2,7 +2,8 @@ package net.woggioni.jzstd; import lombok.SneakyThrows; import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.extension.ExtensionContext; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; @@ -15,7 +16,10 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.security.DigestInputStream; import java.security.MessageDigest; +import java.util.Arrays; import java.util.Random; +import java.util.function.Supplier; +import java.util.stream.IntStream; import java.util.stream.Stream; class SplitOutputStream extends OutputStream { @@ -61,7 +65,7 @@ class HexSequenceGenerator extends InputStream { public int read() { int result; if (count < length) { - result = chars[(int)(count % chars.length)]; + result = chars[(int) (count % chars.length)]; ++count; } else result = -1; return result; @@ -70,11 +74,11 @@ class HexSequenceGenerator extends InputStream { @Override @SneakyThrows public int read(byte[] arr, int off, int len) { - if(count == length) return -1; + if (count == length) return -1; else { int i; for (i = 0; i < len && count + i < length; i++) { - arr[i] = (byte) chars[(int)((count + i) % chars.length)]; + arr[i] = (byte) chars[(int) ((count + i) % chars.length)]; } count += i; return i; @@ -106,7 +110,7 @@ class RandomBytesGenerator extends InputStream { @Override @SneakyThrows public int read(byte[] arr, int off, int len) { - if(count == length) return -1; + if (count == length) return -1; else { int i; for (i = 0; i < len && count + i < length; i++) { @@ -124,7 +128,7 @@ public class BasicTest { public static String bytesToHex(byte[] bytes) { char[] hexChars = new char[bytes.length * 2]; - for(int j = 0; j < bytes.length; j++) { + for (int j = 0; j < bytes.length; j++) { int v = bytes[j] & 0xFF; hexChars[j * 2] = hexArray[v >>> 4]; hexChars[j * 2 + 1] = hexArray[v & 0x0F]; @@ -132,25 +136,36 @@ public class BasicTest { return new String(hexChars); } - private static class InputStreamProvider implements ArgumentsProvider { + private static class TestCaseProvider implements ArgumentsProvider { @Override @SneakyThrows public Stream provideArguments(ExtensionContext context) { - return Stream.of( - Arguments.of(getClass().getResourceAsStream("/index.html")), - Arguments.of(new RandomBytesGenerator(123456, 1000)), - Arguments.of(new RandomBytesGenerator(654321, 65536)), - Arguments.of(new RandomBytesGenerator(101325, 12345678)), - Arguments.of(new HexSequenceGenerator(12345678)) - ); + Supplier[] streams = new Supplier[]{ + () -> getClass().getResourceAsStream("/index.html"), + () -> new RandomBytesGenerator(123456, 1000), + () -> new RandomBytesGenerator(654321, 65536), + () -> new RandomBytesGenerator(101325, 12345678), + () -> new HexSequenceGenerator(12345678) + }; + int[] bufferSizes = new int[]{ + 0x100, + 0x1000, + 0x10000, + 0x100000 + }; + + return Arrays.stream(streams) + .flatMap(s -> Arrays.stream(bufferSizes).mapToObj(bufferSize -> Arguments.of(s.get(), bufferSize))); } } + private static final Path temporaryDirectory = Paths.get(System.getProperty("java.io.tmpdir")); + @SneakyThrows @ParameterizedTest - @ArgumentsSource(InputStreamProvider.class) - public void test(InputStream testStream) { + @ArgumentsSource(TestCaseProvider.class) + public void test(InputStream testStream, int bufferSize) { MessageDigest md5 = MessageDigest.getInstance("MD5"); boolean debug = false; @@ -160,17 +175,17 @@ public class BasicTest { OutputStream os; ByteArrayOutputStream baos = new ByteArrayOutputStream(); os = baos; - if(debug) { - os = new SplitOutputStream(os, Files.newOutputStream(Paths.get("/tmp/out.bin.zst"))); + if (debug) { + os = new SplitOutputStream(os, Files.newOutputStream(temporaryDirectory.resolve("out.bin.zst"))); } os = ZstdOutputStream.from(os); if (debug) { - new SplitOutputStream( + os = new SplitOutputStream( os, - Files.newOutputStream(Paths.get("/tmp/out.bin.original"))); + Files.newOutputStream(temporaryDirectory.resolve("/tmp/out.bin.original"))); } try { - byte[] buffer = new byte[0x100000]; + byte[] buffer = new byte[bufferSize]; while (true) { int read = is.read(buffer); if (read < 0) break; @@ -188,10 +203,29 @@ public class BasicTest { md5.reset(); try (InputStream is = new DigestInputStream( ZstdInputStream.from(compressedStream), md5)) { - byte[] buffer = new byte[0x10000]; - while (true) { - int read = is.read(buffer); - if (read < 0) break; + OutputStream os; + if(debug) { + os = Files.newOutputStream(temporaryDirectory.resolve("out.bin")); + } else { + os = new OutputStream() { + @Override + public void write(int i) { + } + + @Override + public void write(byte[] b, int off, int len) { + } + }; + } + try { + byte[] buffer = new byte[bufferSize]; + while (true) { + int read = is.read(buffer); + if (read < 0) break; + os.write(buffer, 0, read); + } + } finally { + os.close(); } } byte[] roundTripDigest = md5.digest();