|
4 | 4 | import torch |
5 | 5 | from scipy.spatial.distance import euclidean |
6 | 6 | from torch.nn.modules.utils import _pair |
| 7 | +from torch import device |
7 | 8 |
|
8 | 9 | from bindsnet.learning import PostPre |
| 10 | +from bindsnet.learning.MCC_learning import PostPre as MMCPostPre |
9 | 11 | from bindsnet.network import Network |
10 | 12 | from bindsnet.network.nodes import DiehlAndCookNodes, Input, LIFNodes |
11 | | -from bindsnet.network.topology import Connection, LocalConnection |
| 13 | +from bindsnet.network.topology import Connection, LocalConnection, MulticompartmentConnection |
| 14 | +from bindsnet.network.topology_features import Weight |
12 | 15 |
|
13 | 16 |
|
14 | 17 | class TwoLayerNetwork(Network): |
@@ -94,6 +97,9 @@ class DiehlAndCook2015(Network): |
94 | 97 | def __init__( |
95 | 98 | self, |
96 | 99 | n_inpt: int, |
| 100 | + device: device, |
| 101 | + batch_size: int, |
| 102 | + sparse: bool = False, |
97 | 103 | n_neurons: int = 100, |
98 | 104 | exc: float = 22.5, |
99 | 105 | inh: float = 17.5, |
@@ -169,28 +175,61 @@ def __init__( |
169 | 175 | ) |
170 | 176 |
|
171 | 177 | # Connections |
172 | | - w = 0.3 * torch.rand(self.n_inpt, self.n_neurons) |
173 | | - input_exc_conn = Connection( |
| 178 | + if sparse: |
| 179 | + w = 0.3 * torch.rand(batch_size, self.n_inpt, self.n_neurons) |
| 180 | + else: |
| 181 | + w = 0.3 * torch.rand(self.n_inpt, self.n_neurons) |
| 182 | + input_exc_conn = MulticompartmentConnection( |
174 | 183 | source=input_layer, |
175 | 184 | target=exc_layer, |
176 | | - w=w, |
177 | | - update_rule=PostPre, |
178 | | - nu=nu, |
179 | | - reduction=reduction, |
180 | | - wmin=wmin, |
181 | | - wmax=wmax, |
182 | | - norm=norm, |
| 185 | + device=device, |
| 186 | + pipeline=[ |
| 187 | + Weight( |
| 188 | + 'weight', |
| 189 | + w, |
| 190 | + range=[wmin, wmax], |
| 191 | + norm=norm, |
| 192 | + reduction=reduction, |
| 193 | + nu=nu, |
| 194 | + learning_rule=MMCPostPre, |
| 195 | + sparse=sparse |
| 196 | + ) |
| 197 | + ] |
183 | 198 | ) |
184 | 199 | w = self.exc * torch.diag(torch.ones(self.n_neurons)) |
185 | | - exc_inh_conn = Connection( |
186 | | - source=exc_layer, target=inh_layer, w=w, wmin=0, wmax=self.exc |
| 200 | + if sparse: |
| 201 | + w = w.unsqueeze(0).expand(batch_size, -1, -1) |
| 202 | + exc_inh_conn = MulticompartmentConnection( |
| 203 | + source=exc_layer, |
| 204 | + target=inh_layer, |
| 205 | + device=device, |
| 206 | + pipeline=[ |
| 207 | + Weight( |
| 208 | + 'weight', |
| 209 | + w, |
| 210 | + range=[0, self.exc], |
| 211 | + sparse=sparse |
| 212 | + ) |
| 213 | + ] |
187 | 214 | ) |
188 | 215 | w = -self.inh * ( |
189 | 216 | torch.ones(self.n_neurons, self.n_neurons) |
190 | 217 | - torch.diag(torch.ones(self.n_neurons)) |
191 | 218 | ) |
192 | | - inh_exc_conn = Connection( |
193 | | - source=inh_layer, target=exc_layer, w=w, wmin=-self.inh, wmax=0 |
| 219 | + if sparse: |
| 220 | + w = w.unsqueeze(0).expand(batch_size, -1, -1) |
| 221 | + inh_exc_conn = MulticompartmentConnection( |
| 222 | + source=inh_layer, |
| 223 | + target=exc_layer, |
| 224 | + device=device, |
| 225 | + pipeline=[ |
| 226 | + Weight( |
| 227 | + 'weight', |
| 228 | + w, |
| 229 | + range=[-self.inh, 0], |
| 230 | + sparse=sparse |
| 231 | + ) |
| 232 | + ] |
194 | 233 | ) |
195 | 234 |
|
196 | 235 | # Add to network |
|
0 commit comments