Skip to content

Commit 59fe29e

Browse files
authored
Merge pull request #2 from anton164/master
Added more tests for conv1d and tiling
2 parents a1c5702 + c62118c commit 59fe29e

2 files changed

Lines changed: 59 additions & 0 deletions

File tree

tests/test_conv.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,51 @@ def test_conv1d_simple():
1616
assert out[0, 0, 2] == 2 * 1 + 3 * 2
1717
assert out[0, 0, 3] == 3 * 1
1818

19+
@pytest.mark.task4_1
20+
def test_conv1d_simple_backward():
21+
input_tensor = minitorch.tensor_fromlist([0, 1, 2, 3]).view(1, 1, 4)
22+
weight = minitorch.tensor_fromlist([[1, 2, 3]]).view(1, 1, 3)
23+
grad_output = minitorch.tensor_fromlist([0, 1, 2, 3]).view(1, 1, 4)
24+
ctx = minitorch.Context()
25+
ctx.save_for_backward(input_tensor, weight)
26+
grad_input, grad_weight = minitorch.Conv1dFun.backward(ctx, grad_output)
27+
28+
assert grad_input[0, 0, 0] == weight[0, 0, 0] * grad_output[0, 0, 0]
29+
assert (
30+
grad_input[0, 0, 1]
31+
== weight[0, 0, 0] * grad_output[0, 0, 1]
32+
+ weight[0, 0, 1] * grad_output[0, 0, 0]
33+
)
34+
assert (
35+
grad_input[0, 0, 2]
36+
== weight[0, 0, 0] * grad_output[0, 0, 2]
37+
+ weight[0, 0, 1] * grad_output[0, 0, 1]
38+
+ weight[0, 0, 2] * grad_output[0, 0, 0]
39+
)
40+
assert (
41+
grad_input[0, 0, 3]
42+
== weight[0, 0, 0] * grad_output[0, 0, 3]
43+
+ weight[0, 0, 1] * grad_output[0, 0, 2]
44+
+ weight[0, 0, 2] * grad_output[0, 0, 1]
45+
)
1946

2047
@pytest.mark.task4_1
2148
@given(tensors(shape=(1, 1, 6)), tensors(shape=(1, 1, 4)))
2249
def test_conv1d(input, weight):
2350
print(input, weight)
2451
minitorch.grad_check(minitorch.Conv1dFun.apply, input, weight)
2552

53+
@pytest.mark.task4_1
54+
def test_conv1d_in_channel():
55+
t = minitorch.tensor_fromlist([[0, 1, 2, 3], [0, 1, 2, 3]]).view(1, 2, 4)
56+
t.requires_grad_(True)
57+
t2 = minitorch.tensor_fromlist([[1, 2, 3], [1, 2, 3]]).view(1, 2, 3)
58+
out = minitorch.Conv1dFun.apply(t, t2)
59+
60+
assert out[0, 0, 0] == (0 * 1 + 1 * 2 + 2 * 3) * 2
61+
assert out[0, 0, 1] == (1 * 1 + 2 * 2 + 3 * 3) * 2
62+
assert out[0, 0, 2] == (2 * 1 + 3 * 2) * 2
63+
assert out[0, 0, 3] == (3 * 1) * 2
2664

2765
@pytest.mark.task4_1
2866
@given(tensors(shape=(2, 2, 6)), tensors(shape=(3, 2, 2)))

tests/test_nn.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,27 @@
33
from .strategies import tensors, assert_close
44
import pytest
55

6+
@pytest.mark.task4_3
7+
def test_tile():
8+
t = minitorch.tensor_fromlist(
9+
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
10+
).view(1, 1, 4, 4)
11+
tiled, _, _ = minitorch.tile(t, (2, 2))
12+
assert tiled[0, 0, 0, 0, 0] == 1
13+
assert tiled[0, 0, 0, 0, 1] == 2
14+
assert tiled[0, 0, 0, 0, 2] == 5
15+
assert tiled[0, 0, 0, 0, 3] == 6
16+
17+
@pytest.mark.task4_3
18+
def test_tile_2():
19+
t = minitorch.tensor_fromlist(
20+
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
21+
).view(1, 1, 4, 4)
22+
tiled, _, _ = minitorch.tile(t, (1, 2))
23+
assert tiled[0, 0, 0, 0, 0] == 1
24+
assert tiled[0, 0, 0, 0, 1] == 2
25+
assert tiled[0, 0, 0, 1, 0] == 3
26+
assert tiled[0, 0, 0, 1, 1] == 4
627

728
@pytest.mark.task4_3
829
@given(tensors(shape=(1, 1, 4, 4)))

0 commit comments

Comments
 (0)