DaSE-Computer-Vision-2021
您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

65 行
2.2 KiB

  1. from builtins import range
  2. import numpy as np
  3. from random import shuffle
  4. from past.builtins import xrange
  5. def softmax_loss_naive(W, X, y, reg):
  6. """
  7. Softmax loss function, naive implementation (with loops)
  8. Inputs have dimension D, there are C classes, and we operate on minibatches
  9. of N examples.
  10. Inputs:
  11. - W: A numpy array of shape (D, C) containing weights.
  12. - X: A numpy array of shape (N, D) containing a minibatch of data.
  13. - y: A numpy array of shape (N,) containing training labels; y[i] = c means
  14. that X[i] has label c, where 0 <= c < C.
  15. - reg: (float) regularization strength
  16. Returns a tuple of:
  17. - loss as single float
  18. - gradient with respect to weights W; an array of same shape as W
  19. """
  20. # Initialize the loss and gradient to zero.
  21. loss = 0.0
  22. dW = np.zeros_like(W)
  23. #############################################################################
  24. # TODO: 使用显式循环计算softmax损失及其梯度。
  25. # 将损失和梯度分别保存在loss和dW中。
  26. # 如果你不小心,很容易遇到数值不稳定的情况。
  27. # 不要忘了正则化!
  28. #############################################################################
  29. # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
  30. pass
  31. # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
  32. return loss, dW
  33. def softmax_loss_vectorized(W, X, y, reg):
  34. """
  35. Softmax loss function, vectorized version.
  36. Inputs and outputs are the same as softmax_loss_naive.
  37. """
  38. # Initialize the loss and gradient to zero.
  39. loss = 0.0
  40. dW = np.zeros_like(W)
  41. #############################################################################
  42. # TODO: 不使用显式循环计算softmax损失及其梯度。
  43. # 将损失和梯度分别保存在loss和dW中。
  44. # 如果你不小心,很容易遇到数值不稳定的情况。
  45. # 不要忘了正则化!
  46. #############################################################################
  47. # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
  48. pass
  49. # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
  50. return loss, dW