|
|
- {
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 37,
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "'1.13.1'"
- ]
- },
- "execution_count": 37,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "import tensorflow as tf\n",
- "import tensorflow.examples.tutorials.mnist.input_data as input_data\n",
- "\n",
- "tf.__version__"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 36,
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
- "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
- "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
- "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n"
- ]
- }
- ],
- "source": [
- "MNIST=input_data.read_data_sets(\"MNIST_data\", one_hot=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 34,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "0.9149\n"
- ]
- }
- ],
- "source": [
- "# 定义一个占位符x\n",
- "x = tf.placeholder(tf.float32, [None, 784]) # 张量的形状是[None, 784],None表第一个维度任意\n",
- "\n",
- "# 定义变量W,b,是可以被修改的张量,用来存放机器学习模型参数\n",
- "W = tf.Variable(tf.zeros([784, 10]))\n",
- "b = tf.Variable(tf.zeros([10]))\n",
- "\n",
- "# 实现模型, y是预测分布\n",
- "y = tf.nn.softmax(tf.matmul(x, W) + b)\n",
- "\n",
- "# 训练模型,y_是实际分布\n",
- "y_ = tf.placeholder(\"float\", [None, 10])\n",
- "cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) # 交叉嫡,cost function\n",
- "\n",
- "# 使用梯度下降来降低cost,学习速率为0.01\n",
- "train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)\n",
- "\n",
- "# 初始化已经创建的变量\n",
- "init = tf.global_variables_initializer()\n",
- "\n",
- "# 在一个Session中启动模型,并初始化变量\n",
- "sess = tf.Session()\n",
- "sess.run(init)\n",
- "\n",
- "# # 训练模型,运行1000次,每次随机抽取100个\n",
- "for i in range(1, 1000):\n",
- " batch_xs, batch_ys = MNIST.train.next_batch(100)\n",
- " sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})\n",
- "\n",
- "# 验证正确率\n",
- "correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))\n",
- "accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))\n",
- "print(sess.run(accuracy, feed_dict={x: MNIST.test.images, y_: MNIST.test.labels}))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 41,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "当前迭代次数为0,训练正确率为0.8999999761581421\n",
- "当前迭代次数为200,训练正确率为0.9200000166893005\n",
- "当前迭代次数为400,训练正确率为0.9100000262260437\n",
- "当前迭代次数为600,训练正确率为0.9399999976158142\n",
- "当前迭代次数为800,训练正确率为0.8999999761581421\n",
- "当前迭代次数为1000,训练正确率为0.8799999952316284\n",
- "当前迭代次数为1200,训练正确率为0.9200000166893005\n",
- "当前迭代次数为1400,训练正确率为0.9100000262260437\n",
- "当前迭代次数为1600,训练正确率为0.9200000166893005\n",
- "当前迭代次数为1800,训练正确率为0.8399999737739563\n",
- "当前迭代次数为2000,训练正确率为0.8899999856948853\n",
- "当前迭代次数为2200,训练正确率为0.8700000047683716\n",
- "当前迭代次数为2400,训练正确率为0.8999999761581421\n",
- "当前迭代次数为2600,训练正确率为0.8899999856948853\n",
- "当前迭代次数为2800,训练正确率为0.8999999761581421\n",
- "当前迭代次数为3000,训练正确率为0.9100000262260437\n",
- "当前迭代次数为3200,训练正确率为0.949999988079071\n",
- "当前迭代次数为3400,训练正确率为0.9300000071525574\n",
- "当前迭代次数为3600,训练正确率为0.8999999761581421\n",
- "当前迭代次数为3800,训练正确率为0.9100000262260437\n",
- "当前迭代次数为4000,训练正确率为0.9700000286102295\n",
- "当前迭代次数为4200,训练正确率为0.8899999856948853\n",
- "当前迭代次数为4400,训练正确率为0.8999999761581421\n",
- "当前迭代次数为4600,训练正确率为0.9300000071525574\n",
- "当前迭代次数为4800,训练正确率为0.9399999976158142\n",
- "当前迭代次数为5000,训练正确率为0.8899999856948853\n",
- "当前迭代次数为5200,训练正确率为0.9100000262260437\n",
- "当前迭代次数为5400,训练正确率为0.9100000262260437\n",
- "当前迭代次数为5600,训练正确率为0.9100000262260437\n",
- "当前迭代次数为5800,训练正确率为0.9300000071525574\n",
- "当前迭代次数为6000,训练正确率为0.9200000166893005\n",
- "当前迭代次数为6200,训练正确率为0.8700000047683716\n",
- "当前迭代次数为6400,训练正确率为0.9200000166893005\n",
- "当前迭代次数为6600,训练正确率为0.9399999976158142\n",
- "当前迭代次数为6800,训练正确率为0.8999999761581421\n",
- "当前迭代次数为7000,训练正确率为0.9300000071525574\n",
- "当前迭代次数为7200,训练正确率为0.9300000071525574\n",
- "当前迭代次数为7400,训练正确率为0.9300000071525574\n",
- "当前迭代次数为7600,训练正确率为0.8799999952316284\n",
- "当前迭代次数为7800,训练正确率为0.9100000262260437\n",
- "当前迭代次数为8000,训练正确率为0.8999999761581421\n",
- "当前迭代次数为8200,训练正确率为0.9599999785423279\n",
- "当前迭代次数为8400,训练正确率为0.9300000071525574\n",
- "当前迭代次数为8600,训练正确率为0.8799999952316284\n",
- "当前迭代次数为8800,训练正确率为0.9200000166893005\n",
- "当前迭代次数为9000,训练正确率为0.8600000143051147\n",
- "当前迭代次数为9200,训练正确率为0.9700000286102295\n",
- "当前迭代次数为9400,训练正确率为0.9200000166893005\n",
- "当前迭代次数为9600,训练正确率为0.8899999856948853\n",
- "当前迭代次数为9800,训练正确率为0.8799999952316284\n"
- ]
- }
- ],
- "source": [
- " \n",
- "for i in range(10000):\n",
- " batch = MNIST.train.next_batch(100)\n",
- " if i % 200 == 0:\n",
- " print(\"当前迭代次数为{},训练正确率为{}\".format(i, accuracy.eval(feed_dict={x:batch[0], y_: batch[1]},session=sess)))\n",
- " train_step.run(feed_dict={x: batch[0], y_: batch[1]},session=sess)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 45,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " 测试正确率为 0.912\n"
- ]
- }
- ],
- "source": [
- "#测试阶段\n",
- "print(\" 测试正确率为 %g\"%accuracy.eval(feed_dict={x: MNIST.test.images, y_: MNIST.test.labels},session=sess)) "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.7.3"
- },
- "toc": {
- "base_numbering": 1,
- "nav_menu": {},
- "number_sections": true,
- "sideBar": true,
- "skip_h1_title": false,
- "title_cell": "Table of Contents",
- "title_sidebar": "Contents",
- "toc_cell": false,
- "toc_position": {},
- "toc_section_display": true,
- "toc_window_display": false
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
|