From 9b09c6bd738a665e6a3a154326419fb6064d8782 Mon Sep 17 00:00:00 2001 From: Walter Oggioni Date: Sun, 14 May 2023 15:38:15 +0800 Subject: [PATCH] added jmath library --- benchmark/build.gradle | 2 + build.gradle | 61 +- gradle.properties | 2 +- gradle/wrapper/gradle-wrapper.properties | 2 +- jmath/build.gradle | 18 + .../net/woggioni/jmath/BigIntegerExt.java | 50 ++ .../net/woggioni/jmath/DecimalFactory.java | 32 + .../java/net/woggioni/jmath/DecimalValue.java | 63 ++ .../java/net/woggioni/jmath/FloatFactory.java | 23 + .../java/net/woggioni/jmath/FloatValue.java | 58 ++ .../net/woggioni/jmath/IntegerFactory.java | 25 + .../java/net/woggioni/jmath/IntegerValue.java | 57 ++ .../main/java/net/woggioni/jmath/Matrix.java | 561 ++++++++++++++++++ .../java/net/woggioni/jmath/NumericType.java | 10 + .../woggioni/jmath/NumericTypeFactory.java | 10 + .../java/net/woggioni/jmath/Rational.java | 125 ++++ .../net/woggioni/jmath/RationalFactory.java | 23 + .../jmath/SingularMatrixException.java | 7 + .../net/woggioni/jmath/SizeException.java | 7 + .../main/java/net/woggioni/jmath/Vector.java | 142 +++++ .../java/net/woggioni/jmath/MatrixTest.java | 206 +++++++ .../java/net/woggioni/jmath/RationalTest.java | 102 ++++ settings.gradle | 1 + src/main/java/net/woggioni/jwo/Hash.java | 5 + .../java/net/woggioni/jwo/Requirement.java | 25 + 25 files changed, 1592 insertions(+), 25 deletions(-) create mode 100644 jmath/build.gradle create mode 100644 jmath/src/main/java/net/woggioni/jmath/BigIntegerExt.java create mode 100644 jmath/src/main/java/net/woggioni/jmath/DecimalFactory.java create mode 100644 jmath/src/main/java/net/woggioni/jmath/DecimalValue.java create mode 100644 jmath/src/main/java/net/woggioni/jmath/FloatFactory.java create mode 100644 jmath/src/main/java/net/woggioni/jmath/FloatValue.java create mode 100644 jmath/src/main/java/net/woggioni/jmath/IntegerFactory.java create mode 100644 jmath/src/main/java/net/woggioni/jmath/IntegerValue.java create mode 100644 jmath/src/main/java/net/woggioni/jmath/Matrix.java create mode 100644 jmath/src/main/java/net/woggioni/jmath/NumericType.java create mode 100644 jmath/src/main/java/net/woggioni/jmath/NumericTypeFactory.java create mode 100644 jmath/src/main/java/net/woggioni/jmath/Rational.java create mode 100644 jmath/src/main/java/net/woggioni/jmath/RationalFactory.java create mode 100644 jmath/src/main/java/net/woggioni/jmath/SingularMatrixException.java create mode 100644 jmath/src/main/java/net/woggioni/jmath/SizeException.java create mode 100644 jmath/src/main/java/net/woggioni/jmath/Vector.java create mode 100644 jmath/src/test/java/net/woggioni/jmath/MatrixTest.java create mode 100644 jmath/src/test/java/net/woggioni/jmath/RationalTest.java create mode 100644 src/main/java/net/woggioni/jwo/Requirement.java diff --git a/benchmark/build.gradle b/benchmark/build.gradle index ca3b431..301acbd 100644 --- a/benchmark/build.gradle +++ b/benchmark/build.gradle @@ -1,4 +1,6 @@ plugins { + id 'java-library' + alias(catalog.plugins.lombok) alias(catalog.plugins.envelope) } diff --git a/build.gradle b/build.gradle index abfd9ae..2e4c16d 100644 --- a/build.gradle +++ b/build.gradle @@ -1,13 +1,11 @@ plugins { + id 'java-library' id 'maven-publish' alias(catalog.plugins.multi.release.jar) - alias(catalog.plugins.lombok) apply false + alias(catalog.plugins.lombok) } allprojects { - apply plugin: 'java-library' - apply plugin: catalog.plugins.lombok.get().pluginId - group = "net.woggioni" version = getProperty('jwo.version') @@ -17,16 +15,45 @@ allprojects { } mavenCentral() } - - dependencies { - testImplementation catalog.junit.jupiter.api - testImplementation catalog.junit.jupiter.params - testRuntimeOnly catalog.junit.jupiter.engine + + pluginManager.withPlugin('java-library') { + + dependencies { + testImplementation catalog.junit.jupiter.api + testImplementation catalog.junit.jupiter.params + testRuntimeOnly catalog.junit.jupiter.engine + } + + test { + useJUnitPlatform() + } } - lombok { - version = catalog.versions.lombok.get() + pluginManager.withPlugin(catalog.plugins.lombok.get().pluginId) { + lombok { + version = catalog.versions.lombok.get() + } } + + pluginManager.withPlugin('maven-publish') { + publishing { + repositories { + maven { + url = 'https://mvn.woggioni.net/' + } + } + publications { + maven(MavenPublication) { + from(components["java"]) + } + } + } + } +} + +java { + withJavadocJar() + withSourcesJar() } ext { @@ -76,18 +103,6 @@ test { jvmArgs(['--add-opens', 'java.base/sun.nio.fs=ALL-UNNAMED']) } -publishing { - repositories { - maven { - url = 'https://mvn.woggioni.net/' - } - } - publications { - maven(MavenPublication) { - from(components["java"]) - } - } -} diff --git a/gradle.properties b/gradle.properties index 1655678..a1c52ab 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,3 +1,3 @@ -jwo.version = 2023.03 +jwo.version = 2023.05 lys.version = 2023.03 guice.version = 5.0.1 diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 0c85a1f..37aef8d 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,6 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-8.1-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-8.1.1-bin.zip networkTimeout=10000 zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/jmath/build.gradle b/jmath/build.gradle new file mode 100644 index 0000000..951da63 --- /dev/null +++ b/jmath/build.gradle @@ -0,0 +1,18 @@ +plugins { + id 'java-library' + alias(catalog.plugins.lombok) + id 'maven-publish' +} + +dependencies { + implementation project(':') +} + +java { + withJavadocJar() + withSourcesJar() + + toolchain { + languageVersion = JavaLanguageVersion.of(17) + } +} \ No newline at end of file diff --git a/jmath/src/main/java/net/woggioni/jmath/BigIntegerExt.java b/jmath/src/main/java/net/woggioni/jmath/BigIntegerExt.java new file mode 100644 index 0000000..92b5fdd --- /dev/null +++ b/jmath/src/main/java/net/woggioni/jmath/BigIntegerExt.java @@ -0,0 +1,50 @@ +package net.woggioni.jmath; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; + +import java.math.BigInteger; +import java.util.Objects; + +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public class BigIntegerExt { + + static public int gcd(int a, int b) { + int tmp; + while (b != 0) { + tmp = a; + a = b; + b = tmp % b; + } + return a; + } + + static public BigInteger gcd(BigInteger a, BigInteger b) { + BigInteger tmp; + while (!BigInteger.ZERO.equals(b)) { + tmp = a; + a = b; + b = tmp.mod(b); + } + return a; + } + +// static BigInteger gcd(BigInteger n1, BigInteger n2) { +// BigInteger remainder; +// BigInteger result; +// while (true) { +// remainder = n1.mod(n2); +// result = n2; +// if (BigInteger.ZERO.equals(remainder)) break; +// else { +// n1 = n2; +// n2 = remainder; +// } +// } +// return result; +// } + + public static BigInteger mcm(BigInteger n1, BigInteger n2) { + return n1.multiply(n2).divide(gcd(n1, n2)); + } +} diff --git a/jmath/src/main/java/net/woggioni/jmath/DecimalFactory.java b/jmath/src/main/java/net/woggioni/jmath/DecimalFactory.java new file mode 100644 index 0000000..0db8761 --- /dev/null +++ b/jmath/src/main/java/net/woggioni/jmath/DecimalFactory.java @@ -0,0 +1,32 @@ +package net.woggioni.jmath; + +import lombok.AccessLevel; +import lombok.Getter; +import lombok.RequiredArgsConstructor; + +import java.math.BigDecimal; +import java.math.MathContext; + +@RequiredArgsConstructor +public class DecimalFactory implements NumericTypeFactory { + + @Getter(AccessLevel.PACKAGE) + private final MathContext ctx; + @Override + public DecimalValue getZero() { + return new DecimalValue(BigDecimal.ZERO, this); + } + + @Override + public DecimalValue getOne() { + return new DecimalValue(BigDecimal.ONE, this); + } + + @Override + public DecimalValue[] getArray(int size) { + return new DecimalValue[size]; + } + + @Getter + private static final DecimalFactory instance = new DecimalFactory(MathContext.DECIMAL128); +} diff --git a/jmath/src/main/java/net/woggioni/jmath/DecimalValue.java b/jmath/src/main/java/net/woggioni/jmath/DecimalValue.java new file mode 100644 index 0000000..93d18b7 --- /dev/null +++ b/jmath/src/main/java/net/woggioni/jmath/DecimalValue.java @@ -0,0 +1,63 @@ +package net.woggioni.jmath; + +import lombok.EqualsAndHashCode; +import lombok.RequiredArgsConstructor; + +import java.math.BigDecimal; + +@EqualsAndHashCode(onlyExplicitlyIncluded = true) +@RequiredArgsConstructor +public class DecimalValue implements NumericType { + + @EqualsAndHashCode.Include + private final BigDecimal value; + private final DecimalFactory decimalFactory; + + @Override + public DecimalValue add(DecimalValue other) { + return new DecimalValue(value.add(other.value), decimalFactory); + } + + @Override + public DecimalValue sub(DecimalValue other) { + return new DecimalValue(value.subtract(other.value), decimalFactory); + } + + @Override + public DecimalValue mul(DecimalValue other) { + return new DecimalValue(value.multiply(other.value), decimalFactory); + } + + @Override + public DecimalValue div(DecimalValue other) { + return new DecimalValue(value.divide(other.value, decimalFactory.getCtx()), decimalFactory); + } + + @Override + public DecimalValue abs() { + return new DecimalValue(value.abs(), decimalFactory); + } + + @Override + public DecimalValue sqrt() { + return new DecimalValue(value.sqrt(decimalFactory.getCtx()), decimalFactory); + } + + @Override + public int compareTo(DecimalValue o) { + return value.compareTo(o.value); + } + + public static DecimalValue of(double n, DecimalFactory decimalFactory) { + return new DecimalValue(BigDecimal.valueOf(n), decimalFactory); + } + + public static DecimalValue of(double n) { + return new DecimalValue(BigDecimal.valueOf(n), DecimalFactory.getInstance()); + } + + @Override + public String toString() { + return value.toString(); + } +} diff --git a/jmath/src/main/java/net/woggioni/jmath/FloatFactory.java b/jmath/src/main/java/net/woggioni/jmath/FloatFactory.java new file mode 100644 index 0000000..e350b55 --- /dev/null +++ b/jmath/src/main/java/net/woggioni/jmath/FloatFactory.java @@ -0,0 +1,23 @@ +package net.woggioni.jmath; + +import lombok.Getter; + +public class FloatFactory implements NumericTypeFactory { + @Override + public FloatValue getZero() { + return new FloatValue(0); + } + + @Override + public FloatValue getOne() { + return new FloatValue(1); + } + + @Override + public FloatValue[] getArray(int size) { + return new FloatValue[size]; + } + + @Getter + private static final NumericTypeFactory instance = new FloatFactory(); +} diff --git a/jmath/src/main/java/net/woggioni/jmath/FloatValue.java b/jmath/src/main/java/net/woggioni/jmath/FloatValue.java new file mode 100644 index 0000000..1979872 --- /dev/null +++ b/jmath/src/main/java/net/woggioni/jmath/FloatValue.java @@ -0,0 +1,58 @@ +package net.woggioni.jmath; + +import lombok.EqualsAndHashCode; +import lombok.RequiredArgsConstructor; + +import java.util.Comparator; +import java.util.Objects; + +@EqualsAndHashCode(onlyExplicitlyIncluded = true) +@RequiredArgsConstructor +public class FloatValue implements NumericType { + @EqualsAndHashCode.Include + private final float value; + + @Override + public FloatValue add(FloatValue other) { + return new FloatValue(value + other.value); + } + + @Override + public FloatValue sub(FloatValue other) { + return new FloatValue(value - other.value); + } + + @Override + public FloatValue mul(FloatValue other) { + return new FloatValue(value * other.value); + } + + @Override + public FloatValue div(FloatValue other) { + return new FloatValue(value / other.value); + } + + @Override + public FloatValue abs() { + return new FloatValue(Math.abs(value)); + } + + @Override + public FloatValue sqrt() { + return new FloatValue((float) Math.sqrt(value)); + } + + @Override + public int compareTo(FloatValue o) { + return Comparator.naturalOrder().compare(value, o.value); + } + + public static FloatValue of(float n) { + return new FloatValue(n); + } + + @Override + public String toString() { + return Objects.toString(value); + } +} diff --git a/jmath/src/main/java/net/woggioni/jmath/IntegerFactory.java b/jmath/src/main/java/net/woggioni/jmath/IntegerFactory.java new file mode 100644 index 0000000..f5cb3b4 --- /dev/null +++ b/jmath/src/main/java/net/woggioni/jmath/IntegerFactory.java @@ -0,0 +1,25 @@ +package net.woggioni.jmath; + +import lombok.Getter; + +import java.math.BigInteger; + +public class IntegerFactory implements NumericTypeFactory { + @Override + public IntegerValue getZero() { + return new IntegerValue(BigInteger.ZERO); + } + + @Override + public IntegerValue getOne() { + return new IntegerValue(BigInteger.ONE); + } + + @Override + public IntegerValue[] getArray(int size) { + return new IntegerValue[size]; + } + + @Getter + private static final NumericTypeFactory instance = new IntegerFactory(); +} diff --git a/jmath/src/main/java/net/woggioni/jmath/IntegerValue.java b/jmath/src/main/java/net/woggioni/jmath/IntegerValue.java new file mode 100644 index 0000000..7fa8c7e --- /dev/null +++ b/jmath/src/main/java/net/woggioni/jmath/IntegerValue.java @@ -0,0 +1,57 @@ +package net.woggioni.jmath; + +import lombok.EqualsAndHashCode; +import lombok.RequiredArgsConstructor; + +import java.math.BigInteger; + +@EqualsAndHashCode(onlyExplicitlyIncluded = true) +@RequiredArgsConstructor +public class IntegerValue implements NumericType { + @EqualsAndHashCode.Include + private final BigInteger value; + + @Override + public IntegerValue add(IntegerValue other) { + return new IntegerValue(value.add(other.value)); + } + + @Override + public IntegerValue sub(IntegerValue other) { + return new IntegerValue(value.subtract(other.value)); + } + + @Override + public IntegerValue mul(IntegerValue other) { + return new IntegerValue(value.multiply(other.value)); + } + + @Override + public IntegerValue div(IntegerValue other) { + return new IntegerValue(value.divide(other.value)); + } + + @Override + public IntegerValue abs() { + return new IntegerValue(value.abs()); + } + + @Override + public IntegerValue sqrt() { + return new IntegerValue(value.sqrt()); + } + + @Override + public int compareTo(IntegerValue o) { + return value.compareTo(o.value); + } + + public static IntegerValue of(long n) { + return new IntegerValue(BigInteger.valueOf(n)); + } + + @Override + public String toString() { + return value.toString(); + } +} diff --git a/jmath/src/main/java/net/woggioni/jmath/Matrix.java b/jmath/src/main/java/net/woggioni/jmath/Matrix.java new file mode 100644 index 0000000..dbdfccf --- /dev/null +++ b/jmath/src/main/java/net/woggioni/jmath/Matrix.java @@ -0,0 +1,561 @@ +package net.woggioni.jmath; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.Setter; +import lombok.SneakyThrows; + +import java.util.Objects; +import java.util.function.BiFunction; +import java.util.function.Function; + +import static net.woggioni.jwo.Requirement.require; + +public class Matrix> { + + public interface Pivot { + > Matrix mul(Matrix m); + } + + @RequiredArgsConstructor + private static class PivotImpl implements Pivot { + @Getter + private final int[] values; + + @Getter + @Setter + private int permutations = 0; + + @Override + public > Matrix mul(Matrix m) { + return Matrix.of(m.numericTypeFactory, m.getRows(), m.getColumns(), (i, j) -> m.get(values[i], j)); + } + } + + public interface ValueGenerator> { + T generate(int row, int column); + } + + private final NumericTypeFactory numericTypeFactory; + private final T[] values; + private final int rows; + private final int columns; + + @SneakyThrows + public Matrix(int rows, int columns, NumericTypeFactory numericTypeFactory) { + this.numericTypeFactory = numericTypeFactory; + this.rows = rows; + this.columns = columns; + this.values = numericTypeFactory.getArray(rows * columns); + for (int i = 0; i < rows; i++) { + for (int j = 0; j < columns; j++) { + set(i, j, numericTypeFactory.getZero()); + } + } + } + + public static > Matrix of( + NumericTypeFactory numericTypeFactory, + int rows, + int columns, + ValueGenerator generator) { + Matrix result = new Matrix<>(rows, columns, numericTypeFactory); + for (int i = 0; i < rows; i++) { + for (int j = 0; j < columns; j++) { + result.set(i, j, generator.generate(i, j)); + } + } + return result; + } + + public static > Matrix of( + NumericTypeFactory numericTypeFactory, + int rows, + int columns) { + return of(numericTypeFactory, rows, columns, (i, j) -> numericTypeFactory.getZero()); + } + + public static > Matrix of(NumericTypeFactory numericTypeFactory, T[][] values) { + int rows = values.length; + int columns = values[0].length; + return of(numericTypeFactory, rows, columns, (i, j) -> values[i][j]); + } + + public static > Matrix of( + NumericTypeFactory numericTypeFactory, int rows, int columns, T... values) { + return of(numericTypeFactory, rows, columns, (i, j) -> values[i * columns + j]); + } + + public T get(int row, int column) { + return values[row * columns + column]; + } + + public void set(int row, int column, T value) { + values[row * columns + column] = value; + } + + public int getRows() { + return rows; + } + + public int getColumns() { + return columns; + } + + private void requireSameSize(Matrix other) { + require(() -> getRows() == other.getRows() && getColumns() == other.getColumns()) + .otherwise(SizeException.class, "Matrix dimension mismatch: (%d, %d) vs (%d, %d)", + getRows(), getColumns(), other.getRows(), other.getColumns() + ); + } + + public Matrix map(Matrix other, BiFunction op) { + requireSameSize(other); + return of(numericTypeFactory, getRows(), getColumns(), (i, j) -> op.apply(get(i, j), other.get(i, j))); + } + + public > Matrix map(NumericTypeFactory numericTypeFactory, Function op) { + return of(numericTypeFactory, getRows(), getColumns(), (i, j) -> op.apply(get(i, j))); + } + + public Matrix map(Function op) { + return of(numericTypeFactory, getRows(), getColumns(), (i, j) -> op.apply(get(i, j))); + } + + public Matrix add(Matrix other) { + return map(other, T::add); + } + + public Matrix sub(Matrix other) { + return map(other, T::sub); + } + + public Matrix mul(Matrix other) { + return map(other, T::mul); + } + + public Matrix div(Matrix other) { + return map(other, T::div); + } + + public Matrix add(T value) { + return map(numericTypeFactory, (T v) -> v.add(value)); + } + + public Matrix sub(T value) { + return map(numericTypeFactory, (T v) -> v.sub(value)); + } + + public Matrix mul(T value) { + return map(numericTypeFactory, (T v) -> v.mul(value)); + } + + public Matrix div(T value) { + return map(numericTypeFactory, (T v) -> v.div(value)); + } + + + public Vector solve(Vector b) { + Matrix tmp = clone(); + Pivot pivot = tmp.lup(); + return tmp.luSolve(b, pivot); + } + + private void swapRows(int id1, int id2) { + for (int i = 0; i < getColumns(); i++) { + T tmp = get(id1, i); + set(id1, i, get(id2, i)); + set(id2, i, tmp); + } + } + + private void swapRows(int id1, int id2, PivotImpl pivot) { + swapRows(id1, id2); + int tmp = pivot.getValues()[id1]; + pivot.getValues()[id1] = pivot.getValues()[id2]; + pivot.getValues()[id2] = tmp; + pivot.permutations += 1; + } + + private void swapRows(int id1, int id2, PivotImpl pivot, Matrix other) { + swapRows(id1, id2, pivot); + other.swapRows(id1, id2); + } + + private void luRow(int i) { + if (Objects.equals(numericTypeFactory.getZero(), get(i, i))) { + throw new RuntimeException("Matrix is singular"); + } + for (int j = i; j < getColumns(); j++) { + for (int k = 0; k < i; k++) { + set(i, j, get(i, j).sub(get(i, k).mul(get(k, j)))); + } + } + for (int j = i + 1; j < getColumns(); j++) { + for (int k = 0; k < i; k++) { + set(j, i, get(j, i).sub(get(j, k).mul(get(k, i)))); + } + set(j, i, get(j, i).div(get(i, i))); + } + } + + private void luPivot(int i, Pivot pivot) { + T max = get(i, i).abs(); + int max_index = i; + for (int j = i + 1; j < getRows(); j++) { + if (get(j, i).abs().compareTo(max) > 0) { + max = get(i, j).abs(); + max_index = j; + } + } + if (max_index != i) { + swapRows(i, max_index, (PivotImpl) pivot); + } + } + + public Pivot lup() { + if (getRows() != getColumns()) throw new SizeException("Matrix must be square"); + int size = getRows(); + PivotImpl pivot = newPivot(); + for (int i = 0; i < size; i++) { + luPivot(i, pivot); + luRow(i); + } + return pivot; + } + + private PivotImpl newPivot() { + int sz = getRows(); + int[] result = new int[sz]; + for (int i = 0; i < sz; i++) { + result[i] = i; + } + return new PivotImpl(result); + } + + private void addRow(int sourceIndex, int destIndex, T factor) { + int columns = getColumns(); + for (int i = 0; i < columns; i++) { + set(destIndex, i, get(destIndex, i).add(get(sourceIndex, i).mul(factor))); + } + } + + private void addRow(int sourceIndex, int destIndex, T factor, Matrix other) { + addRow(sourceIndex, destIndex, factor); + other.addRow(sourceIndex, destIndex, factor); + } + + public void gaussJordanLow() { + PivotImpl pivot = newPivot(); + int rows = getRows(); + int columns = getColumns(); + for (int i = 0; i < rows; i++) { + if (Objects.equals(numericTypeFactory.getZero(), get(i, i))) { + for (int j = i + 1; j < columns; j++) { + if (!Objects.equals(numericTypeFactory.getZero(), get(j, i))) { + swapRows(i, j, pivot); + break; + } + } + } + for (int j = i + 1; j < rows; j++) { + T ii = get(i, i); + if (!Objects.equals(numericTypeFactory.getZero(), ii)) { + T factor = get(j, i).div(ii).mul(numericTypeFactory.getMinusOne()); + addRow(i, j, factor); + } + } + } + } + + private void gaussJordanLow(Matrix other) { + PivotImpl pivot = newPivot(); + int rows = getRows(); + int columns = getColumns(); + for (int i = 0; i < rows; i++) { + if (Objects.equals(numericTypeFactory.getZero(), get(i, i))) { + for (int j = i + 1; j < columns; j++) { + if (!Objects.equals(numericTypeFactory.getZero(), get(j, i))) { + swapRows(i, j, pivot, other); + break; + } + } + } + for (int j = i + 1; j < rows; j++) { + T ii = get(i, i); + if (!Objects.equals(numericTypeFactory.getZero(), ii)) { + T factor = get(j, i).div(ii).mul(numericTypeFactory.getMinusOne()); + addRow(i, j, factor, other); + } + } + } + } + + private void gaussJordanHigh() { + PivotImpl pivot = newPivot(); + int i = getRows(); + while (i-- > 0) { + if (Objects.equals(numericTypeFactory.getZero(), get(i, i))) { + int j = i; + while (j-- > 0) { + if (!Objects.equals(numericTypeFactory.getZero(), get(j, i))) { + swapRows(i, j, pivot); + break; + } + } + } + int j = i; + while (j-- > 0) { + T ii = get(i, i); + if (!Objects.equals(numericTypeFactory.getZero(), ii)) { + T factor = get(j, i).div(ii).mul(numericTypeFactory.getMinusOne()); + addRow(i, j, factor); + } + } + } + } + + private void gaussJordanHigh(Matrix other) { + PivotImpl pivot = newPivot(); + int i = getRows(); + while (i-- > 0) { + if (Objects.equals(numericTypeFactory.getZero(), get(i, i))) { + int j = i; + while (j-- > 0) { + if (!Objects.equals(numericTypeFactory.getZero(), get(j, i))) { + swapRows(i, j, pivot, other); + break; + } + } + } + int j = i; + while (j-- > 0) { + T ii = get(i, i); + if (!Objects.equals(numericTypeFactory.getZero(), ii)) { + T factor = get(j, i).div(ii).mul(numericTypeFactory.getMinusOne()); + addRow(i, j, factor, other); + } + } + } + } + + public T det() { + require(() -> getRows() == getColumns()).otherwise(SizeException.class, "Matrix must be square"); + Matrix clone = clone(); + clone.gaussJordanLow(); + T result = numericTypeFactory.getOne(); + for (int i = 0; i < getRows(); i++) + result = result.mul(clone.get(i, i)); + return result; + } + + public Matrix invert() { + require(() -> getRows() == getColumns()).otherwise(SizeException.class, "Matrix must be square"); + int sz = getRows(); + int col = getColumns(); + Matrix tmp = clone(); + Matrix result = identity(numericTypeFactory, sz); + tmp.gaussJordanLow(result); + tmp.gaussJordanHigh(result); + for (int i = 0; i < sz; i++) { + T f = tmp.get(i, i); + for (int j = 0; j < col; j++) { + result.set(i, j, result.get(i, j).div(f)); + } + } + return result; + } + + public Matrix triu() { + return of(numericTypeFactory, getRows(), getColumns(), (i, j) -> + i <= j ? get(i, j) : numericTypeFactory.getZero() + ); + } + + public Matrix triu(T diagValue) { + return of(numericTypeFactory, getRows(), getColumns(), (i, j) -> { + T result; + if (i < j) { + result = get(i, j); + } else if (i == j) { + result = diagValue; + } else { + result = numericTypeFactory.getZero(); + } + return result; + }); + } + + + public Matrix tril() { + return of(numericTypeFactory, getRows(), getColumns(), (i, j) -> + i >= j ? get(i, j) : numericTypeFactory.getZero() + ); + } + + public Matrix tril(T diagValue) { + return of(numericTypeFactory, getRows(), getColumns(), (i, j) -> { + T result; + if (i > j) { + result = get(i, j); + } else if (i == j) { + result = diagValue; + } else { + result = numericTypeFactory.getZero(); + } + return result; + }); + } + + + @SneakyThrows + public Vector luSolve(Vector b, Pivot pivot) { + Objects.requireNonNull(b); + PivotImpl pivotImpl = (PivotImpl) pivot; + int[] pivotValues = pivotImpl.getValues(); + int size = getRows(); + if (pivotValues.length != size) throw new SizeException( + String.format("Pivot length is %d must be %d instead", pivotValues.length, size)); + + for (int i = 0; i < pivotValues.length; i++) pivotValues[i] = i; + Vector x = Vector.of(numericTypeFactory, size); + + for (int i = 0; i < size; i++) { + x.set(i, b.get(pivotValues[i])); + for (int k = 0; k < i; k++) { + x.set(i, x.get(i).sub(get(i, k).mul(x.get(k)))); + } + } + int i = size; + while (i-- > 0) { + for (int k = i + 1; k < size; k++) { + x.set(i, x.get(i).sub(get(i, k).mul(x.get(k)))); + } + if (!Objects.equals(get(i, i), numericTypeFactory.getZero())) { + x.set(i, x.get(i).div(get(i, i))); + } else throw new SingularMatrixException("Matrix is singular"); + } + return x; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append('['); + for (int i = 0; i < getRows(); i++) { + sb.append('['); + for (int j = 0; j < getRows(); j++) { + if (j > 0) sb.append(", "); + sb.append(get(i, j)); + } + sb.append(']'); + sb.append('\n'); + } + sb.append(']'); + return sb.toString(); + } + + @Override + public Matrix clone() { + return of(numericTypeFactory, getRows(), getColumns(), this::get); + } + + public static > Matrix identity(NumericTypeFactory numericTypeFactory, int size) { + return of(numericTypeFactory, size, size, (i, j) -> { + T result; + if (i == j) result = numericTypeFactory.getOne(); + else result = numericTypeFactory.getZero(); + return result; + }); + } + + public T luDet() { + require(() -> getRows() == getColumns()).otherwise(SizeException.class, "Matrix must be square"); + Matrix clone = clone(); + PivotImpl pivot = (PivotImpl) clone.lup(); + T result = numericTypeFactory.getOne(); + for (int i = 0; i < rows; i++) { + result = result.mul(clone.get(i, i)); + } + if (pivot.permutations % 2 != 0) { + result = result.mul(numericTypeFactory.getMinusOne()); + } + return result; + } + + public Matrix transpose() { + return Matrix.of(numericTypeFactory, getColumns(), getRows(), (i, j) -> get(j, i)); + } + + public Matrix mmul(Matrix m2) { + return Matrix.of(numericTypeFactory, getRows(), m2.getColumns(), (i, j) -> { + T result = numericTypeFactory.getZero(); + for (int k = 0; k < getColumns(); k++) { + result = result.add(get(i, k).mul(m2.get(k, j))); + } + return result; + }); + } + + public Vector mmul(Vector v) { + return Vector.of(numericTypeFactory, getRows(), i -> { + int columns = getColumns(); + T result = numericTypeFactory.getZero(); + for (int j = 0; j < columns; j++) { + result = result.add(get(i, j).mul(v.get(j))); + } + return result; + }); + } +// public T luDet() { +// if(getRows() != getColumns()) { +// throw newThrowable( +// SizeException.class, +// "Matrix must be square in order to compute the determinant"); +// } +// int[] pivot = lup(); +// T result = numericTypeFactory.getOne(); +// for(int i=0; i> other)) + return false; + if (getRows() != other.getRows() || getColumns() != other.getColumns()) { + return false; + } + int r = getRows(); + int c = getColumns(); + for (int i = 0; i < r; i++) { + for (int j = 0; j < c; j++) { + if (!Objects.equals(get(i, j), other.get(i, j))) return false; + } + } + return true; + } + + public T squaredNorm2() { + T result = numericTypeFactory.getZero(); + for (int i = 0; i < rows; i++) { + for (int j = 0; j < columns; j++) { + T value = get(i, j); + result = result.add(value.mul(value)); + } + } + return result; + } + + public T norm2() { + return squaredNorm2().sqrt(); + } +} diff --git a/jmath/src/main/java/net/woggioni/jmath/NumericType.java b/jmath/src/main/java/net/woggioni/jmath/NumericType.java new file mode 100644 index 0000000..ab31304 --- /dev/null +++ b/jmath/src/main/java/net/woggioni/jmath/NumericType.java @@ -0,0 +1,10 @@ +package net.woggioni.jmath; + +public interface NumericType extends Comparable { + T add(T other); + T sub(T other); + T mul(T other); + T div(T other); + T abs(); + T sqrt(); +} diff --git a/jmath/src/main/java/net/woggioni/jmath/NumericTypeFactory.java b/jmath/src/main/java/net/woggioni/jmath/NumericTypeFactory.java new file mode 100644 index 0000000..eb988a2 --- /dev/null +++ b/jmath/src/main/java/net/woggioni/jmath/NumericTypeFactory.java @@ -0,0 +1,10 @@ +package net.woggioni.jmath; + +public interface NumericTypeFactory> { + T getZero(); + T getOne(); + default T getMinusOne() { + return getZero().sub(getOne()); + } + T[] getArray(int size); +} diff --git a/jmath/src/main/java/net/woggioni/jmath/Rational.java b/jmath/src/main/java/net/woggioni/jmath/Rational.java new file mode 100644 index 0000000..a76c269 --- /dev/null +++ b/jmath/src/main/java/net/woggioni/jmath/Rational.java @@ -0,0 +1,125 @@ +package net.woggioni.jmath; + +import java.math.BigInteger; +import java.util.Objects; + +import static net.woggioni.jmath.BigIntegerExt.mcm; + +public class Rational implements NumericType { + + public static final Rational ZERO = new Rational(0, 1); + public static final Rational ONE = new Rational(1, 1); + public static final Rational MINUS_ONE = new Rational(-1, 1); + public static final BigInteger MINUS_ONE_BI = BigInteger.ZERO.subtract(BigInteger.ONE); + + private BigInteger num; + private BigInteger den; + + @Override + public Rational add(Rational other) { + return new Rational( + this.num.multiply(other.den).add(other.num.multiply(this.den)), + this.den.multiply(other.den) + ).simplify(); + } + + @Override + public Rational sub(Rational other) { + return new Rational( + this.num.multiply(other.den).subtract(other.num.multiply(this.den)), + this.den.multiply(other.den) + ).simplify(); + } + + @Override + public Rational mul(Rational other) { + return new Rational(this.num.multiply(other.num), this.den.multiply(other.den)).simplify(); + } + + @Override + public Rational div(Rational other) { + return new Rational(this.num.multiply(other.den), this.den.multiply(other.num)).simplify(); + } + + @Override + public Rational sqrt() { + return new Rational(num.sqrt(), den.sqrt()); + } + + @Override + public Rational abs() { + return new Rational(num.abs(), den.abs()); + } + + public Rational(BigInteger num, BigInteger den) { + this.num = num; + this.den = den; + } + + public Rational(long num, long den) { + this.num = BigInteger.valueOf(num); + this.den = BigInteger.valueOf(den); + } + + public static Rational of(long num, long den) { + return new Rational(num, den); + } + + public static Rational of(long n) { + return new Rational(n, 1); + } + + private Rational simplify() { + BigInteger gcd = BigIntegerExt.gcd(num.abs(), den.abs()); + num = num.divide(gcd); + den = den.divide(gcd); + if(den.compareTo(BigInteger.ZERO) < 0) { + num = num.multiply(MINUS_ONE_BI); + den = den.multiply(MINUS_ONE_BI); + } + return this; + } + + public BigInteger getNum() { + return num; + } + + public BigInteger getDen() { + return den; + } + @Override + public String toString() { + String result; + if (Objects.equals(BigInteger.ZERO, num) && !Objects.equals(BigInteger.ZERO, den)) result = "0"; + else if (Objects.equals(BigInteger.ONE, den.abs())) result = num.multiply(den.abs().divide(den)).toString(); + else { + boolean negative = Objects.equals(num, num.abs()) ^ Objects.equals(den, den.abs()); + result = String.format("%s%d/%d", negative ? "-" : "", num.abs(), den.abs()); + } + return result; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Rational rational = (Rational) o; + BigInteger mcm = mcm(den, rational.den); + BigInteger v1 = mcm.divide(den).multiply(num); + BigInteger v2 = mcm.divide(rational.den).multiply(rational.num); + return Objects.equals(v1, v2); + } + + @Override + public int hashCode() { + return Objects.hash(num, den); + } + + @Override + public int compareTo(Rational o) { + BigInteger mcm = mcm(getDen(), o.getDen()); + BigInteger n1 = mcm.divide(getDen()).multiply(getNum()); + BigInteger n2 = mcm.divide(o.getDen()).multiply(o.getNum()); + return n1.compareTo(n2); + } +} \ No newline at end of file diff --git a/jmath/src/main/java/net/woggioni/jmath/RationalFactory.java b/jmath/src/main/java/net/woggioni/jmath/RationalFactory.java new file mode 100644 index 0000000..ef24712 --- /dev/null +++ b/jmath/src/main/java/net/woggioni/jmath/RationalFactory.java @@ -0,0 +1,23 @@ +package net.woggioni.jmath; + +import lombok.Getter; + +public class RationalFactory implements NumericTypeFactory { + @Override + public Rational getZero() { + return Rational.ZERO; + } + + @Override + public Rational getOne() { + return Rational.ONE; + } + + @Override + public Rational[] getArray(int size) { + return new Rational[size]; + } + + @Getter + private static final NumericTypeFactory instance = new RationalFactory(); +} diff --git a/jmath/src/main/java/net/woggioni/jmath/SingularMatrixException.java b/jmath/src/main/java/net/woggioni/jmath/SingularMatrixException.java new file mode 100644 index 0000000..e432f39 --- /dev/null +++ b/jmath/src/main/java/net/woggioni/jmath/SingularMatrixException.java @@ -0,0 +1,7 @@ +package net.woggioni.jmath; + +public class SingularMatrixException extends RuntimeException { + public SingularMatrixException(String msg) { + super(msg); + } +} diff --git a/jmath/src/main/java/net/woggioni/jmath/SizeException.java b/jmath/src/main/java/net/woggioni/jmath/SizeException.java new file mode 100644 index 0000000..c833d51 --- /dev/null +++ b/jmath/src/main/java/net/woggioni/jmath/SizeException.java @@ -0,0 +1,7 @@ +package net.woggioni.jmath; + +public class SizeException extends RuntimeException { + public SizeException(String msg) { + super(msg); + } +} diff --git a/jmath/src/main/java/net/woggioni/jmath/Vector.java b/jmath/src/main/java/net/woggioni/jmath/Vector.java new file mode 100644 index 0000000..32045c2 --- /dev/null +++ b/jmath/src/main/java/net/woggioni/jmath/Vector.java @@ -0,0 +1,142 @@ +package net.woggioni.jmath; + +import lombok.AccessLevel; +import lombok.RequiredArgsConstructor; + +import java.util.Objects; +import java.util.function.IntFunction; + +import static net.woggioni.jwo.Requirement.require; + +@RequiredArgsConstructor(access = AccessLevel.PRIVATE) +public class Vector> { + private final NumericTypeFactory numericTypeFactory; + private final T[] values; + + private Vector(NumericTypeFactory numericTypeFactory, int size) { + this(numericTypeFactory, numericTypeFactory.getArray(size)); + } + + private Vector(NumericTypeFactory numericTypeFactory, int size, IntFunction valueGenerator) { + this(numericTypeFactory, size); + for (int i = 0; i < size; i++) + set(i, valueGenerator.apply(i)); + } + + public int size() { + return values.length; + } + + public T get(int index) { + return values[index]; + } + + public void set(int index, T value) { + values[index] = value; + } + + private void requireSameSize(Vector other) { + require(() -> other.size() == size()).otherwise(SizeException.class, "Vectors must be of same size"); + } + + public Vector sum(Vector other) { + requireSameSize(other); + return new Vector<>(numericTypeFactory, size(), index -> get(index).add(other.get(index))); + } + + public Vector sub(Vector other) { + requireSameSize(other); + return new Vector<>(numericTypeFactory, size(), index -> get(index).sub(other.get(index))); + } + + public Vector mul(Vector other) { + requireSameSize(other); + return new Vector<>(numericTypeFactory, size(), index -> get(index).mul(other.get(index))); + } + + public Vector div(Vector other) { + requireSameSize(other); + return new Vector<>(numericTypeFactory, size(), index -> get(index).div(other.get(index))); + } + + public Vector mul(Matrix m) { + return Vector.of(numericTypeFactory, size(), i -> { + T result = numericTypeFactory.getZero(); + for (int j = 0; j < m.getRows(); j++) { + result = result.add(get(j).mul(m.get(j, i))); + } + return result; + }); + } + + public T innerProduct(Vector other) { + requireSameSize(other); + T result = numericTypeFactory.getZero(); + int sz = size(); + for (int i = 0; i < sz; i++) + result = result.add(get(i).mul(other.get(i))); + return result; + } + + public T norm() { + T result = numericTypeFactory.getZero(); + int sz = size(); + for (int i = 0; i < sz; i++) { + T value = get(i); + result = result.add(value.mul(value)); + } + return result; + } + + public T abs() { + return norm().sqrt(); + } + + @Override + public Vector clone() { + int sz = size(); + Vector result = new Vector<>(numericTypeFactory, sz); + for (int i = 0; i < sz; i++) { + result.set(i, get(i)); + } + return result; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof Vector vector)) { + return false; + } + if (size() != vector.size()) return false; + for (int i = 0; i < size(); i++) { + if (!Objects.equals(get(i), vector.get(i))) return false; + } + return true; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append('['); + for (int i = 0; i < size(); i++) { + if (i > 0) sb.append(", "); + sb.append(get(i)); + } + sb.append(']'); + return sb.toString(); + } + + public static > Vector of(NumericTypeFactory numericTypeFactory, T... values) { + return new Vector<>(numericTypeFactory, values); + } + + public static > Vector of(NumericTypeFactory numericTypeFactory, int size) { + return new Vector<>(numericTypeFactory, size); + } + + public static > Vector of(NumericTypeFactory numericTypeFactory, + int size, + IntFunction generator) { + return new Vector<>(numericTypeFactory, size, generator); + } +} diff --git a/jmath/src/test/java/net/woggioni/jmath/MatrixTest.java b/jmath/src/test/java/net/woggioni/jmath/MatrixTest.java new file mode 100644 index 0000000..48160fe --- /dev/null +++ b/jmath/src/test/java/net/woggioni/jmath/MatrixTest.java @@ -0,0 +1,206 @@ +package net.woggioni.jmath; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.util.stream.IntStream; + +public class MatrixTest { + + private static Matrix testMatrix; + private static Matrix integerTestMatrix; + + @BeforeAll + public static void setup() { + testMatrix = Matrix.of(RationalFactory.getInstance(), new Rational[][]{ + new Rational[]{ + Rational.ONE, Rational.ZERO, Rational.of(-1, 2), Rational.ZERO, Rational.ZERO + }, + new Rational[]{ + Rational.ZERO, Rational.ONE, Rational.of(-1, 2), Rational.ZERO, Rational.ZERO + }, + new Rational[]{ + Rational.of(-4, 9), Rational.ZERO, Rational.ONE, Rational.ZERO, Rational.ZERO + }, + new Rational[]{ + Rational.of(-1, 3), Rational.ZERO, Rational.ZERO, Rational.ONE, Rational.ZERO + }, + new Rational[]{ + Rational.of(-2, 9), Rational.ZERO, Rational.ZERO, Rational.ZERO, Rational.ONE + }, + }); + + NumericTypeFactory integerFactory = IntegerFactory.getInstance(); + integerTestMatrix = Matrix.of(integerFactory, 3, 3, + IntStream.of(1, 2, 3, 4, 5, 6, 7, 8, 9).mapToObj(IntegerValue::of).toArray(IntegerValue[]::new)); + } + + @Test + void trilTest() { + NumericTypeFactory rationalFactory = IntegerFactory.getInstance(); + Matrix expected = Matrix.of(rationalFactory, 3, 3, + IntStream.of(1, 0, 0, 4, 5, 0, 7, 8, 9).mapToObj(IntegerValue::of).toArray(IntegerValue[]::new)); + Assertions.assertEquals(expected, integerTestMatrix.tril()); + Matrix expected2 = Matrix.of(rationalFactory, 3, 3, + IntStream.of(-3, 0, 0, 4, -3, 0, 7, 8, -3).mapToObj(IntegerValue::of).toArray(IntegerValue[]::new)); + Assertions.assertEquals(expected2, integerTestMatrix.tril(IntegerValue.of(-3))); + } + + @Test + void triuTest() { + NumericTypeFactory rationalFactory = IntegerFactory.getInstance(); + Matrix expected = Matrix.of(rationalFactory, 3, 3, + IntStream.of(1, 2, 3, 0, 5, 6, 0, 0, 9).mapToObj(IntegerValue::of).toArray(IntegerValue[]::new)); + Assertions.assertEquals(expected, integerTestMatrix.triu()); + Matrix expected2 = Matrix.of(rationalFactory, 3, 3, + IntStream.of(-3, 2, 3, 0, -3, 6, 0, 0, -3).mapToObj(IntegerValue::of).toArray(IntegerValue[]::new)); + Assertions.assertEquals(expected2, integerTestMatrix.triu(IntegerValue.of(-3))); + } + + @Test + void transposeTest() { + NumericTypeFactory integerFactory = IntegerFactory.getInstance(); + Matrix m = Matrix.of(integerFactory, 3, 3, + (i, j) -> IntegerValue.of(1 + i * 3L + j)); + Matrix transpose = Matrix.of(integerFactory, 3, 3, + (i, j) -> m.get(j, i)); + Assertions.assertEquals(transpose, m.transpose()); + } + + @Test + void testInverse() { + NumericTypeFactory rationalFactory = RationalFactory.getInstance(); + int[] nums = new int[]{1, 2, 3, 4, 5, 6, 8, 7, 9}; + Matrix m = Matrix.of(rationalFactory, 3, 3, + (i, j) -> Rational.of(nums[i * 3 + j], 1)); + Matrix inverse = m.invert(); + Matrix identity = Matrix.identity(rationalFactory, 3); + Assertions.assertEquals(identity, m.mmul(inverse)); + Assertions.assertEquals(identity, inverse.mmul(m)); + } + + @Test + void addTest() { + NumericTypeFactory integerFactory = IntegerFactory.getInstance(); + Matrix m2 = Matrix.of(integerFactory, 3, 3, + (i, j) -> integerTestMatrix.get(i, j).mul(integerFactory.getMinusOne())); + Matrix expected = Matrix.of(integerFactory, 3, 3, + (i, j) -> integerFactory.getZero()); + Assertions.assertEquals(expected, integerTestMatrix.add(m2)); + } + + @Test + void subTest() { + NumericTypeFactory integerFactory = IntegerFactory.getInstance(); + Matrix expected = Matrix.of(integerFactory, 3, 3, + (i, j) -> integerFactory.getZero()); + Assertions.assertEquals(expected, integerTestMatrix.sub(integerTestMatrix)); + } + + @Test + void mulTest() { + NumericTypeFactory integerFactory = IntegerFactory.getInstance(); + Matrix expected = Matrix.of(integerFactory, 3, 3, + (i, j) -> integerTestMatrix.get(i, j).mul(integerTestMatrix.get(i, j))); + Assertions.assertEquals(expected, integerTestMatrix.mul(integerTestMatrix)); + } + + @Test + void divTest() { + NumericTypeFactory integerFactory = IntegerFactory.getInstance(); + Matrix expected = Matrix.of(integerFactory, 3, 3, + (i, j) -> IntegerValue.of(1)); + Assertions.assertEquals(expected, integerTestMatrix.div(integerTestMatrix)); + } + + @Test + void linProdTest() { + NumericTypeFactory integerFactory = IntegerFactory.getInstance(); + Matrix m = Matrix.of(integerFactory, 2, 5, + IntStream.of(1, 4, 8, 2, 5, 7, 3, 6, 9, 0).mapToObj(IntegerValue::of).toArray(IntegerValue[]::new)); + Matrix expected = Matrix.of(integerFactory, 2, 2, + IntStream.of(110, 85, 85, 175).mapToObj(IntegerValue::of).toArray(IntegerValue[]::new)); + Assertions.assertEquals(expected, m.mmul(m.transpose())); + + Matrix expected2 = Matrix.of(integerFactory, 5, 5, + IntStream.of(50, 25, 50, 65, 5, + 25, 25, 50, 35, 20, + 50, 50, 100, 70, 40, + 65, 35, 70, 85, 10, + 5, 20, 40, 10, 25) + .mapToObj(IntegerValue::of).toArray(IntegerValue[]::new)); + Assertions.assertEquals(expected2, m.transpose().mmul(m)); + } + + @Test + void linProdVectorTest() { + NumericTypeFactory integerFactory = IntegerFactory.getInstance(); + Matrix m = Matrix.of(integerFactory, 3, 3, + IntStream.of(1, 2, 3, 4, 5, 6, 8, 7, 9).mapToObj(IntegerValue::of).toArray(IntegerValue[]::new)); + Vector v = Vector.of(integerFactory, + IntStream.of(1, 2, 3).mapToObj(IntegerValue::of).toArray(IntegerValue[]::new)); + Vector expected = Vector.of(integerFactory, + IntStream.of(14, 32, 49).mapToObj(IntegerValue::of).toArray(IntegerValue[]::new)); + Assertions.assertEquals(expected, m.mmul(v)); + } + + @Test + void determinantTest() { + NumericTypeFactory rationalFactory = RationalFactory.getInstance(); + Matrix m = Matrix.of(rationalFactory, 3, 3, + IntStream.of(1, 2, 3, 4, 5, 6, 8, 7, 9).mapToObj(Rational::of).toArray(Rational[]::new)); + Assertions.assertEquals(Rational.of(-9), m.det()); + Assertions.assertEquals(Rational.of(-9), m.luDet()); + Matrix m2 = Matrix.of(rationalFactory, 3, 3, + IntStream.of(1, 2, 3, 4, 5, 6, 7, 8, 9).mapToObj(Rational::of).toArray(Rational[]::new)); + Assertions.assertEquals(Rational.ZERO, m2.det()); + Assertions.assertEquals(Rational.ZERO, m2.luDet()); + } + + @Test + void luTest() { + NumericTypeFactory rationalFactory = RationalFactory.getInstance(); + Matrix m = Matrix.of(rationalFactory, 3, 3, + IntStream.of(8, 7, 9, 1, 2, 3, 4, 5, 6).mapToObj(Rational::of).toArray(Rational[]::new)); + Matrix lu = m.clone(); + Matrix.Pivot p = lu.lup(); + Matrix result = p.mul(lu.tril(Rational.ONE).mmul(lu.triu())); + Assertions.assertEquals(m, result); + } + + @Test + void mmulTest() { + NumericTypeFactory integerFactory = IntegerFactory.getInstance(); + Matrix m = Matrix.of(integerFactory, 3, 3, + IntStream.of(1, 2, 3, 4, 5, 6, 8, 7, 9).mapToObj(IntegerValue::of).toArray(IntegerValue[]::new)); + Vector v = Vector.of( + integerFactory, IntStream.of(1, 2, 3) + .mapToObj(IntegerValue::of).toArray(IntegerValue[]::new)); + Vector expected = Vector.of(integerFactory, + IntStream.of(33, 33, 42).mapToObj(IntegerValue::of).toArray(IntegerValue[]::new)); + Assertions.assertEquals(expected, v.mul(m)); + } + + @Test + public void linearSystemTest() { + Vector expectedSolution = Vector.of(RationalFactory.getInstance(), + Rational.of(9, 14), + Rational.of(9, 14), + Rational.of(2, 7), + Rational.of(3, 14), + Rational.of(1, 7) + ); + + Vector b = Vector.of(RationalFactory.getInstance(), + Rational.of(1, 2), Rational.of(1, 2), Rational.ZERO, Rational.ZERO, Rational.ZERO + ); + + Vector solution = testMatrix.solve(b); + + Assertions.assertEquals(expectedSolution, solution); + Matrix inverse = testMatrix.invert(); + Assertions.assertEquals(expectedSolution, inverse.mmul(b)); + } + +} \ No newline at end of file diff --git a/jmath/src/test/java/net/woggioni/jmath/RationalTest.java b/jmath/src/test/java/net/woggioni/jmath/RationalTest.java new file mode 100644 index 0000000..379a8aa --- /dev/null +++ b/jmath/src/test/java/net/woggioni/jmath/RationalTest.java @@ -0,0 +1,102 @@ +package net.woggioni.jmath; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.ArgumentsProvider; +import org.junit.jupiter.params.provider.ArgumentsSource; + +import java.math.BigInteger; +import static java.math.BigInteger.valueOf; +import java.util.stream.Stream; + +public class RationalTest { + + private enum Operation { + ADD, SUBTRACT, MULTIPLY, DIVIDE + } + + private static class GcdTestCaseProvider implements ArgumentsProvider { + @Override + public Stream provideArguments(ExtensionContext context) { + return Stream.of( + Arguments.of(valueOf(10), valueOf(25), valueOf(5)), + Arguments.of(valueOf(9), valueOf(7), valueOf(1)), + Arguments.of(valueOf(101), valueOf(3), valueOf(1)), + Arguments.of(valueOf(5), valueOf(105), valueOf(5)), + Arguments.of(valueOf(91), valueOf(91), valueOf(91)), + Arguments.of(valueOf(45), valueOf(81), valueOf(9)), + Arguments.of(valueOf(1), valueOf(1), valueOf(1)), + Arguments.of(valueOf(11), valueOf(34), valueOf(1)) + ); + } + } + @ParameterizedTest(name="input: {0}, {1}, expected outcome: {2}") + @ArgumentsSource(GcdTestCaseProvider.class) + void gcdTest(BigInteger n1, BigInteger n2, BigInteger expected) { + Assertions.assertEquals(expected, BigIntegerExt.gcd(n1, n2)); + } + + private static class RationalTestCaseProvider implements ArgumentsProvider { + @Override + public Stream provideArguments(ExtensionContext context) { + return Stream.of( + Arguments.of( + Rational.of(3, 4), Rational.of(1, 5), + Operation.ADD, + Rational.of(19, 20)), + Arguments.of( + Rational.of(1, 2), Rational.of(1, 2), + Operation.ADD, + Rational.of(1, 1)), + Arguments.of( + Rational.of(3, 4), Rational.of(2, 3), + Operation.MULTIPLY, + Rational.of(1, 2)), + Arguments.of( + Rational.of(4, 5), Rational.of(3, 2), + Operation.DIVIDE, + Rational.of(8, 15)), + Arguments.of( + Rational.of(1, 1), Rational.of(2, 1), + Operation.SUBTRACT, + Rational.of(1, -1)) + ); + } + } + + @ParameterizedTest(name="input: {0}, {1}, operation: {2}, expected outcome: {3}") + @ArgumentsSource(RationalTestCaseProvider.class) + void rationalTest(Rational r1, Rational r2, Operation operation, Rational expected) { + Rational result; + switch (operation) { + case ADD: + result = r1.add(r2); + break; + case SUBTRACT: + result = r1.sub(r2); + break; + case MULTIPLY: + result = r1.mul(r2); + break; + case DIVIDE: + result = r1.div(r2); + break; + default: + throw new RuntimeException("This should never happen"); + } + Assertions.assertEquals(expected, result); + } + + + @Test + void sumTest() { + Rational s5 = Rational.of(9, 22); + Rational s6 = Rational.of(3, 22); + Rational s4 = Rational.of(5, 22); + Rational s3 = Rational.of(5, 22); + Assertions.assertEquals(Rational.ONE, s5.add(s6).add(s4).add(s3)); + } +} \ No newline at end of file diff --git a/settings.gradle b/settings.gradle index abdb3b9..4abac85 100644 --- a/settings.gradle +++ b/settings.gradle @@ -31,3 +31,4 @@ dependencyResolutionManagement { rootProject.name = 'jwo' include('benchmark') +include('jmath') diff --git a/src/main/java/net/woggioni/jwo/Hash.java b/src/main/java/net/woggioni/jwo/Hash.java index 4fe1de8..746d377 100644 --- a/src/main/java/net/woggioni/jwo/Hash.java +++ b/src/main/java/net/woggioni/jwo/Hash.java @@ -22,6 +22,11 @@ public class Hash { SHA512("SHA-512"); private final String key; + + @SneakyThrows + public MessageDigest newMessageDigest() { + return MessageDigest.getInstance(key); + } } @Getter diff --git a/src/main/java/net/woggioni/jwo/Requirement.java b/src/main/java/net/woggioni/jwo/Requirement.java new file mode 100644 index 0000000..263dbaf --- /dev/null +++ b/src/main/java/net/woggioni/jwo/Requirement.java @@ -0,0 +1,25 @@ +package net.woggioni.jwo; + +import lombok.AccessLevel; +import lombok.RequiredArgsConstructor; +import lombok.SneakyThrows; + +import java.util.function.Supplier; + +import static net.woggioni.jwo.JWO.newThrowable; + +@RequiredArgsConstructor(access = AccessLevel.PRIVATE) +public class Requirement { + private final Supplier booleanSupplier; + + public static Requirement require(Supplier booleanSupplier) { + return new Requirement(booleanSupplier); + } + + @SneakyThrows + public void otherwise(Class cls, String format, Object... args) { + if(!booleanSupplier.get()) { + throw newThrowable(cls, format, args); + } + } +}