-
Notifications
You must be signed in to change notification settings - Fork 25
Description
He descargado este notebook https://phuijse.github.io/MachineLearningBook/contents/neural_networks/torch-training.html para correrlo en Jupyter, y me sale el siguiente error en la celda 25:
RuntimeError Traceback (most recent call last)
File :12
Cell In[24], line 17, in train_one_epoch(epoch)
15 train_loss, valid_loss = 0.0, 0.0
16 for batchx, batchy in train_loader:
---> 17 train_loss += update_step(batchx, batchy)
18 for batchx, batchy in valid_loader:
19 valid_loss += evaluate_step(batchx, batchy)
Cell In[24], line 4, in update_step(data, label)
2 prediction = model(data)
3 optimizer.zero_grad()
----> 4 loss = criterion(prediction, label)
5 loss.backward()
6 optimizer.step()
File ~\AppData\Roaming\Python\Python311\site-packages\torch\nn\modules\module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ~\AppData\Roaming\Python\Python311\site-packages\torch\nn\modules\module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File ~\AppData\Roaming\Python\Python311\site-packages\torch\nn\modules\loss.py:1185, in CrossEntropyLoss.forward(self, input, target)
1184 def forward(self, input: Tensor, target: Tensor) -> Tensor:
-> 1185 return F.cross_entropy(input, target, weight=self.weight,
1186 ignore_index=self.ignore_index, reduction=self.reduction,
1187 label_smoothing=self.label_smoothing)
File ~\AppData\Roaming\Python\Python311\site-packages\torch\nn\functional.py:3086, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
3084 if size_average is not None or reduce is not None:
3085 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3086 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: expected scalar type Long but found Int
No he hecho ningún cambio, y tampoco entiendo muy bien la razón del error.