@@ -16,13 +16,51 @@ def test_conv1d_simple():
1616 assert out [0 , 0 , 2 ] == 2 * 1 + 3 * 2
1717 assert out [0 , 0 , 3 ] == 3 * 1
1818
19+ @pytest .mark .task4_1
20+ def test_conv1d_simple_backward ():
21+ input_tensor = minitorch .tensor_fromlist ([0 , 1 , 2 , 3 ]).view (1 , 1 , 4 )
22+ weight = minitorch .tensor_fromlist ([[1 , 2 , 3 ]]).view (1 , 1 , 3 )
23+ grad_output = minitorch .tensor_fromlist ([0 , 1 , 2 , 3 ]).view (1 , 1 , 4 )
24+ ctx = minitorch .Context ()
25+ ctx .save_for_backward (input_tensor , weight )
26+ grad_input , grad_weight = minitorch .Conv1dFun .backward (ctx , grad_output )
27+
28+ assert grad_input [0 , 0 , 0 ] == weight [0 , 0 , 0 ] * grad_output [0 , 0 , 0 ]
29+ assert (
30+ grad_input [0 , 0 , 1 ]
31+ == weight [0 , 0 , 0 ] * grad_output [0 , 0 , 1 ]
32+ + weight [0 , 0 , 1 ] * grad_output [0 , 0 , 0 ]
33+ )
34+ assert (
35+ grad_input [0 , 0 , 2 ]
36+ == weight [0 , 0 , 0 ] * grad_output [0 , 0 , 2 ]
37+ + weight [0 , 0 , 1 ] * grad_output [0 , 0 , 1 ]
38+ + weight [0 , 0 , 2 ] * grad_output [0 , 0 , 0 ]
39+ )
40+ assert (
41+ grad_input [0 , 0 , 3 ]
42+ == weight [0 , 0 , 0 ] * grad_output [0 , 0 , 3 ]
43+ + weight [0 , 0 , 1 ] * grad_output [0 , 0 , 2 ]
44+ + weight [0 , 0 , 2 ] * grad_output [0 , 0 , 1 ]
45+ )
1946
2047@pytest .mark .task4_1
2148@given (tensors (shape = (1 , 1 , 6 )), tensors (shape = (1 , 1 , 4 )))
2249def test_conv1d (input , weight ):
2350 print (input , weight )
2451 minitorch .grad_check (minitorch .Conv1dFun .apply , input , weight )
2552
53+ @pytest .mark .task4_1
54+ def test_conv1d_in_channel ():
55+ t = minitorch .tensor_fromlist ([[0 , 1 , 2 , 3 ], [0 , 1 , 2 , 3 ]]).view (1 , 2 , 4 )
56+ t .requires_grad_ (True )
57+ t2 = minitorch .tensor_fromlist ([[1 , 2 , 3 ], [1 , 2 , 3 ]]).view (1 , 2 , 3 )
58+ out = minitorch .Conv1dFun .apply (t , t2 )
59+
60+ assert out [0 , 0 , 0 ] == (0 * 1 + 1 * 2 + 2 * 3 ) * 2
61+ assert out [0 , 0 , 1 ] == (1 * 1 + 2 * 2 + 3 * 3 ) * 2
62+ assert out [0 , 0 , 2 ] == (2 * 1 + 3 * 2 ) * 2
63+ assert out [0 , 0 , 3 ] == (3 * 1 ) * 2
2664
2765@pytest .mark .task4_1
2866@given (tensors (shape = (2 , 2 , 6 )), tensors (shape = (3 , 2 , 2 )))
0 commit comments