package cljtest.linear; import base.Asserts; import base.ExtendedRandom; import base.TestCounter; import cljtest.ClojureScript; import clojure.lang.IPersistentVector; import common.Engine; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.function.*; import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; /** * Clojure bridge. * * @author Georgiy Korneev (kgeorgiy@kgeorgiy.info) */ public interface Item { Item ZERO = value(0); Item ONE = value(1); int dim(); boolean isValid(); Item refill(ExtendedRandom random); Engine.Result toClojure(); default Value mapValue(final DoubleUnaryOperator f) { return value(f.applyAsDouble(value())); } default Vector map(final Function f) { throw new UnsupportedOperationException("map"); } default int size() { throw new UnsupportedOperationException("size"); } default Item get(final int index) { throw new UnsupportedOperationException("get"); } default double value() { throw new UnsupportedOperationException("getValue"); } static Stream args(final int argc, final Item shape, final ExtendedRandom random) { return Stream.generate(() -> shape.refill(random)).limit(argc); } static Item fromClojure(final Object value) { if (value instanceof Number n) { return value(n.doubleValue()); } else if (value instanceof IPersistentVector vector) { return vector(IntStream.range(0, vector.length()).mapToObj(vector::nth).map(Item::fromClojure)); } else { throw new AssertionError(value == null ? "null result" : "Unknown type " + value.getClass().getSimpleName()); } } static Vector vector(final Stream items) { return new Vector(items(items)); } static Value value(final double value) { return new Value(value); } static List items(final Stream items) { return items.collect(Collectors.toUnmodifiableList()); } static Supplier generator(final int... dims) { Supplier generator = () -> ZERO; for (int i = dims.length - 1; i >= 0; i--) { final int dim = dims[i]; final Supplier gen = generator; generator = () -> vector(Stream.generate(gen).limit(dim)); } return generator; } static IntFunction> same(final Supplier generator) { return same(generator.get()); } static IntFunction> same(final Item shape) { return n -> Collections.nCopies(n, shape); } static Engine.Result[] toClojure(final List args) { return toArray(args.stream().map(Item::toClojure)); } static Engine.Result[] toArray(final Stream> resultStream) { return resultStream.toArray(Engine.Result[]::new); } static List functions(final String prefix) { return functions(prefix, Operation.values()); } static List functions(final String prefix, final Operation... ops) { return Arrays.stream(ops).map(op -> op.function(prefix)).toList(); } record Value(double value) implements Item { public boolean isValid() { return Double.isFinite(value); } @Override public int dim() { return 0; } @Override public Value refill(final ExtendedRandom random) { return new Value(random.nextInt(1, 99) / 10.0); } @Override public Engine.Result toClojure() { return LinearTester.number(value); } @Override public boolean equals(final Object obj) { return obj instanceof Value v && Asserts.isEqual(value, v.value, 1e-7); } @Override public String toString() { return Double.toString(value); } } final class Vector implements Item { private final List items; private final int dim; private Vector(final List items) { this.items = items; dim = items.stream().mapToInt(Item::dim).max().orElse(0) + 1; } @Override public boolean isValid() { return items.stream().allMatch(Item::isValid); } @Override public int dim() { return dim; } public int size() { return items.size(); } public Item get(final int index) { return items.get(index); } @Override public Vector refill(final ExtendedRandom random) { return vector(items.stream().map(item -> item.refill(random))); } @Override public Engine.Result toClojure() { return LinearTester.vector(toArray(items.stream().map(Item::toClojure))); } @Override public boolean equals(final Object obj) { return obj instanceof Vector v && items.equals(v.items); } @Override public String toString() { return items.stream().map(Item::toString).collect(Collectors.joining(", ", "[", "]")); } @Override public Vector map(final Function f) { return vector(items.stream().map(f)); } } class Fun { private final Function, Item> expected; private final ClojureScript.F actual; public Fun(final String name, final Function, Item> implementation) { expected = implementation; actual = ClojureScript.function(name, Object.class); } public void test(final TestCounter counter, final Stream argStream) { final List args = items(argStream); test(counter, args, args); } public void test(final TestCounter counter, final List args, final List fakeArgs) { final Item expected = this.expected.apply(fakeArgs); // if (!expected.isValid()) { // return; // } test(counter, () -> { final Engine.Result result; try { result = actual.call(toClojure(args)); } catch (final RuntimeException | AssertionError e) { throw new AssertionError("No error expected for " + actual.callToString(toClojure(args)), e); } final Item actual = fromClojure(result.value()); if (!expected.equals(actual)) { throw new AssertionError(result.context() + ": expected " + expected + ", found " + actual); } }); // System.err.println("Testing? " + result.context); } private static void test(final TestCounter counter, final Runnable action) { counter.test(() -> { if (counter.getTestNo() % 1000 == 0) { counter.println("Test " + counter.getTestNo()); } action.run(); }); } public void test(final int args, final Item shape, final TestCounter counter, final ExtendedRandom random) { test(args, Item.same(shape), counter, random); } public void test(final int args, final IntFunction> shapes, final TestCounter counter, final ExtendedRandom random) { test(shapes.apply(args), counter, random); } public void test(final List shapes, final TestCounter counter, final ExtendedRandom random) { test(counter, shapes.stream().map(shape -> shape.refill(random))); } public void expectException(final TestCounter counter, final Stream items) { expectException(counter, toClojure(items.toList())); } protected void expectException(final TestCounter counter, final Engine.Result... args) { test(counter, () -> { final Engine.Result result = actual.expectException(args); final boolean ok = result.value() instanceof AssertionError; if (!ok) { result.value().printStackTrace(); } Asserts.assertTrue( "AssertionError expected instead of " + result.value() + " in " + result.context(), ok ); }); } } enum Operation { ADD("+", (a, b) -> a + b, a -> a, ZERO), SUB("-", (a, b) -> a - b, a -> -a, ZERO), MUL("*", (a, b) -> a * b, a -> a, ONE), DIV("d", (a, b) -> a / b, a -> 1 / a, ONE); private final String suffix; private final DoubleBinaryOperator binary; private final DoubleUnaryOperator unary; private final Item neutral; Operation(final String suffix, final DoubleBinaryOperator binary, final DoubleUnaryOperator unary, final Item neutral ) { this.suffix = suffix; this.binary = binary; this.unary = unary; this.neutral = neutral; } public String suffix() { return suffix; } public DoubleBinaryOperator binary() { return binary; } public DoubleUnaryOperator unary() { return unary; } public Item neutral() { return neutral; } public Item apply(final List args) { final Item first = args.get(0); if (first instanceof Value) { return value(args.size() == 1 ? unary.applyAsDouble(first.value()) : args.stream().map(Value.class::cast).mapToDouble(Value::value).reduce(binary).getAsDouble()); } else { return vector(IntStream.range(0, first.size()) .mapToObj(i -> apply(items(args.stream().map(Vector.class::cast).map(arg -> arg.get(i)))))); } } private Fun function(final String prefix) { return new Fun(prefix + suffix, this::apply); } } }