@@ -523,16 +523,18 @@ def _connection_update(self, **kwargs) -> None:
523523 self .average_buffer_index + 1
524524 ) % self .average_update
525525
526- if self .continues_update :
527- self .feature_value += self .nu [0 ] * torch .mean (
528- self .average_buffer , dim = 0
529- )
530- elif self .average_buffer_index == 0 :
531- self .feature_value += self .nu [0 ] * torch .mean (
526+ if self .continues_update or self .average_buffer_index == 0 :
527+ update = self .nu [0 ] * torch .mean (
532528 self .average_buffer , dim = 0
533529 )
530+ if self .feature_value .is_sparse :
531+ update = update .to_sparse ()
532+ self .feature_value += update
534533 else :
535- self .feature_value += self .nu [0 ] * self .reduction (update , dim = 0 )
534+ update = self .nu [0 ] * self .reduction (update , dim = 0 )
535+ if self .feature_value .is_sparse :
536+ update = update .to_sparse ()
537+ self .feature_value += update
536538
537539 # Update P^+ and P^- values.
538540 self .p_plus *= torch .exp (- self .connection .dt / self .tc_plus )
@@ -701,14 +703,16 @@ def _connection_update(self, **kwargs) -> None:
701703 self .average_buffer_index + 1
702704 ) % self .average_update
703705
704- if self .continues_update :
705- self .feature_value += torch .mean (self .average_buffer , dim = 0 )
706- elif self .average_buffer_index == 0 :
707- self .feature_value += torch .mean (self .average_buffer , dim = 0 )
706+ if self .continues_update or self .average_buffer_index == 0 :
707+ update = torch .mean (self .average_buffer , dim = 0 )
708+ if self .feature_value .is_sparse :
709+ update = update .to_sparse ()
710+ self .feature_value += update
708711 else :
709- self .feature_value += (
710- self .nu [0 ] * self .connection .dt * reward * self .eligibility_trace
711- )
712+ update = self .nu [0 ] * self .connection .dt * reward * self .eligibility_trace
713+ if self .feature_value .is_sparse :
714+ update = update .to_sparse ()
715+ self .feature_value += update
712716
713717 # Update P^+ and P^- values.
714718 self .p_plus *= torch .exp (- self .connection .dt / self .tc_plus ) # Decay
0 commit comments