|
@ -55,4 +55,4 @@ plt.savefig('./mnist/output/2.jpg') |
|
|
# np.argmax(model(x_test[:9]).numpy(), axis=1) |
|
|
# np.argmax(model(x_test[:9]).numpy(), axis=1) |
|
|
|
|
|
|
|
|
#保存训练好的模型 |
|
|
#保存训练好的模型 |
|
|
model.save("./mnist/output/model_epoch_5") |
|
|
|
|
|
|
|
|
model.save("./mnist/output/model_epoch_10") |