You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

230 lines
8.2 KiB

{
"cells": [
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"'1.13.1'"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import tensorflow as tf\n",
"import tensorflow.examples.tutorials.mnist.input_data as input_data\n",
"\n",
"tf.__version__"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
"Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
"Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
"Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n"
]
}
],
"source": [
"MNIST=input_data.read_data_sets(\"MNIST_data\", one_hot=True)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9149\n"
]
}
],
"source": [
"# 定义一个占位符x\n",
"x = tf.placeholder(tf.float32, [None, 784]) # 张量的形状是[None, 784],None表第一个维度任意\n",
"\n",
"# 定义变量W,b,是可以被修改的张量,用来存放机器学习模型参数\n",
"W = tf.Variable(tf.zeros([784, 10]))\n",
"b = tf.Variable(tf.zeros([10]))\n",
"\n",
"# 实现模型, y是预测分布\n",
"y = tf.nn.softmax(tf.matmul(x, W) + b)\n",
"\n",
"# 训练模型,y_是实际分布\n",
"y_ = tf.placeholder(\"float\", [None, 10])\n",
"cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) # 交叉嫡,cost function\n",
"\n",
"# 使用梯度下降来降低cost,学习速率为0.01\n",
"train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)\n",
"\n",
"# 初始化已经创建的变量\n",
"init = tf.global_variables_initializer()\n",
"\n",
"# 在一个Session中启动模型,并初始化变量\n",
"sess = tf.Session()\n",
"sess.run(init)\n",
"\n",
"# # 训练模型,运行1000次,每次随机抽取100个\n",
"for i in range(1, 1000):\n",
" batch_xs, batch_ys = MNIST.train.next_batch(100)\n",
" sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})\n",
"\n",
"# 验证正确率\n",
"correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))\n",
"accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))\n",
"print(sess.run(accuracy, feed_dict={x: MNIST.test.images, y_: MNIST.test.labels}))"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"当前迭代次数为0,训练正确率为0.8999999761581421\n",
"当前迭代次数为200,训练正确率为0.9200000166893005\n",
"当前迭代次数为400,训练正确率为0.9100000262260437\n",
"当前迭代次数为600,训练正确率为0.9399999976158142\n",
"当前迭代次数为800,训练正确率为0.8999999761581421\n",
"当前迭代次数为1000,训练正确率为0.8799999952316284\n",
"当前迭代次数为1200,训练正确率为0.9200000166893005\n",
"当前迭代次数为1400,训练正确率为0.9100000262260437\n",
"当前迭代次数为1600,训练正确率为0.9200000166893005\n",
"当前迭代次数为1800,训练正确率为0.8399999737739563\n",
"当前迭代次数为2000,训练正确率为0.8899999856948853\n",
"当前迭代次数为2200,训练正确率为0.8700000047683716\n",
"当前迭代次数为2400,训练正确率为0.8999999761581421\n",
"当前迭代次数为2600,训练正确率为0.8899999856948853\n",
"当前迭代次数为2800,训练正确率为0.8999999761581421\n",
"当前迭代次数为3000,训练正确率为0.9100000262260437\n",
"当前迭代次数为3200,训练正确率为0.949999988079071\n",
"当前迭代次数为3400,训练正确率为0.9300000071525574\n",
"当前迭代次数为3600,训练正确率为0.8999999761581421\n",
"当前迭代次数为3800,训练正确率为0.9100000262260437\n",
"当前迭代次数为4000,训练正确率为0.9700000286102295\n",
"当前迭代次数为4200,训练正确率为0.8899999856948853\n",
"当前迭代次数为4400,训练正确率为0.8999999761581421\n",
"当前迭代次数为4600,训练正确率为0.9300000071525574\n",
"当前迭代次数为4800,训练正确率为0.9399999976158142\n",
"当前迭代次数为5000,训练正确率为0.8899999856948853\n",
"当前迭代次数为5200,训练正确率为0.9100000262260437\n",
"当前迭代次数为5400,训练正确率为0.9100000262260437\n",
"当前迭代次数为5600,训练正确率为0.9100000262260437\n",
"当前迭代次数为5800,训练正确率为0.9300000071525574\n",
"当前迭代次数为6000,训练正确率为0.9200000166893005\n",
"当前迭代次数为6200,训练正确率为0.8700000047683716\n",
"当前迭代次数为6400,训练正确率为0.9200000166893005\n",
"当前迭代次数为6600,训练正确率为0.9399999976158142\n",
"当前迭代次数为6800,训练正确率为0.8999999761581421\n",
"当前迭代次数为7000,训练正确率为0.9300000071525574\n",
"当前迭代次数为7200,训练正确率为0.9300000071525574\n",
"当前迭代次数为7400,训练正确率为0.9300000071525574\n",
"当前迭代次数为7600,训练正确率为0.8799999952316284\n",
"当前迭代次数为7800,训练正确率为0.9100000262260437\n",
"当前迭代次数为8000,训练正确率为0.8999999761581421\n",
"当前迭代次数为8200,训练正确率为0.9599999785423279\n",
"当前迭代次数为8400,训练正确率为0.9300000071525574\n",
"当前迭代次数为8600,训练正确率为0.8799999952316284\n",
"当前迭代次数为8800,训练正确率为0.9200000166893005\n",
"当前迭代次数为9000,训练正确率为0.8600000143051147\n",
"当前迭代次数为9200,训练正确率为0.9700000286102295\n",
"当前迭代次数为9400,训练正确率为0.9200000166893005\n",
"当前迭代次数为9600,训练正确率为0.8899999856948853\n",
"当前迭代次数为9800,训练正确率为0.8799999952316284\n"
]
}
],
"source": [
" \n",
"for i in range(10000):\n",
" batch = MNIST.train.next_batch(100)\n",
" if i % 200 == 0:\n",
" print(\"当前迭代次数为{},训练正确率为{}\".format(i, accuracy.eval(feed_dict={x:batch[0], y_: batch[1]},session=sess)))\n",
" train_step.run(feed_dict={x: batch[0], y_: batch[1]},session=sess)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 测试正确率为 0.912\n"
]
}
],
"source": [
"#测试阶段\n",
"print(\" 测试正确率为 %g\"%accuracy.eval(feed_dict={x: MNIST.test.images, y_: MNIST.test.labels},session=sess)) "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}