Skip to content

Commit 59df643

Browse files
authored
Update kron tests to explicitly cover static and dynamic shapes (#1912)
1 parent ecd1e07 commit 59df643

1 file changed

Lines changed: 19 additions & 9 deletions

File tree

tests/tensor/test_nlinalg.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -732,36 +732,46 @@ def setup_method(self):
732732
super().setup_method()
733733

734734
def test_vec_vec_kron_raises(self):
735+
"""Ensure kron raises an error for 1D inputs."""
735736
x = vector()
736737
y = vector()
737738
with pytest.raises(
738739
TypeError, match="kron: inputs dimensions must sum to 3 or more"
739740
):
740741
kron(x, y)
741742

743+
@pytest.mark.parametrize("static_shape", [True, False])
742744
@pytest.mark.parametrize("shp0", [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)])
743745
@pytest.mark.parametrize("shp1", [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)])
744-
def test_perform(self, shp0, shp1):
746+
def test_perform(self, static_shape, shp0, shp1):
747+
"""Test kron execution and symbolic shape inference."""
745748
if len(shp0) + len(shp1) == 2:
746749
pytest.skip("Sum of shp0 and shp1 must be more than 2")
747750

748-
x = tensor(dtype="floatX", shape=shp0)
749751
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
750-
751-
y = tensor(dtype="floatX", shape=shp1)
752752
b = self.rng.random(shp1).astype(config.floatX)
753753

754+
# Using np.kron to evaluate expected numerical output and dimensionality
755+
np_val = np.kron(a, b)
756+
757+
# Determine tensor shapes
758+
shape_x = shp0 if static_shape else (None,) * len(shp0)
759+
shape_y = shp1 if static_shape else (None,) * len(shp1)
760+
shape_out = np_val.shape if static_shape else (None,) * np_val.ndim
761+
762+
x = tensor(dtype="floatX", shape=shape_x)
763+
y = tensor(dtype="floatX", shape=shape_y)
764+
754765
kron_xy = kron(x, y)
766+
767+
# Assert symbolic shape inference immediately after node creation
768+
assert kron_xy.type.shape == shape_out
769+
755770
f = function([x, y], kron_xy)
756771
out = f(a, b)
757772

758-
# Using np.kron to compare outputs
759-
np_val = np.kron(a, b)
760773
np.testing.assert_allclose(out, np_val)
761774

762-
# Regression test for issue #1867
763-
assert kron_xy.type.shape == np_val.shape
764-
765775
@pytest.mark.parametrize(
766776
"i, shp0, shp1",
767777
[(0, (2, 3), (6, 7)), (1, (2, 3), (4, 3, 5)), (2, (2, 4, 3), (4, 3, 5))],

0 commit comments

Comments
 (0)