Skip to content

Commit 326d270

Browse files
committed
Add sparse support for other learning rules
1 parent bed0623 commit 326d270

1 file changed

Lines changed: 18 additions & 14 deletions

File tree

bindsnet/learning/MCC_learning.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -523,16 +523,18 @@ def _connection_update(self, **kwargs) -> None:
523523
self.average_buffer_index + 1
524524
) % self.average_update
525525

526-
if self.continues_update:
527-
self.feature_value += self.nu[0] * torch.mean(
528-
self.average_buffer, dim=0
529-
)
530-
elif self.average_buffer_index == 0:
531-
self.feature_value += self.nu[0] * torch.mean(
526+
if self.continues_update or self.average_buffer_index == 0:
527+
update = self.nu[0] * torch.mean(
532528
self.average_buffer, dim=0
533529
)
530+
if self.feature_value.is_sparse:
531+
update = update.to_sparse()
532+
self.feature_value += update
534533
else:
535-
self.feature_value += self.nu[0] * self.reduction(update, dim=0)
534+
update = self.nu[0] * self.reduction(update, dim=0)
535+
if self.feature_value.is_sparse:
536+
update = update.to_sparse()
537+
self.feature_value += update
536538

537539
# Update P^+ and P^- values.
538540
self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
@@ -701,14 +703,16 @@ def _connection_update(self, **kwargs) -> None:
701703
self.average_buffer_index + 1
702704
) % self.average_update
703705

704-
if self.continues_update:
705-
self.feature_value += torch.mean(self.average_buffer, dim=0)
706-
elif self.average_buffer_index == 0:
707-
self.feature_value += torch.mean(self.average_buffer, dim=0)
706+
if self.continues_update or self.average_buffer_index == 0:
707+
update = torch.mean(self.average_buffer, dim=0)
708+
if self.feature_value.is_sparse:
709+
update = update.to_sparse()
710+
self.feature_value += update
708711
else:
709-
self.feature_value += (
710-
self.nu[0] * self.connection.dt * reward * self.eligibility_trace
711-
)
712+
update = self.nu[0] * self.connection.dt * reward * self.eligibility_trace
713+
if self.feature_value.is_sparse:
714+
update = update.to_sparse()
715+
self.feature_value += update
712716

713717
# Update P^+ and P^- values.
714718
self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus) # Decay

0 commit comments

Comments
 (0)