You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

230 lines
8.2 KiB

2 years ago
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 37,
  6. "metadata": {
  7. "scrolled": true
  8. },
  9. "outputs": [
  10. {
  11. "data": {
  12. "text/plain": [
  13. "'1.13.1'"
  14. ]
  15. },
  16. "execution_count": 37,
  17. "metadata": {},
  18. "output_type": "execute_result"
  19. }
  20. ],
  21. "source": [
  22. "import tensorflow as tf\n",
  23. "import tensorflow.examples.tutorials.mnist.input_data as input_data\n",
  24. "\n",
  25. "tf.__version__"
  26. ]
  27. },
  28. {
  29. "cell_type": "code",
  30. "execution_count": 36,
  31. "metadata": {
  32. "scrolled": true
  33. },
  34. "outputs": [
  35. {
  36. "name": "stdout",
  37. "output_type": "stream",
  38. "text": [
  39. "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
  40. "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
  41. "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
  42. "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n"
  43. ]
  44. }
  45. ],
  46. "source": [
  47. "MNIST=input_data.read_data_sets(\"MNIST_data\", one_hot=True)"
  48. ]
  49. },
  50. {
  51. "cell_type": "code",
  52. "execution_count": 34,
  53. "metadata": {},
  54. "outputs": [
  55. {
  56. "name": "stdout",
  57. "output_type": "stream",
  58. "text": [
  59. "0.9149\n"
  60. ]
  61. }
  62. ],
  63. "source": [
  64. "# 定义一个占位符x\n",
  65. "x = tf.placeholder(tf.float32, [None, 784]) # 张量的形状是[None, 784],None表第一个维度任意\n",
  66. "\n",
  67. "# 定义变量W,b,是可以被修改的张量,用来存放机器学习模型参数\n",
  68. "W = tf.Variable(tf.zeros([784, 10]))\n",
  69. "b = tf.Variable(tf.zeros([10]))\n",
  70. "\n",
  71. "# 实现模型, y是预测分布\n",
  72. "y = tf.nn.softmax(tf.matmul(x, W) + b)\n",
  73. "\n",
  74. "# 训练模型,y_是实际分布\n",
  75. "y_ = tf.placeholder(\"float\", [None, 10])\n",
  76. "cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) # 交叉嫡,cost function\n",
  77. "\n",
  78. "# 使用梯度下降来降低cost,学习速率为0.01\n",
  79. "train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)\n",
  80. "\n",
  81. "# 初始化已经创建的变量\n",
  82. "init = tf.global_variables_initializer()\n",
  83. "\n",
  84. "# 在一个Session中启动模型,并初始化变量\n",
  85. "sess = tf.Session()\n",
  86. "sess.run(init)\n",
  87. "\n",
  88. "# # 训练模型,运行1000次,每次随机抽取100个\n",
  89. "for i in range(1, 1000):\n",
  90. " batch_xs, batch_ys = MNIST.train.next_batch(100)\n",
  91. " sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})\n",
  92. "\n",
  93. "# 验证正确率\n",
  94. "correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))\n",
  95. "accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))\n",
  96. "print(sess.run(accuracy, feed_dict={x: MNIST.test.images, y_: MNIST.test.labels}))"
  97. ]
  98. },
  99. {
  100. "cell_type": "code",
  101. "execution_count": 41,
  102. "metadata": {},
  103. "outputs": [
  104. {
  105. "name": "stdout",
  106. "output_type": "stream",
  107. "text": [
  108. "当前迭代次数为0,训练正确率为0.8999999761581421\n",
  109. "当前迭代次数为200,训练正确率为0.9200000166893005\n",
  110. "当前迭代次数为400,训练正确率为0.9100000262260437\n",
  111. "当前迭代次数为600,训练正确率为0.9399999976158142\n",
  112. "当前迭代次数为800,训练正确率为0.8999999761581421\n",
  113. "当前迭代次数为1000,训练正确率为0.8799999952316284\n",
  114. "当前迭代次数为1200,训练正确率为0.9200000166893005\n",
  115. "当前迭代次数为1400,训练正确率为0.9100000262260437\n",
  116. "当前迭代次数为1600,训练正确率为0.9200000166893005\n",
  117. "当前迭代次数为1800,训练正确率为0.8399999737739563\n",
  118. "当前迭代次数为2000,训练正确率为0.8899999856948853\n",
  119. "当前迭代次数为2200,训练正确率为0.8700000047683716\n",
  120. "当前迭代次数为2400,训练正确率为0.8999999761581421\n",
  121. "当前迭代次数为2600,训练正确率为0.8899999856948853\n",
  122. "当前迭代次数为2800,训练正确率为0.8999999761581421\n",
  123. "当前迭代次数为3000,训练正确率为0.9100000262260437\n",
  124. "当前迭代次数为3200,训练正确率为0.949999988079071\n",
  125. "当前迭代次数为3400,训练正确率为0.9300000071525574\n",
  126. "当前迭代次数为3600,训练正确率为0.8999999761581421\n",
  127. "当前迭代次数为3800,训练正确率为0.9100000262260437\n",
  128. "当前迭代次数为4000,训练正确率为0.9700000286102295\n",
  129. "当前迭代次数为4200,训练正确率为0.8899999856948853\n",
  130. "当前迭代次数为4400,训练正确率为0.8999999761581421\n",
  131. "当前迭代次数为4600,训练正确率为0.9300000071525574\n",
  132. "当前迭代次数为4800,训练正确率为0.9399999976158142\n",
  133. "当前迭代次数为5000,训练正确率为0.8899999856948853\n",
  134. "当前迭代次数为5200,训练正确率为0.9100000262260437\n",
  135. "当前迭代次数为5400,训练正确率为0.9100000262260437\n",
  136. "当前迭代次数为5600,训练正确率为0.9100000262260437\n",
  137. "当前迭代次数为5800,训练正确率为0.9300000071525574\n",
  138. "当前迭代次数为6000,训练正确率为0.9200000166893005\n",
  139. "当前迭代次数为6200,训练正确率为0.8700000047683716\n",
  140. "当前迭代次数为6400,训练正确率为0.9200000166893005\n",
  141. "当前迭代次数为6600,训练正确率为0.9399999976158142\n",
  142. "当前迭代次数为6800,训练正确率为0.8999999761581421\n",
  143. "当前迭代次数为7000,训练正确率为0.9300000071525574\n",
  144. "当前迭代次数为7200,训练正确率为0.9300000071525574\n",
  145. "当前迭代次数为7400,训练正确率为0.9300000071525574\n",
  146. "当前迭代次数为7600,训练正确率为0.8799999952316284\n",
  147. "当前迭代次数为7800,训练正确率为0.9100000262260437\n",
  148. "当前迭代次数为8000,训练正确率为0.8999999761581421\n",
  149. "当前迭代次数为8200,训练正确率为0.9599999785423279\n",
  150. "当前迭代次数为8400,训练正确率为0.9300000071525574\n",
  151. "当前迭代次数为8600,训练正确率为0.8799999952316284\n",
  152. "当前迭代次数为8800,训练正确率为0.9200000166893005\n",
  153. "当前迭代次数为9000,训练正确率为0.8600000143051147\n",
  154. "当前迭代次数为9200,训练正确率为0.9700000286102295\n",
  155. "当前迭代次数为9400,训练正确率为0.9200000166893005\n",
  156. "当前迭代次数为9600,训练正确率为0.8899999856948853\n",
  157. "当前迭代次数为9800,训练正确率为0.8799999952316284\n"
  158. ]
  159. }
  160. ],
  161. "source": [
  162. " \n",
  163. "for i in range(10000):\n",
  164. " batch = MNIST.train.next_batch(100)\n",
  165. " if i % 200 == 0:\n",
  166. " print(\"当前迭代次数为{},训练正确率为{}\".format(i, accuracy.eval(feed_dict={x:batch[0], y_: batch[1]},session=sess)))\n",
  167. " train_step.run(feed_dict={x: batch[0], y_: batch[1]},session=sess)"
  168. ]
  169. },
  170. {
  171. "cell_type": "code",
  172. "execution_count": 45,
  173. "metadata": {},
  174. "outputs": [
  175. {
  176. "name": "stdout",
  177. "output_type": "stream",
  178. "text": [
  179. " 测试正确率为 0.912\n"
  180. ]
  181. }
  182. ],
  183. "source": [
  184. "#测试阶段\n",
  185. "print(\" 测试正确率为 %g\"%accuracy.eval(feed_dict={x: MNIST.test.images, y_: MNIST.test.labels},session=sess)) "
  186. ]
  187. },
  188. {
  189. "cell_type": "code",
  190. "execution_count": null,
  191. "metadata": {},
  192. "outputs": [],
  193. "source": []
  194. }
  195. ],
  196. "metadata": {
  197. "kernelspec": {
  198. "display_name": "Python 3",
  199. "language": "python",
  200. "name": "python3"
  201. },
  202. "language_info": {
  203. "codemirror_mode": {
  204. "name": "ipython",
  205. "version": 3
  206. },
  207. "file_extension": ".py",
  208. "mimetype": "text/x-python",
  209. "name": "python",
  210. "nbconvert_exporter": "python",
  211. "pygments_lexer": "ipython3",
  212. "version": "3.7.3"
  213. },
  214. "toc": {
  215. "base_numbering": 1,
  216. "nav_menu": {},
  217. "number_sections": true,
  218. "sideBar": true,
  219. "skip_h1_title": false,
  220. "title_cell": "Table of Contents",
  221. "title_sidebar": "Contents",
  222. "toc_cell": false,
  223. "toc_position": {},
  224. "toc_section_display": true,
  225. "toc_window_display": false
  226. }
  227. },
  228. "nbformat": 4,
  229. "nbformat_minor": 2
  230. }