You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

606 rivejä
69 KiB

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"mnist = tf.keras.datasets.mnist\n",
"(x_train, y_train),(x_test, y_test) = mnist.load_data(path=\"/data/data/mnist.npz\") #加载mnist数据集\n",
"\n",
"#验证mnist数据集大小。x为数据,y为标签。mnist每张图的像素为28*28\n",
"print(x_train.shape)\n",
"print(y_train.shape)\n",
"print(x_test.shape)\n",
"print(y_test.shape)\n",
"\n",
"#打印训练集中前9张,看看是什么数字\n",
"for i in range(9): \n",
" plt.subplot(3,3,1+i)\n",
" plt.imshow(x_train[i], cmap='gray')\n",
"plt.show()\n",
"\n",
"#打印相应的标签\n",
"print(y_train[:9])\n",
"\n",
"#基操:将像素标准化一下\n",
"x_train, x_test = x_train / 255.0, x_test / 255.0\n",
"\n",
"#搭建一个两层神经网络\n",
"model = tf.keras.models.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=(28, 28)), #拉伸图像成一维向量\n",
" tf.keras.layers.Dense(128, activation='relu'), #第一层全连接+ReLU激活\n",
" tf.keras.layers.Dropout(0.2), #dropout层\n",
" tf.keras.layers.Dense(10, activation='softmax') #第二层全连接+softmax激活,输出预测标签\n",
"])\n",
"\n",
"#设置训练超参,优化器为sgd,损失函数为交叉熵,训练衡量指标为accuracy\n",
"model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy'])\n",
"\n",
"#开始训练,训练5个epoch,一个epoch代表所有图像计算一遍。每一个epoch能观察到训练精度的提升\n",
"model.fit(x_train, y_train, epochs=6)\n",
"\n",
"#计算训练了5个epoch的模型在测试集上的表现\n",
"model.evaluate(x_test, y_test)\n",
"\n",
"\n",
"#直观看一下模型预测结果,打印测试集中的前9张图像\n",
"for i in range(9): \n",
" plt.subplot(3,3,1+i)\n",
" plt.imshow(x_test[i], cmap='gray')\n",
"plt.show()\n",
"\n",
"#打印模型识别的数字,是否正确?\n",
"np.argmax(model(x_test[:9]).numpy(), axis=1)\n",
"\n",
"#保存训练好的模型\n",
"model.save(\"/data/output/model_epoch_6\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model.evaluate(x_test, y_test)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10000/10000 [==============================] - 0s 30us/sample - loss: 0.4680 - accuracy: 0.8846\n"
]
},
{
"data": {
"text/plain": [
"[0.46795412821769716, 0.8846]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.evaluate(x_test, y_test)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'model' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-1-e80b0bf1617b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_test\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mNameError\u001b[0m: name 'model' is not defined"
]
}
],
"source": [
"model.evaluate(x_test, y_test)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
" from ._conv import register_converters as _register_converters\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"(60000, 28, 28)\n",
"(60000,)\n",
"(10000, 28, 28)\n",
"(10000,)\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 640x480 with 9 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[5 0 4 1 9 2 1 3 1]\n",
"Epoch 1/6\n",
"60000/60000 [==============================] - 3s 58us/sample - loss: 1.6360 - accuracy: 0.5506\n",
"Epoch 2/6\n",
"60000/60000 [==============================] - 3s 55us/sample - loss: 0.9633 - accuracy: 0.7638\n",
"Epoch 3/6\n",
"60000/60000 [==============================] - 3s 55us/sample - loss: 0.7347 - accuracy: 0.8098\n",
"Epoch 4/6\n",
"60000/60000 [==============================] - 3s 55us/sample - loss: 0.6252 - accuracy: 0.8334\n",
"Epoch 5/6\n",
"60000/60000 [==============================] - 3s 55us/sample - loss: 0.5630 - accuracy: 0.8467\n",
"Epoch 6/6\n",
"60000/60000 [==============================] - 3s 55us/sample - loss: 0.5185 - accuracy: 0.8568\n",
"10000/10000 [==============================] - 0s 37us/sample - loss: 0.4304 - accuracy: 0.8923\n"
]
},
{
"data": {
"text/plain": [
"<Figure size 640x480 with 9 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import os\n",
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"mnist = tf.keras.datasets.mnist\n",
"(x_train, y_train),(x_test, y_test) = mnist.load_data(path=\"/data/data/mnist.npz\") #加载mnist数据集\n",
"\n",
"#验证mnist数据集大小。x为数据,y为标签。mnist每张图的像素为28*28\n",
"print(x_train.shape)\n",
"print(y_train.shape)\n",
"print(x_test.shape)\n",
"print(y_test.shape)\n",
"\n",
"#打印训练集中前9张,看看是什么数字\n",
"for i in range(9): \n",
" plt.subplot(3,3,1+i)\n",
" plt.imshow(x_train[i], cmap='gray')\n",
"plt.show()\n",
"\n",
"#打印相应的标签\n",
"print(y_train[:9])\n",
"\n",
"#基操:将像素标准化一下\n",
"x_train, x_test = x_train / 255.0, x_test / 255.0\n",
"\n",
"#搭建一个两层神经网络\n",
"model = tf.keras.models.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=(28, 28)), #拉伸图像成一维向量\n",
" tf.keras.layers.Dense(128, activation='relu'), #第一层全连接+ReLU激活\n",
" tf.keras.layers.Dropout(0.2), #dropout层\n",
" tf.keras.layers.Dense(10, activation='softmax') #第二层全连接+softmax激活,输出预测标签\n",
"])\n",
"\n",
"#设置训练超参,优化器为sgd,损失函数为交叉熵,训练衡量指标为accuracy\n",
"model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy'])\n",
"\n",
"#开始训练,训练5个epoch,一个epoch代表所有图像计算一遍。每一个epoch能观察到训练精度的提升\n",
"model.fit(x_train, y_train, epochs=6)\n",
"\n",
"#计算训练了5个epoch的模型在测试集上的表现\n",
"model.evaluate(x_test, y_test)\n",
"\n",
"\n",
"#直观看一下模型预测结果,打印测试集中的前9张图像\n",
"for i in range(9): \n",
" plt.subplot(3,3,1+i)\n",
" plt.imshow(x_test[i], cmap='gray')\n",
"plt.show()\n",
"\n",
"#打印模型识别的数字,是否正确?\n",
"np.argmax(model(x_test[:9]).numpy(), axis=1)\n",
"\n",
"#保存训练好的模型\n",
"model.save(\"/data/output/model_epoch_6\")\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10000/10000 [==============================] - 0s 30us/sample - loss: 0.4304 - accuracy: 0.8923\n"
]
},
{
"data": {
"text/plain": [
"[0.43036172440052034, 0.8923]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.evaluate(x_test, y_test)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(60000, 28, 28)\n",
"(60000,)\n",
"(10000, 28, 28)\n",
"(10000,)\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 9 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[5 0 4 1 9 2 1 3 1]\n",
"Epoch 1/9\n",
"60000/60000 [==============================] - 3s 57us/sample - loss: 1.7841 - accuracy: 0.4954\n",
"Epoch 2/9\n",
"60000/60000 [==============================] - 3s 55us/sample - loss: 1.0707 - accuracy: 0.7433\n",
"Epoch 3/9\n",
"60000/60000 [==============================] - 3s 55us/sample - loss: 0.7987 - accuracy: 0.7936\n",
"Epoch 4/9\n",
"60000/60000 [==============================] - 3s 56us/sample - loss: 0.6696 - accuracy: 0.8208\n",
"Epoch 5/9\n",
"60000/60000 [==============================] - 3s 55us/sample - loss: 0.5945 - accuracy: 0.8386\n",
"Epoch 6/9\n",
"60000/60000 [==============================] - 3s 55us/sample - loss: 0.5450 - accuracy: 0.8495\n",
"Epoch 7/9\n",
"60000/60000 [==============================] - 3s 55us/sample - loss: 0.5106 - accuracy: 0.8577\n",
"Epoch 8/9\n",
"60000/60000 [==============================] - 3s 55us/sample - loss: 0.4841 - accuracy: 0.8640\n",
"Epoch 9/9\n",
"60000/60000 [==============================] - 3s 55us/sample - loss: 0.4627 - accuracy: 0.8687\n",
"10000/10000 [==============================] - 0s 38us/sample - loss: 0.3818 - accuracy: 0.9001\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAVEAAAD8CAYAAADOg5fGAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xe0FPX9//HnW8SKXwMKiICCYEMsgAWJvYEdFRVED5afJQGDioUQE3v0aKIxRjAkErDECiqWKIRDsR9AxQYicgDv8SoiKogSBD6/P3Y/s3v77s7u7M7u63EO586dmd15s+97575n5lPMOYeIiORmk2IHICISZzqJioiEoJOoiEgIOomKiISgk6iISAg6iYqIhKCTqIhICKFOombWz8w+MbNFZjYyX0FJcSmv5Uu5zT/LtbG9mTUDFgLHAlXAbGCQc+7j/IUnUVNey5dyWxibhnjtgcAi59xiADN7HDgVaDAhZlbp3aNWOOdaFzuIJiiv2YtDXiHL3CqvmeU1zOV8e+DztO+rkuukYUuLHUAGlNfsxSGvoNxmK6O8hqlErZ51df5ymdklwCUhjiPRUl7LV5O5VV6zF+YkWgV0TPu+A/BF7Z2cc2OBsaDLg5hQXstXk7lVXrMX5nJ+NrCrmXU2s82AgcDk/IQlRaS8li/ltgByrkSdc+vNbBjwCtAMGOec+yhvkUlRKK/lS7ktjJybOOV0MF0ezHXO7V/sIPJNeVVey1RGeQ1zT1Sk4K6++moAttxySwD22WefYNuAAQNq7DtmzJhg+c033wTg4YcfLnSIUuHU7VNEJARdzkdLl30ZeuKJJ4C61WamPvvsMwCOOeYYAJYtW5afwOqnvEZkt912A2DBggUADB8+HID77ruvEIfLKK+qREVEQtA9USkZvvqEhitQX4EAvPLKKwDssssuAJx88snBti5dugAwePBgAG6//fb8BitF0aNHDwA2btwIQFVVVTHDAVSJioiEopOoiEgIupyXott//8S9+9NOO63Oto8+SrQFP+WUUwBYsWJFsO2HH34AYLPNNgPgrbfeCrbtu+++AGy33XYFiFiKZb/99gNgzZo1ADzzzDPFDAdQJSoiEkosKlH/kOHiiy8G4IsvUmMmrF27FoBHH30UgC+//BKARYsWRRmihNCuXTsAzFKDDPkKtG/fvgBUV1c3+PoRI0YA0K1btzrbXnzxxbzFKcXRvXv3YHnYsGFAaXWiUCUqIhJCLCrRO++8E4BOnTo1uM+ll14KwOrVq4FUJZMPvhmFj2POnDl5e2+B559/HoCuXbsG63weV65c2eTrBw4cCEDz5s0LEJ0U2x577BEsb7311kDN5nDFpkpURCQEnURFREJo8nLezMYBJwHLnXPdk+taAU8AnYAlwFnOuW8LFaR/oORH8Jk/f36wbc899wSgZ8+eABxxxBEA9O7dO9jn888T08p07Jg+qHdN69evB+Drr78GUg870vn+1+VwOV8Kea1t6dLspiq65pprgFR/6nRvv/12ja+VpBRzG8a1114bLPufkVL6HcykEh0P9Ku1biQwzTm3KzAt+b3Ey3iU13I1HuU2MhmN4mRmnYAX0v6qfQIc4ZyrNrN2wAzn3O4ZvE/BR4Vp2bIlkGqUCzB37lwADjjggAZf55tKLVy4EKhZ7bZq1QqAoUOHAjXHrcxSSY32E6e8eieddFKw/NRTTwGpxvbLly8PtvmHTTNnzowirJLKK+Qnt8Uexck/SF68eHGwzv9+pj9sKqCCDsrc1jlXDZBMSpuGdtTsgbGivJavjHKrvGav4E2cop498NtvE7d5pk+fXmfbtGnTmnz9GWecAaQqWoAPPvgAKK1mFcVWrFkhfRdRSFWgXnp+IqpAy04pzfZ5+OGH11nnn1mUklyfzn+VvCQg+XV5E/tLPCiv5Uu5LZBcK9HJwBDgjuTX5/IWUZG0aZO4uhk9ejQAm2yS+vty8803A5k1/I65ks3rs88+C8Bxxx1XZ9tDDz0EwPXXXx9pTDFTsrltyN57711nne/wUkqarETN7DHgTWB3M6sys4tIJOJYM/sUODb5vcSI8lq+lNtoNVmJOucGNbDp6DzHIhFSXsuXchutWPSdj4JvvtS6dWsg9YAK4JNPPilKTJLq9NCnTx8ANt9882CbH1v01ltvBVLji0q8+Y4yF1xwAQDvvvtusG3q1KlFiakx6vYpIhJCxVeiv/zlLwEYObJmB47+/fsHyx9++GGkMUnKxIkTgfpHqH/kkUeA1PTIUh78NNe+k8vLL78cbPOdYkqJKlERkRAqvhI94YQTgNRYlL5B/ptvvlm0mCQ1p5IfWMabMWNGsHzDDTdEGZJExM+P5bukP/3008UMp0mqREVEQtBJVEQkhIq8nN9yyy2D5X79EiOGrVu3DkhdIv7888/RB1bh0h8ejRo1Cqg75cd7770XLKtJU/nYYYcdguVDDz0USDUtLIVpkRujSlREJISKrET9iOgAPXr0AFLNKN54442ixCSpqY+h7tivvu+8HiaVp/PPPz9Y9uNY/Oc//ylSNNlRJSoiEkJFVaInnngiAL///e+DdatWrQJSIzVJ8Vx11VUNbhs2bBig+6Dlauedd66zLr3rdSlTJSoiEkJFVKL+qe9f//pXAJo1axZse+mllwB46623og9MMua7AGbaauL777+vsb9/yr/tttvW2fcXv/gF0HglvGHDBgCuu+66YN2PP/6YUSzStPS5s7znn3++CJFkL5PxRDua2XQzm29mH5nZ8OT6VmY21cw+TX5t2dR7SelQXsuT8hq9TC7n1wMjnHN7Ar2BoWbWDU3BGnfKa3lSXiOWyaDM1YCfJXC1mc0H2gOnAkckd5sAzACuq+ctiiL9kt03X+rcuTNQc9Sf9IdMlSRueX3//fez2t9Pp1xdXQ1A27ZtATj77LNDxfHll18Gy7fddluo9yqEuOX1kEMOAWo2to+brO6JJuey7gG8jaZgLRvKa3lSXqOR8UnUzFoAE4ErnHOrzCyj1xVrCtYuXboEy7169aqxLf0BQqWPRVlKefUP+QBOPfXUUO915plnNrnP+vXrAdi4cWOdbZMnTwZgzpw5Nda/+uqroeKKSinltTGnnXYaUPPK0Y9kP2vWrEIfPi8yauJkZs1JJORR59yk5GpNwRpzymt5Ul6j1WQlaok/YQ8C851zd6dtKskpWH2j3SlTptTZ5rt7vvDCC5HGVIpKMa+nn356sHzttdcCdQcgSbfXXnsBjd/nHDduHABLliyps82Pmr9gwYKsYy1VpZjX+my11VZAajzfdH78UN+srNRlcjn/S+A84AMz80PojCKRjCeT07EuA5q+fpJSoryWJ+U1Ypk8nX8NaOiGiqZgjSnltTwpr9Erux5Ll1ySeLC400471dk2c+ZMIDXtgJSuO++8M+N9zznnnAJGIoXge5L5/vH+QR7AvffeW5SYcqW+8yIiIZRNJeob7V5++eVFjkREmuIr0T59+hQ5kvBUiYqIhFA2laifl6VFixZ1tvkG9RqLUkTyTZWoiEgIZVOJ1jZv3rxg+eijEy07Vq5cWaxwRKRMqRIVEQlBJ1ERkRAsyobnUY7iVKLmOuf2L3YQ+aa8Kq9lKqO8qhIVEQkh6gdLK4A1ya9xsz3h4647L2x5UF7Lk/KagUgv5wHMbE4cL33iGndU4vr5xDXuqMT184kybl3Oi4iEoJOoiEgIxTiJji3CMfMhrnFHJa6fT1zjjkpcP5/I4o78nqiISDnR5byISAg6iYqIhBDZSdTM+pnZJ2a2yMxGRnXcbJlZRzObbmbzzewjMxueXN/KzKaa2afJry2LHWupiENuldfsKa8ZxhDFPVEzawYsBI4FqoDZwCDn3McFP3iWknNyt3POvWNm2wBzgf7A+cBK59wdyR+ols6564oYakmIS26V1+wor5mLqhI9EFjknFvsnFsHPA6cGtGxs+Kcq3bOvZNcXg3MB9qTiHdCcrcJJBIlMcmt8po15TVDoU6iWZT77YHP076vSq4raWbWCegBvA20dc5VQyJxQJviRVZYWV7GxS63lZpXKO/f2WLlNeeTaLLcvx84HugGDDKzbg3tXs+6km5bZWYtgInAFc65VcWOJypZ5hVilttKzSuU9+9sUfPqnMvpH3Aw8Era978FftvYviSSUMn/vs71847qXzZ5Tdu/2J9rsf+VfF5z/J0t9uda7H8Z5TXMKE71lfsH1d7JzC4BLgH2DnGscrG02AFkINu8SjzyChnkVnmtIaO8hrknmlG575wb6xKjqZwW4lgSnazy6mI4wk8FazK3ymv2wpxEq4COad93AL5oaGfn3EshjiXRySqvEivKbQGEOYnOBnY1s85mthkwEJicn7CkiJTX8qXcFkDO90Sdc+vNbBiJB0bNgHHOuY/yFpkUhfJavpTbwtBEddHShGblSXktT5qoTkSk0HQSFREJIerZPiOz9dZbB8t33XUXAJdeeikAc+fODbadeeaZACxdGpemfiJSSlSJioiEULaVaLt27YLliy++GICNGzcC0KtXr2DbSSedBMD9998fYXSSqZ49ewIwadIkADp16hTq/Y477rhgef78+QB8/vnnDe0uJebkk08GYPLkRMusYcOGAfDAAw8E+2zYsCHSmFSJioiEoJOoiEgIZXc537p1awAmTJjQxJ4SB3379gVg8803z8v7+ctBgAsvvBCAgQMH5uW9pTC22267YHn06NE1tv3tb38DYNy4ccG6n376KZrAklSJioiEUDaV6G9+8xsA+vdPzAJw4IEHZvS6ww47DIBNNkn8PZk3bx4As2bNyneIkoVNN038aJ5wwgl5fd/05m1XXXUVkGoOt2bNmrweS/LD/44CdOjQoca2xx57DIC1a9dGGlM6VaIiIiGUTSV6zz33AKlmTJk6/fTTa3z1je7PPvvsYJ/06kWiceSRRwJw8MEHA3DnnXfm5X1btkzNnNutW2JmjK222gpQJVpq/H3w3/3udw3u8/DDDwMQ5RggtakSFREJocmTqJmNM7PlZvZh2rpWZjbVzD5Nfm3Z2HtI6VFey5dyG60mh8Izs8OAH4CHnHPdk+vuBFY65+5ITrva0jl3XZMHK8DQWi+9lBgw//jjjwcyu5z/5ptvguUffvgBgJ133rnB/Zs1axYmxHQlM2RaKea1e/fuwfKMGTOAVK58LzOfr1z59wU45JBDgFTvtq+//jrXty2ZvEL+clvsofD23z/xkc6ePbvOtvXr1wPQvHnzQoaQn6HwnHOzgJW1Vp8K+IaYE4D+WYcnRaW8li/lNlq5Plhq65yrBnDOVZtZmzzG1KTDDz88WN59992BVAXaWCXq+9dOmTIlWPf9998DcNRRRwH138T+1a9+BcCYMWPChB0HRc3r9ddfHyz7Zkf9+vUDwlegrVq1Amr+7GT7EDLmiprbXJxxxhkNbkv/HS62gj+d1xSs5Ul5LU/Ka/ZyPYl+ZWbtkn/R2gHLG9rROTcWGAvh77H4EXwef/zxYN32229f777p44NOnDgRgJtuugmAH3/8scH9L7kk8fPju49CqnnNFltsAaS6mgH8/PPP2f0nSltR8jpgwACgZsP6RYsWATBnzpwwbx3wVxjp1ae/P/rdd9/l5RglLqPc5jOvYaU3svfWrVsHNN7sKWq5NnGaDAxJLg8BnstPOFJkymv5Um4LpMlK1MweA44AtjezKuAG4A7gSTO7CFgGnFnIID3fFbCh6hNg5syZQM1BJVasWNHke/tK9Pbbbwfg7rvvDrb5xti+IvVjGQJ89tlnGcVeakopr352Af85Q92BJnLlr14GDx4M1Bxr8tZbbwXK7mqipHKbiz59+tT4ms53iHjvvfcijakxTZ5EnXODGth0dJ5jkQgpr+VLuY2WeiyJiIRQNn3n/QMIP0ZkJpfw9fGX6v7yD+CAAw4IGZ3UZ9tttwWgd+/edbblqzmZf1DobwH5KUEApk+fnpdjSH419vtWis0MVYmKiIQQy0rUj/2Z7qCDDsrLe5tZnWPUPt6NN94YLJ933nl5OW4l8qP0tG/fHkiNDZlPXbp0qfH9hx9+2MCeUip8d08vvQmaKlERkTITq0r0sssuAwrbXc/PwdOjR49gXe0upemVqORu9erVQKq5yj777BNs8900V66s3QU8M23aJHo1+ob83muvvZbT+0lh+cFgAM4555wa23zXbICqqqrIYsqUKlERkRB0EhURCSFWl/Pp093mi+8j76eKGDVqVIP7+vEmy62HS7H4qW19r6/0UXtefPFFoGbPsYb4cUh32WWXYJ3vqVR7vNwKG7kpNtKnRa79IHfq1KlRh5MVVaIiIiHEqhItBD8azNChQxvcZ8mSJQAMGZIYv2HZsmUFj6uS3HDDDUCqeRnAiSeeCGTW7Ml3rEivOhsaX2H8+PG5hikFVPsBIKSaNv3973+POpysqBIVEQmhIitRPy8TpEbGb8zHH38MqHlMoSxYsACAs846K1i33377AdC1a9cmX//000/XWTdhQmImjPTuu5C6DyuloUOHDkDdZk2Qas6UrzFlC0WVqIhICJmMJ9oReAjYAdgIjHXO3WtmrYAngE7AEuAs59y3hQu1/i6Znp/t0xs7dmywvOOOO9bYlv76TJ7WFqJVQLGVUl7r4xvg5zpu5OLFi+tdnz6jaDl2AS31vNbmxwyt73f62WefjTqcnGRSia4HRjjn9gR6A0PNrBswEpjmnNsVmJb8XuJDeS1PymvEMpkyudo5905yeTUwH2iPpmCNNeW1PCmv0cvqwZKZdQJ6AG9ThClY/QgufpqOdC+88AJQ/+V5Y5fsDW3z0ytXgmLntRD8rZ/0ZlNQnpfwDYlDXtMb2Xu+ydq9994bdTg5yfgkamYtgInAFc65VbV/OBt5naZgLWHKa3lSXqOT0UnUzJqTSMijzrlJydWRT8E6aVLi0Ndcc02wLn1q41z4rpx+xHM/Enp1dXWo942DUslrIfiG97W7fVaCOOW1b9++ddb5zizpozeVsibviVriT9iDwHznXHpHZk3BGmPKa3lSXqOXSSX6S+A84AMz8+1NRlGEKVj9tMbp0yH375+4Pz58+PCc3vO2224D4P777w8ZXeyUTF4LYYsttqjxfQU1so9FXps3bw7UnXkAYO3atUB8BvrJZMrk14CGbqhoCtaYUl7Lk/IaPfVYEhEJIZZ952fNmlVnecqUKUDqwVB6LyM/DbLvxZT+pNL3i5fycsEFFwCpkYBuueWWYoYjtfimhb5ffHpPskWLFhUlplypEhURCSGWlWh9Xn755RpfpbLNnj0bSI2MP3369GKGI7Vs2LABSI3nm94Ube7cuUWJKVeqREVEQrAoGyOXYqPsiM11zu1f7CDyTXlVXstURnlVJSoiEoJOoiIiIegkKiISgk6iIiIh6CQqIhKCTqIiIiFE3dh+BbAm+TVutid83DvnI5ASpLyWJ+U1A5G2EwUwszlxbFMX17ijEtfPJ65xRyWun0+UcetyXkQkBJ1ERURCKMZJdGwRjpkPcY07KnH9fOIad1Ti+vlEFnfk90RFRMqJLudFREKI7CRqZv3M7BMzW2RmI6M6brbMrKOZTTez+Wb2kZkNT65vZWZTzezT5NeWxY61VMQht8pr9pTXDGOI4nLezJoBC4FjgSpgNjDIOVdyc3Mk5+Ru55x7x8y2AeYC/YHzgZXOuTuSP1AtnXPXFTHUkhCX3Cqv2VFeMxdVJXogsMg5t9g5tw54HDg1omNnxTlX7Zx7J7m8GpgPtCcR74TkbhNIJEpiklvlNWvKa4ZCnUSzKPfbA5+nfV+VXFfSzKwT0AN4G2jrnKuGROKANsWLrLCyvIyLXW4rNa9Q3r+zxcprzifRZLl/P3A80A0YZGbdGtq9nnUl3SzAzFoAE4ErnHOrih1PVLLMK8Qst5WaVyjv39li5jVMJZpNuV8FdEz7vgPwRYhjF5SZNSeRkEedc5OSq79K3n/x92GWFyu+Asv2Mi42ua3wvEKZ/s4WO685P1gyswFAP+fc/0t+fx5wkHNuWD37bkriJnXnELGWgxXOudbFDqIx2eQ1uX1T4OcIQyxFJZ9XyOl3VnnNIK9hKtGMyn0zuwR4C9gQ4ljlYmmxA8hAxnk1szkkclvp4pBXyCC3ymsNGeU1zEk0o3LfOTfWObe/c27XEMeS6GSb19iN8FPBmsyt8pq9MCfR2cCuZtbZzDYDBgKT8xOWFJHyWr6U2wLIeVBm59x6MxsGvAI0A8Y55z7KW2RSFMpr+VJuCyPSAUjMrGSbSERkbjleJimvymuZyiivGoBERCQEnURFRELQSVREJISoZ/sUEQmtZcvEyHY77bRTg/ssXZpo5nnllVcG6z788EMAFi5cCMC8efNCx6JKVEQkhNhXom3aJAZnefLJJwF44403ABg7NjXFypIlS/JyrG233RaAww47LFj38ssvA/Dzz5XeQ06kME488cRg+ZRTTgHgiCOOAKBr164Nvs5XmzvvnJo+fvPNN6+xT7NmzULHp0pURCQEnURFREKI5eW8v6kM8NFHiQ4X/lL7q6++AvJ3CZ/+3nPnzgWgdevUwC69evUCYNGiRXk7ntT1f//3fwDcfvvtwbru3bsDcMwxxwC6pRJnXbp0CZaHDh0KwMUXXwzAlltuGWwzq28MlfrttttueYqucapERURCiFUluv322wPwxBNPBOtatWoFwOjRowG4/PLL837c66+/HoDOnRPDoV566aXBNlWghTV48GAAbrvtNgA6duxYZx9fpX7zzTfRBSZ51aFDh2B5+PDhod5rwYIFQOoqtdBUiYqIhBCrSrRnz55AqnlDuptvvjmvx9prr72C5REjRgDwzDPPADUrYSkMX5n85S9/AWC77bYDoL4Bc+677z4Ahg1LDdC+cuXKQocoGfJXkJCqMl9//XUg1UTwf//7X7DP999/D8CaNWsA2HrrrYNtU6ZMAVKN5t9++20A3n333WCfn376qcbrC02VqIhICE2eRM1snJktN7MP09a1MrOpZvZp8mvLxt5DSo/yWr6U22hlcjk/Hvgb8FDaupHANOfcHcm5q0cC1+U/vATfK+mMM86os+2iiy4C4Ouvv87Lsfxl/H//+9862/zl/OrVq/NyrCIbT5Hz2pirr74aSD04bMzZZ58NQL9+/YJ1/kGUv9Rft25dvkMsZeMpgdz6y3B/CQ6w7777AnDaaafV2Pett1JTOvnbdr6ZYnr/+KqqKgA2btyY/4Bz1GQl6pybBdS+wXQqMCG5PAHon+e4pMCU1/Kl3EYr1wdLbZ1z1QDOuWoza5PHmOr485//DMC5554LpBq9Azz11FN5Pdahhx4KQNu2bYN148ePB+CRRx7J67FKUKR5rS29j/MFF1xQY9v7778PpDpTQKqRvec7RUCqkn300UcB+PLLL/MbbPxEltvNNtsMgH//+99AqvoE+OMf/wjUf6Xn1e4os2zZsjxHmF8FfzqfnDL5kkIfR6KlvJYn5TV7uZ5EvzKzdsm/aO2A5Q3t6JwbC4yF3Ods8c1a/H2QL75IzfIa9l6X71I2atQoAH7961/XOCbAhRdeGOoYMRJpXmvbb7/9guVtttkGgFdffRWAww8/HIAtttgi2GfQoEFAKnfpXQd32GEHAJ577jkAjj/+eKCimz5llNtc89qiRYtg+be//S0AJ510EgArVqwItv3pT38C4Mcff8wy/NKVaxOnycCQ5PIQ4Ln8hCNFpryWL+W2QJqsRM3sMeAIYHszqwJuAO4AnjSzi4BlwJmFDLK29PEF/ZO/7777DoAxY8Y0+Xpf1UCq4X7v3r1r7PP000+HDbOklWJe08d69FcC99xzT4191q5dGyz/61//AuDMMxNh7rLLLnXe01c8lfR0vhi57d8/9Zxq5MiRQOpepn/OAKmG9OWkyZOoc25QA5uOznMsEiHltXwpt9FSjyURkRBi0Xf+3nvvBeDII48EYMcddwy2+ak6/DiDfvqAxqSPSVi7L/bixYuB1MMKiY5/UJTO37p59tlnG3zd/vvv3+A234j7hx9+CBmdNKZPnz511vn+7L6BfLlSJSoiEkIsKlHfuH6fffYBajaF8V39rrnmGiDV/XPChAk05OGHHw6Wa0+Z6ie6++yzz8KGLVl67LHHgmV/RXHAAQcAsMceewCw9957B/v4roN+pgP/cDF9nR8d3ef8448/LkjslW7AgAF11vnfzRtuuCFY55ucvffee9EEFgFVoiIiIVh94zMW7GB5apQdVnpTGD8yvf/L2LdvXyB/A5rUMtc51/ANvJjKV17TBxvxefFdOf197Pp+Xn0XQj83D8ALL7wAwK677grAP/7xDwAuu+yyfIRaW8XnNT0vjQ0O4rc98MADQOqedfogIz739Y1M7wcIevPNN4GC32/NKK+qREVEQtBJVEQkhIq8nPejMgGcd955QOom+NSpUwt56Iq/7MuUH6HJ9xzzl/XpP69+rNDrrksMi5nem8mPFuR7zyxdurTG+0JeHx5WfF7vuuuuYPmqq64qSDzp/O22GTNmADBw4MBCHEaX8yIihVZRlajvY50+0Zwfpd435H/nnXcKGULFVyzZ8pXjOeecA9RsxvSHP/wBqL8hvR+dy49p6ZtMpY8JO2TIkDqvy1HF57VZs2bBco8ePYDUZ7/ppqmWlH7K6002yU/95s9fN954Y7Du1ltvzct7o0pURKTwYtHYPl/8mJLpfFOYAlegkiPffKmxkdDr46fN9VcdvhL1VxyQalJVwWOM5s2GDRuC5Tlz5gCw22671dnv6KMTY6A0b94cSFWQvlNFtnzTt169euX0+nxQJSoiEkIm44l2JDFr4A7ARmCsc+5eM2sFPAF0ApYAZznnvi1cqOH5SnTNmjXBOj9/U6Upp7w25sknnwRSlaifGRRg2LBhANx8883RB1YgpZ7XadOm1fjed+FOr0TXr18PpMaL9R0lAK644gogdY+8FGRSia4HRjjn9gR6A0PNrBupKVh3BaYlv5f4UF7Lk/IasUymTK52zr2TXF4NzAfaoylYY015LU/Ka/SyauJkZp2AWUB3YJlz7hdp2751zrVs4vVFaeLk+0uPHj0agOXLU3N0+QnNIlKSTWHimtds+MvG11+Gbng8AAAEg0lEQVR/PVjnJ73bc889AVi4cGGub6+85qhnz54AzJ49u8F9pk+fHiz76XzSxwSG1O82wOWXX56v8DLKa8ZP582sBTARuMI5t6r2f6KR12kK1hKmvJYn5TU6GVWiZtYceAF4xTl3d3LdJ8ARaVOwznDO7d7E+xSlYvEjNPmxKNO7fV500UVAaopePw6ln2Qrz0qqYol7XnMxYsSIYNl3VZw0aRKQ6gIMqSZSGVJec+Q7RYwbNy5Yd9ZZZzX5Ot+k6sUXXwTg3HPPDbalPzgOKT+N7S3xJ+xBYL5PSJKmYI0x5bU8Ka/Ra7ISNbNDgFeBD0g0mQAYBbwNPAnsRHIKVudco62WS6USffDBB4NtM2fOBODKK68EUmMY5rFLYLqSqVjKIa+5aN26dbDs74927doVqDljwvvvv5/N2yqvIbVt2zZY/uc//wmk5s5q06ZNsG3JkiVAaqaC9O6eBZCfe6LOudeAhm6oaArWmFJey5PyGj31WBIRCaEiRnGqfTlf35TJ/hL/lltuAeDzzz8vRCglc9mXT3G6nE/np6Twl4jpE+UNHjw4m7dSXgvAP+jr3bt3sO6mm24CajZTLCCN4iQiUmgVUYkecsghQKqP9KxZs4JtY8aMAeDbbxPdiNetW1fIUFSxlKApU6YAcPDBBwfrDjroICDjKZaV1/KkSlREpNAqYjzR1157DYCjjjqqyJFIKRowYAAA8+bNC9b5Zk8ZVqJSwVSJioiEUBGVqEhjVq1aBUDnzp2LHInEkSpREZEQdBIVEQlBJ1ERkRB0EhURCSHqB0srgDXJr3GzPeHj3jkfgZQg5bU8Ka8ZiLTHEoCZzYlj7464xh2VuH4+cY07KnH9fKKMW5fzIiIh6CQqIhJCMU6iY4twzHyIa9xRievnE9e4oxLXzyeyuCO/JyoiUk50OS8iEkJkJ1Ez62dmn5jZIjMbGdVxs2VmHc1supnNN7OPzGx4cn0rM5tqZp8mv7YsdqylIg65VV6zp7xmGEMUl/Nm1gxYCBwLVAGzgUHOuZIbZyw5J3c759w7ZrYNMBfoD5wPrHTO3ZH8gWrpnLuuiKGWhLjkVnnNjvKauagq0QOBRc65xc65dcDjwKkRHTsrzrlq59w7yeXVwHygPYl4JyR3m0AiURKT3CqvWVNeMxTVSbQ9kD7zW1VyXUkzs05ADxJzdrd1zlVDInFAm4ZfWVFil1vlNSPKa4aiOonWNw92STcLMLMWwETgCufcqmLHU8JilVvlNWPKa4aiOolWAR3Tvu8AfBHRsbNmZs1JJORR59yk5Oqvkvdf/H2YSOZsjYHY5FZ5zYrymqGoTqKzgV3NrLOZbQYMBCZHdOysWGJS+geB+c65u9M2TQaGJJeHAM9FHVuJikVuldesKa+ZxhBVY3szOwH4C9AMGOecuy2SA2fJzA4BXgU+ADYmV48icZ/lSWAnYBlwpnNuZVGCLDFxyK3ymj3lNcMY1GNJRCR36rEkIhKCTqIiIiHoJCoiEoJOoiIiIegkKiISgk6iIiIh6CQqIhKCTqIiIiH8f426yqqIsOjrAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 9 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import os\n",
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"mnist = tf.keras.datasets.mnist\n",
"(x_train, y_train),(x_test, y_test) = mnist.load_data(path=\"/data/data/mnist.npz\") #加载mnist数据集\n",
"\n",
"#验证mnist数据集大小。x为数据,y为标签。mnist每张图的像素为28*28\n",
"print(x_train.shape)\n",
"print(y_train.shape)\n",
"print(x_test.shape)\n",
"print(y_test.shape)\n",
"\n",
"#打印训练集中前9张,看看是什么数字\n",
"for i in range(9): \n",
" plt.subplot(3,3,1+i)\n",
" plt.imshow(x_train[i], cmap='gray')\n",
"plt.show()\n",
"\n",
"#打印相应的标签\n",
"print(y_train[:9])\n",
"\n",
"#基操:将像素标准化一下\n",
"x_train, x_test = x_train / 255.0, x_test / 255.0\n",
"\n",
"#搭建一个两层神经网络\n",
"model = tf.keras.models.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=(28, 28)), #拉伸图像成一维向量\n",
" tf.keras.layers.Dense(128, activation='relu'), #第一层全连接+ReLU激活\n",
" tf.keras.layers.Dropout(0.2), #dropout层\n",
" tf.keras.layers.Dense(10, activation='softmax') #第二层全连接+softmax激活,输出预测标签\n",
"])\n",
"\n",
"#设置训练超参,优化器为sgd,损失函数为交叉熵,训练衡量指标为accuracy\n",
"model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy'])\n",
"\n",
"#开始训练,训练5个epoch,一个epoch代表所有图像计算一遍。每一个epoch能观察到训练精度的提升\n",
"model.fit(x_train, y_train, epochs=9)\n",
"\n",
"#计算训练了5个epoch的模型在测试集上的表现\n",
"model.evaluate(x_test, y_test)\n",
"\n",
"\n",
"#直观看一下模型预测结果,打印测试集中的前9张图像\n",
"for i in range(9): \n",
" plt.subplot(3,3,1+i)\n",
" plt.imshow(x_test[i], cmap='gray')\n",
"plt.show()\n",
"\n",
"#打印模型识别的数字,是否正确?\n",
"np.argmax(model(x_test[:9]).numpy(), axis=1)\n",
"\n",
"#保存训练好的模型\n",
"model.save(\"/data/output/model_epoch_9\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10000/10000 [==============================] - 0s 30us/sample - loss: 0.3818 - accuracy: 0.9001\n"
]
},
{
"data": {
"text/plain": [
"[0.381765931224823, 0.9001]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.evaluate(x_test, y_test)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(60000, 28, 28)\n",
"(60000,)\n",
"(10000, 28, 28)\n",
"(10000,)\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 9 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[5 0 4 1 9 2 1 3 1]\n",
"Epoch 1/15\n",
"60000/60000 [==============================] - 4s 64us/sample - loss: 0.2953 - accuracy: 0.9138\n",
"Epoch 2/15\n",
"60000/60000 [==============================] - 4s 61us/sample - loss: 0.1440 - accuracy: 0.9567\n",
"Epoch 3/15\n",
"60000/60000 [==============================] - 4s 63us/sample - loss: 0.1086 - accuracy: 0.9677\n",
"Epoch 4/15\n",
"60000/60000 [==============================] - 4s 62us/sample - loss: 0.0888 - accuracy: 0.9721\n",
"Epoch 5/15\n",
"60000/60000 [==============================] - 4s 62us/sample - loss: 0.0763 - accuracy: 0.9758\n",
"Epoch 6/15\n",
"60000/60000 [==============================] - 4s 61us/sample - loss: 0.0658 - accuracy: 0.9789\n",
"Epoch 7/15\n",
"60000/60000 [==============================] - 4s 60us/sample - loss: 0.0572 - accuracy: 0.9813\n",
"Epoch 8/15\n",
"60000/60000 [==============================] - 4s 61us/sample - loss: 0.0519 - accuracy: 0.9832\n",
"Epoch 9/15\n",
"60000/60000 [==============================] - 4s 62us/sample - loss: 0.0503 - accuracy: 0.9830\n",
"Epoch 10/15\n",
"60000/60000 [==============================] - 4s 62us/sample - loss: 0.0430 - accuracy: 0.9857\n",
"Epoch 11/15\n",
"60000/60000 [==============================] - 4s 61us/sample - loss: 0.0414 - accuracy: 0.9861\n",
"Epoch 12/15\n",
"60000/60000 [==============================] - 4s 62us/sample - loss: 0.0383 - accuracy: 0.9877\n",
"Epoch 13/15\n",
"60000/60000 [==============================] - 4s 61us/sample - loss: 0.0362 - accuracy: 0.9876\n",
"Epoch 14/15\n",
"60000/60000 [==============================] - 4s 61us/sample - loss: 0.0342 - accuracy: 0.9881\n",
"Epoch 15/15\n",
"60000/60000 [==============================] - ETA: 0s - loss: 0.0336 - accuracy: 0.98 - 4s 61us/sample - loss: 0.0336 - accuracy: 0.9884\n",
"10000/10000 [==============================] - 0s 37us/sample - loss: 0.0869 - accuracy: 0.9791\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 9 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import os\n",
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"mnist = tf.keras.datasets.mnist\n",
"(x_train, y_train),(x_test, y_test) = mnist.load_data(path=\"/data/data/mnist.npz\") #加载mnist数据集\n",
"\n",
"#验证mnist数据集大小。x为数据,y为标签。mnist每张图的像素为28*28\n",
"print(x_train.shape)\n",
"print(y_train.shape)\n",
"print(x_test.shape)\n",
"print(y_test.shape)\n",
"\n",
"#打印训练集中前9张,看看是什么数字\n",
"for i in range(9): \n",
" plt.subplot(3,3,1+i)\n",
" plt.imshow(x_train[i], cmap='gray')\n",
"plt.show()\n",
"\n",
"#打印相应的标签\n",
"print(y_train[:9])\n",
"\n",
"#基操:将像素标准化一下\n",
"x_train, x_test = x_train / 255.0, x_test / 255.0\n",
"\n",
"#搭建一个两层神经网络\n",
"model = tf.keras.models.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=(28, 28)), #拉伸图像成一维向量\n",
" tf.keras.layers.Dense(128, activation='relu'), #第一层全连接+ReLU激活\n",
" tf.keras.layers.Dropout(0.2), #dropout层\n",
" tf.keras.layers.Dense(10, activation='softmax') #第二层全连接+softmax激活,输出预测标签\n",
"])\n",
"\n",
"#设置训练超参,优化器为sgd,损失函数为交叉熵,训练衡量指标为accuracy\n",
"model.compile(optimizer='Adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])\n",
"\n",
"#开始训练,训练5个epoch,一个epoch代表所有图像计算一遍。每一个epoch能观察到训练精度的提升\n",
"model.fit(x_train, y_train, epochs=15)\n",
"\n",
"#计算训练了5个epoch的模型在测试集上的表现\n",
"model.evaluate(x_test, y_test)\n",
"\n",
"\n",
"#直观看一下模型预测结果,打印测试集中的前9张图像\n",
"for i in range(9): \n",
" plt.subplot(3,3,1+i)\n",
" plt.imshow(x_test[i], cmap='gray')\n",
"plt.show()\n",
"\n",
"#打印模型识别的数字,是否正确?\n",
"np.argmax(model(x_test[:9]).numpy(), axis=1)\n",
"\n",
"#保存训练好的模型\n",
"model.save(\"/data/output/model_epoch_15\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10000/10000 [==============================] - 0s 30us/sample - loss: 0.0869 - accuracy: 0.9791\n"
]
},
{
"data": {
"text/plain": [
"[0.08690820527771356, 0.9791]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.evaluate(x_test, y_test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "tensorflow-2.0",
"language": "python",
"name": "tensorflow-2.0"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}