You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

274 regels
8.0 KiB

3 jaren geleden
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "metadata": {},
  7. "outputs": [
  8. {
  9. "ename": "ModuleNotFoundError",
  10. "evalue": "No module named 'surprise'",
  11. "output_type": "error",
  12. "traceback": [
  13. "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
  14. "\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
  15. "\u001b[1;32m<ipython-input-1-002ce27085d1>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[1;32mimport\u001b[0m \u001b[0msurprise\u001b[0m \u001b[1;31m# run 'pip install scikit-surprise' to install surprise\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
  16. "\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'surprise'"
  17. ]
  18. }
  19. ],
  20. "source": [
  21. "import numpy as np\n",
  22. "import surprise # run 'pip install scikit-surprise' to install surprise"
  23. ]
  24. },
  25. {
  26. "cell_type": "code",
  27. "execution_count": 3,
  28. "metadata": {
  29. "collapsed": true
  30. },
  31. "outputs": [],
  32. "source": [
  33. "class MatrixFacto(surprise.AlgoBase):\n",
  34. " '''A basic rating prediction algorithm based on matrix factorization.'''\n",
  35. " \n",
  36. " def __init__(self, learning_rate, n_epochs, n_factors):\n",
  37. " \n",
  38. " self.lr = learning_rate # learning rate for SGD\n",
  39. " self.n_epochs = n_epochs # number of iterations of SGD\n",
  40. " self.n_factors = n_factors # number of factors\n",
  41. " \n",
  42. " def fit(self, trainset):\n",
  43. " '''Learn the vectors p_u and q_i with SGD'''\n",
  44. " \n",
  45. " print('Fitting data with SGD...')\n",
  46. " \n",
  47. " # Randomly initialize the user and item factors.\n",
  48. " p = np.random.normal(0, .1, (trainset.n_users, self.n_factors))\n",
  49. " q = np.random.normal(0, .1, (trainset.n_items, self.n_factors))\n",
  50. " \n",
  51. " # SGD procedure\n",
  52. " for _ in range(self.n_epochs):\n",
  53. " for u, i, r_ui in trainset.all_ratings():\n",
  54. " err = r_ui - np.dot(p[u], q[i])\n",
  55. " # Update vectors p_u and q_i\n",
  56. " p[u] += self.lr * err * q[i]\n",
  57. " q[i] += self.lr * err * p[u]\n",
  58. " # Note: in the update of q_i, we should actually use the previous (non-updated) value of p_u.\n",
  59. " # In practice it makes almost no difference.\n",
  60. " \n",
  61. " self.p, self.q = p, q\n",
  62. " self.trainset = trainset\n",
  63. "\n",
  64. " def estimate(self, u, i):\n",
  65. " '''Return the estmimated rating of user u for item i.'''\n",
  66. " \n",
  67. " # return scalar product between p_u and q_i if user and item are known,\n",
  68. " # else return the average of all ratings\n",
  69. " if self.trainset.knows_user(u) and self.trainset.knows_item(i):\n",
  70. " return np.dot(self.p[u], self.q[i])\n",
  71. " else:\n",
  72. " return self.trainset.global_mean"
  73. ]
  74. },
  75. {
  76. "cell_type": "code",
  77. "execution_count": 11,
  78. "metadata": {
  79. "collapsed": true
  80. },
  81. "outputs": [],
  82. "source": [
  83. "# data loading. We'll use the movielens dataset (https://grouplens.org/datasets/movielens/100k/)\n",
  84. "# it will be downloaded automatically.\n",
  85. "data = surprise.Dataset.load_builtin('ml-100k')\n",
  86. "data.split(2) # split data for 2-folds cross validation"
  87. ]
  88. },
  89. {
  90. "cell_type": "code",
  91. "execution_count": 12,
  92. "metadata": {},
  93. "outputs": [
  94. {
  95. "name": "stdout",
  96. "output_type": "stream",
  97. "text": [
  98. "Evaluating RMSE of algorithm MatrixFacto.\n",
  99. "\n",
  100. "------------\n",
  101. "Fold 1\n",
  102. "Fitting data with SGD...\n",
  103. "RMSE: 0.9826\n",
  104. "------------\n",
  105. "Fold 2\n",
  106. "Fitting data with SGD...\n",
  107. "RMSE: 0.9873\n",
  108. "------------\n",
  109. "------------\n",
  110. "Mean RMSE: 0.9849\n",
  111. "------------\n",
  112. "------------\n"
  113. ]
  114. },
  115. {
  116. "data": {
  117. "text/plain": [
  118. "CaseInsensitiveDefaultDict(list,\n",
  119. " {'rmse': [0.98263312180825368, 0.9872549391926676]})"
  120. ]
  121. },
  122. "execution_count": 12,
  123. "metadata": {},
  124. "output_type": "execute_result"
  125. }
  126. ],
  127. "source": [
  128. "algo = MatrixFacto(learning_rate=.01, n_epochs=10, n_factors=10)\n",
  129. "surprise.evaluate(algo, data, measures=['RMSE'])"
  130. ]
  131. },
  132. {
  133. "cell_type": "code",
  134. "execution_count": 13,
  135. "metadata": {},
  136. "outputs": [
  137. {
  138. "name": "stdout",
  139. "output_type": "stream",
  140. "text": [
  141. "Evaluating RMSE of algorithm KNNBasic.\n",
  142. "\n",
  143. "------------\n",
  144. "Fold 1\n",
  145. "Computing the msd similarity matrix...\n",
  146. "Done computing similarity matrix.\n",
  147. "RMSE: 1.0101\n",
  148. "------------\n",
  149. "Fold 2\n",
  150. "Computing the msd similarity matrix...\n",
  151. "Done computing similarity matrix.\n",
  152. "RMSE: 0.9982\n",
  153. "------------\n",
  154. "------------\n",
  155. "Mean RMSE: 1.0042\n",
  156. "------------\n",
  157. "------------\n"
  158. ]
  159. },
  160. {
  161. "data": {
  162. "text/plain": [
  163. "CaseInsensitiveDefaultDict(list,\n",
  164. " {'rmse': [1.0101383334175613, 0.99823558896449016]})"
  165. ]
  166. },
  167. "execution_count": 13,
  168. "metadata": {},
  169. "output_type": "execute_result"
  170. }
  171. ],
  172. "source": [
  173. "# try a neighborhood-based algorithm (on the same data)\n",
  174. "algo = surprise.KNNBasic()\n",
  175. "surprise.evaluate(algo, data, measures=['RMSE'])"
  176. ]
  177. },
  178. {
  179. "cell_type": "code",
  180. "execution_count": 14,
  181. "metadata": {},
  182. "outputs": [
  183. {
  184. "name": "stdout",
  185. "output_type": "stream",
  186. "text": [
  187. "Evaluating RMSE of algorithm SVD.\n",
  188. "\n",
  189. "------------\n",
  190. "Fold 1\n",
  191. "RMSE: 0.9604\n",
  192. "------------\n",
  193. "Fold 2\n",
  194. "RMSE: 0.9538\n",
  195. "------------\n",
  196. "------------\n",
  197. "Mean RMSE: 0.9571\n",
  198. "------------\n",
  199. "------------\n"
  200. ]
  201. },
  202. {
  203. "data": {
  204. "text/plain": [
  205. "CaseInsensitiveDefaultDict(list,\n",
  206. " {'rmse': [0.96042083843476056,\n",
  207. " 0.95382688332712151]})"
  208. ]
  209. },
  210. "execution_count": 14,
  211. "metadata": {},
  212. "output_type": "execute_result"
  213. }
  214. ],
  215. "source": [
  216. "# try a more sophisticated matrix factorization algorithm (on the same data)\n",
  217. "algo = surprise.SVD()\n",
  218. "surprise.evaluate(algo, data, measures=['RMSE'])"
  219. ]
  220. }
  221. ],
  222. "metadata": {
  223. "kernelspec": {
  224. "display_name": "Python 3",
  225. "language": "python",
  226. "name": "python3"
  227. },
  228. "language_info": {
  229. "codemirror_mode": {
  230. "name": "ipython",
  231. "version": 3
  232. },
  233. "file_extension": ".py",
  234. "mimetype": "text/x-python",
  235. "name": "python",
  236. "nbconvert_exporter": "python",
  237. "pygments_lexer": "ipython3",
  238. "version": "3.7.6"
  239. },
  240. "latex_envs": {
  241. "LaTeX_envs_menu_present": true,
  242. "autoclose": false,
  243. "autocomplete": true,
  244. "bibliofile": "biblio.bib",
  245. "cite_by": "apalike",
  246. "current_citInitial": 1,
  247. "eqLabelWithNumbers": true,
  248. "eqNumInitial": 1,
  249. "hotkeys": {
  250. "equation": "Ctrl-E",
  251. "itemize": "Ctrl-I"
  252. },
  253. "labels_anchors": false,
  254. "latex_user_defs": false,
  255. "report_style_numbering": false,
  256. "user_envs_cfg": false
  257. },
  258. "toc": {
  259. "base_numbering": 1,
  260. "nav_menu": {},
  261. "number_sections": true,
  262. "sideBar": true,
  263. "skip_h1_title": false,
  264. "title_cell": "Table of Contents",
  265. "title_sidebar": "Contents",
  266. "toc_cell": false,
  267. "toc_position": {},
  268. "toc_section_display": true,
  269. "toc_window_display": false
  270. }
  271. },
  272. "nbformat": 4,
  273. "nbformat_minor": 2
  274. }