@@ -732,36 +732,46 @@ def setup_method(self):
732732 super ().setup_method ()
733733
734734 def test_vec_vec_kron_raises (self ):
735+ """Ensure kron raises an error for 1D inputs."""
735736 x = vector ()
736737 y = vector ()
737738 with pytest .raises (
738739 TypeError , match = "kron: inputs dimensions must sum to 3 or more"
739740 ):
740741 kron (x , y )
741742
743+ @pytest .mark .parametrize ("static_shape" , [True , False ])
742744 @pytest .mark .parametrize ("shp0" , [(2 ,), (2 , 3 ), (2 , 3 , 4 ), (2 , 3 , 4 , 5 )])
743745 @pytest .mark .parametrize ("shp1" , [(6 ,), (6 , 7 ), (6 , 7 , 8 ), (6 , 7 , 8 , 9 )])
744- def test_perform (self , shp0 , shp1 ):
746+ def test_perform (self , static_shape , shp0 , shp1 ):
747+ """Test kron execution and symbolic shape inference."""
745748 if len (shp0 ) + len (shp1 ) == 2 :
746749 pytest .skip ("Sum of shp0 and shp1 must be more than 2" )
747750
748- x = tensor (dtype = "floatX" , shape = shp0 )
749751 a = np .asarray (self .rng .random (shp0 )).astype (config .floatX )
750-
751- y = tensor (dtype = "floatX" , shape = shp1 )
752752 b = self .rng .random (shp1 ).astype (config .floatX )
753753
754+ # Using np.kron to evaluate expected numerical output and dimensionality
755+ np_val = np .kron (a , b )
756+
757+ # Determine tensor shapes
758+ shape_x = shp0 if static_shape else (None ,) * len (shp0 )
759+ shape_y = shp1 if static_shape else (None ,) * len (shp1 )
760+ shape_out = np_val .shape if static_shape else (None ,) * np_val .ndim
761+
762+ x = tensor (dtype = "floatX" , shape = shape_x )
763+ y = tensor (dtype = "floatX" , shape = shape_y )
764+
754765 kron_xy = kron (x , y )
766+
767+ # Assert symbolic shape inference immediately after node creation
768+ assert kron_xy .type .shape == shape_out
769+
755770 f = function ([x , y ], kron_xy )
756771 out = f (a , b )
757772
758- # Using np.kron to compare outputs
759- np_val = np .kron (a , b )
760773 np .testing .assert_allclose (out , np_val )
761774
762- # Regression test for issue #1867
763- assert kron_xy .type .shape == np_val .shape
764-
765775 @pytest .mark .parametrize (
766776 "i, shp0, shp1" ,
767777 [(0 , (2 , 3 ), (6 , 7 )), (1 , (2 , 3 ), (4 , 3 , 5 )), (2 , (2 , 4 , 3 ), (4 , 3 , 5 ))],
0 commit comments