From 7406aa8a86ec203d82c3ab3f1582a94ed557dc87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E9=9B=A8=E6=99=B4?= <10205501458@stu.ecnu.edu.cn> Date: Thu, 2 Jun 2022 11:05:48 +0800 Subject: [PATCH] Upload files to 'code/CNN' --- code/CNN/3.ipynb | 498 +++++++++++++++++++++++++++++++++++++++++++++ code/CNN/CNN_classifier.py | 188 +++++++++++++++++ 2 files changed, 686 insertions(+) create mode 100644 code/CNN/3.ipynb create mode 100644 code/CNN/CNN_classifier.py diff --git a/code/CNN/3.ipynb b/code/CNN/3.ipynb new file mode 100644 index 0000000..4da95a0 --- /dev/null +++ b/code/CNN/3.ipynb @@ -0,0 +1,498 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[-0.73833905 -0.14152196 -0.35304517 -0.5694006 -0.41008208 -0.11232905]\n", + " [-0.73833905 -0.14164066 -0.35324653 -0.56959656 -0.41008208 -0.11232905]\n", + " [-0.73833905 -0.14176103 -0.35345072 -0.56979527 -0.41342335 -0.34642073]\n", + " ...\n", + " [-0.73833905 -0.35299514 -0.71177197 -0.91850029 -0.40122532 0.31341636]\n", + " [-0.73833905 -0.35869641 -0.72144317 -0.92791195 -0.40122532 0.31341636]\n", + " [-0.73833905 -0.36224577 -0.72746403 -0.93377123 -0.40122532 0.31341636]] [4, 5, 4, 5, 4, 5]\n", + "[[-0.70953335 -1.24696123 0.80366164 0.32426513 1.75007941 0.98942482]\n", + " [-0.70953335 -1.24696123 0.80399547 0.32467858 1.75007941 0.98942482]\n", + " [-0.70953335 -1.24696123 0.80433397 0.32509783 1.76582028 0.98942482]\n", + " ...\n", + " [-0.70953335 -1.24696123 -0.58801131 -1.39938764 1.79806729 0.42918451]\n", + " [-0.70953335 -1.24696123 -0.64469193 -1.46958926 1.79806729 0.42918451]\n", + " [-0.70953335 -1.24696123 -0.68310364 -1.51716398 1.79806729 0.42918451]] [1, 8, 3, 6, 4, 5]\n", + "[[-0.45775858 -0.49249489 -0.32148183 -0.61943689 -0.54651911 0.00407354]\n", + " [-0.45775858 -0.49249489 -0.32139901 -0.61934213 -0.54651911 0.00407354]\n", + " [-0.45775858 -0.49249489 -0.32132185 -0.61925384 -0.54651911 0.00407354]\n", + " ...\n", + " [-0.45775858 -0.49249489 -0.71650663 -1.07141423 -0.58342324 -0.48046198]\n", + " [-0.45775858 -0.49249489 -0.72670863 -1.0830871 -0.58342324 -0.48046198]\n", + " [-0.45775858 -0.49249489 -0.73306012 -1.09035432 -0.58342324 -0.48046198]] [8, 1, 8, 1, 2, 7]\n", + "[[ 1.30598423 -1.28489274 -0.95294996 -0.59131444 -0.11004319 -0.35756886]\n", + " [ 1.30544086 -1.28489274 -0.95294996 -0.59170509 -0.11004319 -0.35756886]\n", + " [ 1.30469694 -1.28489274 -0.95294996 -0.59223993 -0.11004319 -0.35756886]\n", + " ...\n", + " [-0.47438622 -1.28489274 -0.95294996 -1.87129372 -0.04617476 -0.82399449]\n", + " [-0.52386171 -1.28489274 -0.95294996 -1.90686362 -0.04617476 -0.82399449]\n", + " [-0.5546629 -1.28489274 -0.95294996 -1.92900783 -0.04617476 -0.82399449]] [2, 7, 5, 4, 4, 5]\n", + "[[-0.77336447 0.83714232 -0.02652395 -0.84068222 -0.31514096 -0.45631725]\n", + " [-0.77336447 0.83719004 -0.02650423 -0.84064236 -0.31514096 -0.45631725]\n", + " [-0.77336447 0.83722937 -0.02648796 -0.8406095 -0.31514096 -0.45631725]\n", + " ...\n", + " [-0.77336447 -1.03465463 -0.80044849 -2.40429892 -0.15185466 0.00683089]\n", + " [-0.77336447 -1.11527661 -0.83378294 -2.47164697 -0.15185466 0.00683089]\n", + " [-0.77336447 -1.16991318 -0.85637331 -2.51728795 -0.15185466 0.00683089]] [4, 5, 6, 3, 4, 5]\n", + "[[-0.76625021 -0.21085965 0.30426279 -0.59697468 -0.28227719 -0.55859043]\n", + " [-0.76625021 -0.21085965 0.30388242 -0.59733074 -0.28227719 -0.55859043]\n", + " [-0.76625021 -0.21085965 0.30351743 -0.5976724 -0.28227719 -0.55859043]\n", + " ...\n", + " [-0.76625021 -0.21085965 -0.73297402 -1.56792586 -0.29692441 -1.03858709]\n", + " [-0.76625021 -0.21085965 -0.74910391 -1.58302495 -0.29692441 -1.03858709]\n", + " [-0.76625021 -0.21085965 -0.75895844 -1.59224972 -0.29364569 -0.38492619]] [8, 1, 5, 4, 3, 6]\n", + "[[-1.04641915 -0.20380784 0.46457631 -1.11356293 -0.78146525 -0.57856754]\n", + " [-1.04641915 -0.20380784 0.46458987 -1.11355069 -0.78146525 -0.57856754]\n", + " [-1.04641915 -0.20380784 0.46459897 -1.11354247 -0.78146525 -0.57856754]\n", + " ...\n", + " [-1.04641915 -0.20380784 -0.6299704 -2.10203693 -0.73427062 -0.65765889]\n", + " [-1.04641915 -0.20380784 -0.65953927 -2.12874028 -0.73427062 -0.65765889]\n", + " [-1.04641915 -0.20380784 -0.6779475 -2.14536456 -0.73427062 -0.65765889]] [3, 6, 5, 4, 4, 5]\n", + "[[-0.4662164 -0.92535848 0.25824924 -0.46050554 -0.37632731 -0.85179158]\n", + " [-0.4662164 -0.92535848 0.25683678 -0.46197126 -0.37632731 -0.85179158]\n", + " [-0.4662164 -0.92535848 0.25569732 -0.46315368 -0.37632731 -0.85179158]\n", + " ...\n", + " [-0.4662164 -0.92535848 -0.92033216 -1.6835251 -0.31534675 -1.24232196]\n", + " [-0.4662164 -0.92535848 -0.96908416 -1.73411528 -0.31534675 -1.24232196]\n", + " [-0.4662164 -0.92535848 -1.00212256 -1.76839939 -0.31534675 -1.24232196]] [2, 7, 3, 6, 4, 5]\n", + "[[-0.22565103 -0.52872559 0.68793776 0.5577097 -0.55639546 -1.75613221]\n", + " [-0.22565103 -0.52928228 0.68686071 0.55636073 -0.55639546 -1.75613221]\n", + " [-0.22565103 -0.529847 0.68576813 0.55499232 -0.5784538 -1.47606102]\n", + " ...\n", + " [-0.22565103 -1.28516026 -0.77555202 -1.27526296 -0.6496459 -1.16293839]\n", + " [-0.22565103 -1.31437712 -0.83207848 -1.34606048 -0.6496459 -1.16293839]\n", + " [-0.22565103 -1.33417699 -0.87038569 -1.394039 -0.65453774 -1.16293839]] [7, 2, 4, 5, 7, 2]\n", + "[[ 0.62743838 0.23976437 -0.46310077 0.11777361 0.55671807 -1.16009444]\n", + " [ 0.6252212 0.23672288 -0.46310077 0.11543644 0.55671807 -1.16009444]\n", + " [ 0.62331947 0.23411411 -0.46310077 0.11343178 0.55671807 -1.16009444]\n", + " ...\n", + " [-0.82841518 -1.75735736 -0.46310077 -1.41687342 0.10305765 -1.50787441]\n", + " [-0.875806 -1.82236749 -0.46310077 -1.46682912 0.10305765 -1.50787441]\n", + " [-0.90792207 -1.86642391 -0.46310077 -1.50068336 0.10305765 -1.50787441]] [2, 7, 6, 3, 2, 7]\n", + "[[-0.14585094 -0.59648574 -0.4026451 -0.64208217 -1.06833399 0.59287188]\n", + " [-0.14576973 -0.59620239 -0.4026451 -0.64195377 -1.06833399 0.59287188]\n", + " [-0.14567293 -0.59586462 -0.4026451 -0.6418007 -1.06833399 0.59287188]\n", + " ...\n", + " [-0.50682212 -1.85604045 -0.4026451 -1.21285978 -1.07844381 -1.49617308]\n", + " [-0.512524 -1.8759363 -0.4026451 -1.22187575 -1.07844381 -1.49617308]\n", + " [-0.51600644 -1.88808774 -0.4026451 -1.22738228 -1.09930667 -0.37006732]] [3, 6, 5, 4, 1, 8]\n", + "[[-0.35862825 -0.35714423 -0.39331462 -0.70993826 -0.37556131 -0.28470072]\n", + " [-0.35849929 -0.35714423 -0.39331462 -0.70985678 -0.37556131 -0.28470072]\n", + " [-0.35839306 -0.35714423 -0.39331462 -0.70978965 -0.37556131 -0.28470072]\n", + " ...\n", + " [-0.88192112 -0.35714423 -0.39331462 -1.04058493 -0.37906982 0.15335923]\n", + " [-0.90278179 -0.35714423 -0.39331462 -1.0537659 -0.37906982 0.15335923]\n", + " [-0.91691875 -0.35714423 -0.39331462 -1.06269845 -0.37906982 0.15335923]] [5, 4, 6, 3, 7, 2]\n", + "[[-0.19698857 -0.73098187 0.92686605 0.47388338 0.24309849 -0.80674761]\n", + " [-0.19698857 -0.73098187 0.92686605 0.47388338 0.24309849 -0.80674761]\n", + " [-0.19698857 -0.73098187 0.92686605 0.47388338 0.24309849 -0.80674761]\n", + " ...\n", + " [-0.19698857 -0.73098187 -0.28154858 -0.35253519 0.24177257 -0.71297856]\n", + " [-0.19698857 -0.73098187 -0.30076571 -0.36567753 0.24177257 -0.71297856]\n", + " [-0.19698857 -0.73098187 -0.31250238 -0.37370408 0.24154317 -0.80674761]] [9, 0, 7, 2, 1, 8]\n", + "[[-0.79352257 1.92045095 0.22047343 0.19939002 0.39408206 -3.64419815]\n", + " [-0.79352257 1.91778509 0.21951147 0.19754164 0.39408206 -3.64419815]\n", + " [-0.79352257 1.91512947 0.21855321 0.19570036 0.39408206 -3.64419815]\n", + " ...\n", + " [-0.79352257 -0.43667077 -0.63007766 -1.43492797 0.48893761 -0.98543255]\n", + " [-0.79352257 -0.50089675 -0.65325317 -1.47945926 0.48893761 -0.98543255]\n", + " [-0.79352257 -0.54088346 -0.6676821 -1.50718418 0.48893761 -0.98543255]] [6, 3, 2, 7, 5, 4]\n", + "[[-0.27237369 -0.54063198 -0.50811608 -0.72395271 -0.68405473 0.33905917]\n", + " [-0.27237369 -0.54063198 -0.50730244 -0.72304336 -0.68405473 0.33905917]\n", + " [-0.27237369 -0.54063198 -0.50658053 -0.72223654 -0.68405473 0.33905917]\n", + " ...\n", + " [-0.27237369 -0.54063198 -1.40019084 -1.72095853 -0.66467772 -0.95228849]\n", + " [-0.27237369 -0.54063198 -1.41464306 -1.73711071 -0.66467772 -0.95228849]\n", + " [-0.27237369 -0.54063198 -1.4234696 -1.74697548 -0.66324022 -0.59486012]] [7, 2, 5, 4, 6, 3]\n", + "[[-0.8197209 -0.37802163 0.11338408 -0.49811041 -0.80759158 -1.6838975 ]\n", + " [-0.8197209 -0.37802163 0.11338337 -0.49811124 -0.80759158 -1.6838975 ]\n", + " [-0.8197209 -0.37802163 0.11338152 -0.49811342 -0.80759158 -1.6838975 ]\n", + " ...\n", + " [-0.8197209 -0.37802163 -0.71477178 -1.47488888 -0.72840232 -0.83375274]\n", + " [-0.8197209 -0.37802163 -0.72865273 -1.49126094 -0.72840232 -0.83375274]\n", + " [-0.8197209 -0.37802163 -0.73713019 -1.50125979 -0.6939122 -0.38602261]] [3, 6, 5, 4, 5, 4]\n", + "[[-0.39336858 -0.1677936 0.37677439 0.31212638 -0.05409307 -0.36385253]\n", + " [-0.39336858 -0.16852957 0.37647502 0.31173676 -0.05409307 -0.36385253]\n", + " [-0.39336858 -0.16941129 0.37611637 0.31126997 -0.05409307 -0.36385253]\n", + " ...\n", + " [-0.39336858 -1.68641238 -0.24094592 -0.49183219 -0.0618552 -0.37143746]\n", + " [-0.39336858 -1.74696276 -0.26557567 -0.52388763 -0.0618552 -0.37143746]\n", + " [-0.39336858 -1.78799705 -0.28226696 -0.54561123 -0.0618552 -0.37143746]] [3, 6, 6, 3, 2, 7]\n", + "[[-0.70112291 -0.95802485 1.38959749 1.76155062 2.51005844 0.91094868]\n", + " [-0.70112291 -0.95802485 1.39309019 1.76624349 2.51005844 0.91094868]\n", + " [-0.70112291 -0.95802485 1.39646557 1.77077872 2.51005844 0.91094868]\n", + " ...\n", + " [-0.70112291 -0.95802485 -0.62365684 -0.94350128 2.63141296 0.71376734]\n", + " [-0.70112291 -0.95802485 -0.67914414 -1.01805522 2.63141296 0.71376734]\n", + " [-0.70112291 -0.95802485 -0.713688 -1.0644691 2.63141296 0.71376734]] [4, 5, 8, 1, 5, 4]\n", + "[[-0.39162345 -0.77307617 0.37867865 0.21940014 0.63658961 -0.86950231]\n", + " [-0.39162345 -0.77307617 0.37839947 0.21909086 0.63658961 -0.86950231]\n", + " [-0.39162345 -0.77307617 0.37807441 0.21873075 0.63658961 -0.86950231]\n", + " ...\n", + " [-0.39162345 -0.77307617 -0.84416418 -1.13529109 0.85064753 -0.53320699]\n", + " [-0.39162345 -0.77307617 -0.87837299 -1.17318833 0.85064753 -0.53320699]\n", + " [-0.39162345 -0.77307617 -0.89966981 -1.19678141 0.85064753 -0.53320699]] [5, 4, 3, 6, 5, 4]\n", + "[[ 0.67573725 -0.81465401 -0.55765161 -0.24191291 0.11970908 -0.70480243]\n", + " [ 0.67526792 -0.81465401 -0.55765161 -0.24281012 0.11970908 -0.70480243]\n", + " [ 0.67473702 -0.81465401 -0.55765161 -0.24382503 0.11970908 -0.70480243]\n", + " ...\n", + " [-0.88682589 -0.81465401 -0.55765161 -3.22902427 0.19913457 0.3208753 ]\n", + " [-0.9121858 -0.81465401 -0.55765161 -3.27750416 0.19913457 0.3208753 ]\n", + " [-0.92767402 -0.81465401 -0.55765161 -3.30711257 0.22318534 0.76601815]] [4, 5, 3, 6, 5, 4]\n", + "[[-0.49583169 0.34641847 1.94129529 1.46273709 0.18874347 2.1760096 ]\n", + " [-0.49583169 0.3476835 1.94386122 1.46571892 0.18874347 2.1760096 ]\n", + " [-0.49583169 0.34875708 1.94603882 1.4682495 0.18874347 2.1760096 ]\n", + " ...\n", + " [-0.49583169 -0.86448159 -0.51483512 -1.39151191 0.10361879 1.29895904]\n", + " [-0.49583169 -0.91026036 -0.60769054 -1.49941843 0.10361879 1.29895904]\n", + " [-0.49583169 -0.94128389 -0.67061714 -1.57254492 0.10361879 1.29895904]] [5, 4, 3, 6, 6, 3]\n", + "[array([[-0.73833905, -0.14152196, -0.35304517, -0.5694006 , -0.41008208,\n", + " -0.11232905],\n", + " [-0.73833905, -0.14164066, -0.35324653, -0.56959656, -0.41008208,\n", + " -0.11232905],\n", + " [-0.73833905, -0.14176103, -0.35345072, -0.56979527, -0.41342335,\n", + " -0.34642073],\n", + " ...,\n", + " [-0.73833905, -0.35299514, -0.71177197, -0.91850029, -0.40122532,\n", + " 0.31341636],\n", + " [-0.73833905, -0.35869641, -0.72144317, -0.92791195, -0.40122532,\n", + " 0.31341636],\n", + " [-0.73833905, -0.36224577, -0.72746403, -0.93377123, -0.40122532,\n", + " 0.31341636]]), array([[-0.70953335, -1.24696123, 0.80366164, 0.32426513, 1.75007941,\n", + " 0.98942482],\n", + " [-0.70953335, -1.24696123, 0.80399547, 0.32467858, 1.75007941,\n", + " 0.98942482],\n", + " [-0.70953335, -1.24696123, 0.80433397, 0.32509783, 1.76582028,\n", + " 0.98942482],\n", + " ...,\n", + " [-0.70953335, -1.24696123, -0.58801131, -1.39938764, 1.79806729,\n", + " 0.42918451],\n", + " [-0.70953335, -1.24696123, -0.64469193, -1.46958926, 1.79806729,\n", + " 0.42918451],\n", + " [-0.70953335, -1.24696123, -0.68310364, -1.51716398, 1.79806729,\n", + " 0.42918451]]), array([[-0.45775858, -0.49249489, -0.32148183, -0.61943689, -0.54651911,\n", + " 0.00407354],\n", + " [-0.45775858, -0.49249489, -0.32139901, -0.61934213, -0.54651911,\n", + " 0.00407354],\n", + " [-0.45775858, -0.49249489, -0.32132185, -0.61925384, -0.54651911,\n", + " 0.00407354],\n", + " ...,\n", + " [-0.45775858, -0.49249489, -0.71650663, -1.07141423, -0.58342324,\n", + " -0.48046198],\n", + " [-0.45775858, -0.49249489, -0.72670863, -1.0830871 , -0.58342324,\n", + " -0.48046198],\n", + " [-0.45775858, -0.49249489, -0.73306012, -1.09035432, -0.58342324,\n", + " -0.48046198]]), array([[ 1.30598423, -1.28489274, -0.95294996, -0.59131444, -0.11004319,\n", + " -0.35756886],\n", + " [ 1.30544086, -1.28489274, -0.95294996, -0.59170509, -0.11004319,\n", + " -0.35756886],\n", + " [ 1.30469694, -1.28489274, -0.95294996, -0.59223993, -0.11004319,\n", + " -0.35756886],\n", + " ...,\n", + " [-0.47438622, -1.28489274, -0.95294996, -1.87129372, -0.04617476,\n", + " -0.82399449],\n", + " [-0.52386171, -1.28489274, -0.95294996, -1.90686362, -0.04617476,\n", + " -0.82399449],\n", + " [-0.5546629 , -1.28489274, -0.95294996, -1.92900783, -0.04617476,\n", + " -0.82399449]]), array([[-0.77336447, 0.83714232, -0.02652395, -0.84068222, -0.31514096,\n", + " -0.45631725],\n", + " [-0.77336447, 0.83719004, -0.02650423, -0.84064236, -0.31514096,\n", + " -0.45631725],\n", + " [-0.77336447, 0.83722937, -0.02648796, -0.8406095 , -0.31514096,\n", + " -0.45631725],\n", + " ...,\n", + " [-0.77336447, -1.03465463, -0.80044849, -2.40429892, -0.15185466,\n", + " 0.00683089],\n", + " [-0.77336447, -1.11527661, -0.83378294, -2.47164697, -0.15185466,\n", + " 0.00683089],\n", + " [-0.77336447, -1.16991318, -0.85637331, -2.51728795, -0.15185466,\n", + " 0.00683089]]), array([[-0.76625021, -0.21085965, 0.30426279, -0.59697468, -0.28227719,\n", + " -0.55859043],\n", + " [-0.76625021, -0.21085965, 0.30388242, -0.59733074, -0.28227719,\n", + " -0.55859043],\n", + " [-0.76625021, -0.21085965, 0.30351743, -0.5976724 , -0.28227719,\n", + " -0.55859043],\n", + " ...,\n", + " [-0.76625021, -0.21085965, -0.73297402, -1.56792586, -0.29692441,\n", + " -1.03858709],\n", + " [-0.76625021, -0.21085965, -0.74910391, -1.58302495, -0.29692441,\n", + " -1.03858709],\n", + " [-0.76625021, -0.21085965, -0.75895844, -1.59224972, -0.29364569,\n", + " -0.38492619]]), array([[-1.04641915, -0.20380784, 0.46457631, -1.11356293, -0.78146525,\n", + " -0.57856754],\n", + " [-1.04641915, -0.20380784, 0.46458987, -1.11355069, -0.78146525,\n", + " -0.57856754],\n", + " [-1.04641915, -0.20380784, 0.46459897, -1.11354247, -0.78146525,\n", + " -0.57856754],\n", + " ...,\n", + " [-1.04641915, -0.20380784, -0.6299704 , -2.10203693, -0.73427062,\n", + " -0.65765889],\n", + " [-1.04641915, -0.20380784, -0.65953927, -2.12874028, -0.73427062,\n", + " -0.65765889],\n", + " [-1.04641915, -0.20380784, -0.6779475 , -2.14536456, -0.73427062,\n", + " -0.65765889]]), array([[-0.4662164 , -0.92535848, 0.25824924, -0.46050554, -0.37632731,\n", + " -0.85179158],\n", + " [-0.4662164 , -0.92535848, 0.25683678, -0.46197126, -0.37632731,\n", + " -0.85179158],\n", + " [-0.4662164 , -0.92535848, 0.25569732, -0.46315368, -0.37632731,\n", + " -0.85179158],\n", + " ...,\n", + " [-0.4662164 , -0.92535848, -0.92033216, -1.6835251 , -0.31534675,\n", + " -1.24232196],\n", + " [-0.4662164 , -0.92535848, -0.96908416, -1.73411528, -0.31534675,\n", + " -1.24232196],\n", + " [-0.4662164 , -0.92535848, -1.00212256, -1.76839939, -0.31534675,\n", + " -1.24232196]]), array([[-0.22565103, -0.52872559, 0.68793776, 0.5577097 , -0.55639546,\n", + " -1.75613221],\n", + " [-0.22565103, -0.52928228, 0.68686071, 0.55636073, -0.55639546,\n", + " -1.75613221],\n", + " [-0.22565103, -0.529847 , 0.68576813, 0.55499232, -0.5784538 ,\n", + " -1.47606102],\n", + " ...,\n", + " [-0.22565103, -1.28516026, -0.77555202, -1.27526296, -0.6496459 ,\n", + " -1.16293839],\n", + " [-0.22565103, -1.31437712, -0.83207848, -1.34606048, -0.6496459 ,\n", + " -1.16293839],\n", + " [-0.22565103, -1.33417699, -0.87038569, -1.394039 , -0.65453774,\n", + " -1.16293839]]), array([[ 0.62743838, 0.23976437, -0.46310077, 0.11777361, 0.55671807,\n", + " -1.16009444],\n", + " [ 0.6252212 , 0.23672288, -0.46310077, 0.11543644, 0.55671807,\n", + " -1.16009444],\n", + " [ 0.62331947, 0.23411411, -0.46310077, 0.11343178, 0.55671807,\n", + " -1.16009444],\n", + " ...,\n", + " [-0.82841518, -1.75735736, -0.46310077, -1.41687342, 0.10305765,\n", + " -1.50787441],\n", + " [-0.875806 , -1.82236749, -0.46310077, -1.46682912, 0.10305765,\n", + " -1.50787441],\n", + " [-0.90792207, -1.86642391, -0.46310077, -1.50068336, 0.10305765,\n", + " -1.50787441]]), array([[-0.14585094, -0.59648574, -0.4026451 , -0.64208217, -1.06833399,\n", + " 0.59287188],\n", + " [-0.14576973, -0.59620239, -0.4026451 , -0.64195377, -1.06833399,\n", + " 0.59287188],\n", + " [-0.14567293, -0.59586462, -0.4026451 , -0.6418007 , -1.06833399,\n", + " 0.59287188],\n", + " ...,\n", + " [-0.50682212, -1.85604045, -0.4026451 , -1.21285978, -1.07844381,\n", + " -1.49617308],\n", + " [-0.512524 , -1.8759363 , -0.4026451 , -1.22187575, -1.07844381,\n", + " -1.49617308],\n", + " [-0.51600644, -1.88808774, -0.4026451 , -1.22738228, -1.09930667,\n", + " -0.37006732]]), array([[-0.35862825, -0.35714423, -0.39331462, -0.70993826, -0.37556131,\n", + " -0.28470072],\n", + " [-0.35849929, -0.35714423, -0.39331462, -0.70985678, -0.37556131,\n", + " -0.28470072],\n", + " [-0.35839306, -0.35714423, -0.39331462, -0.70978965, -0.37556131,\n", + " -0.28470072],\n", + " ...,\n", + " [-0.88192112, -0.35714423, -0.39331462, -1.04058493, -0.37906982,\n", + " 0.15335923],\n", + " [-0.90278179, -0.35714423, -0.39331462, -1.0537659 , -0.37906982,\n", + " 0.15335923],\n", + " [-0.91691875, -0.35714423, -0.39331462, -1.06269845, -0.37906982,\n", + " 0.15335923]]), array([[-0.19698857, -0.73098187, 0.92686605, 0.47388338, 0.24309849,\n", + " -0.80674761],\n", + " [-0.19698857, -0.73098187, 0.92686605, 0.47388338, 0.24309849,\n", + " -0.80674761],\n", + " [-0.19698857, -0.73098187, 0.92686605, 0.47388338, 0.24309849,\n", + " -0.80674761],\n", + " ...,\n", + " [-0.19698857, -0.73098187, -0.28154858, -0.35253519, 0.24177257,\n", + " -0.71297856],\n", + " [-0.19698857, -0.73098187, -0.30076571, -0.36567753, 0.24177257,\n", + " -0.71297856],\n", + " [-0.19698857, -0.73098187, -0.31250238, -0.37370408, 0.24154317,\n", + " -0.80674761]]), array([[-0.79352257, 1.92045095, 0.22047343, 0.19939002, 0.39408206,\n", + " -3.64419815],\n", + " [-0.79352257, 1.91778509, 0.21951147, 0.19754164, 0.39408206,\n", + " -3.64419815],\n", + " [-0.79352257, 1.91512947, 0.21855321, 0.19570036, 0.39408206,\n", + " -3.64419815],\n", + " ...,\n", + " [-0.79352257, -0.43667077, -0.63007766, -1.43492797, 0.48893761,\n", + " -0.98543255],\n", + " [-0.79352257, -0.50089675, -0.65325317, -1.47945926, 0.48893761,\n", + " -0.98543255],\n", + " [-0.79352257, -0.54088346, -0.6676821 , -1.50718418, 0.48893761,\n", + " -0.98543255]]), array([[-0.27237369, -0.54063198, -0.50811608, -0.72395271, -0.68405473,\n", + " 0.33905917],\n", + " [-0.27237369, -0.54063198, -0.50730244, -0.72304336, -0.68405473,\n", + " 0.33905917],\n", + " [-0.27237369, -0.54063198, -0.50658053, -0.72223654, -0.68405473,\n", + " 0.33905917],\n", + " ...,\n", + " [-0.27237369, -0.54063198, -1.40019084, -1.72095853, -0.66467772,\n", + " -0.95228849],\n", + " [-0.27237369, -0.54063198, -1.41464306, -1.73711071, -0.66467772,\n", + " -0.95228849],\n", + " [-0.27237369, -0.54063198, -1.4234696 , -1.74697548, -0.66324022,\n", + " -0.59486012]]), array([[-0.8197209 , -0.37802163, 0.11338408, -0.49811041, -0.80759158,\n", + " -1.6838975 ],\n", + " [-0.8197209 , -0.37802163, 0.11338337, -0.49811124, -0.80759158,\n", + " -1.6838975 ],\n", + " [-0.8197209 , -0.37802163, 0.11338152, -0.49811342, -0.80759158,\n", + " -1.6838975 ],\n", + " ...,\n", + " [-0.8197209 , -0.37802163, -0.71477178, -1.47488888, -0.72840232,\n", + " -0.83375274],\n", + " [-0.8197209 , -0.37802163, -0.72865273, -1.49126094, -0.72840232,\n", + " -0.83375274],\n", + " [-0.8197209 , -0.37802163, -0.73713019, -1.50125979, -0.6939122 ,\n", + " -0.38602261]]), array([[-0.39336858, -0.1677936 , 0.37677439, 0.31212638, -0.05409307,\n", + " -0.36385253],\n", + " [-0.39336858, -0.16852957, 0.37647502, 0.31173676, -0.05409307,\n", + " -0.36385253],\n", + " [-0.39336858, -0.16941129, 0.37611637, 0.31126997, -0.05409307,\n", + " -0.36385253],\n", + " ...,\n", + " [-0.39336858, -1.68641238, -0.24094592, -0.49183219, -0.0618552 ,\n", + " -0.37143746],\n", + " [-0.39336858, -1.74696276, -0.26557567, -0.52388763, -0.0618552 ,\n", + " -0.37143746],\n", + " [-0.39336858, -1.78799705, -0.28226696, -0.54561123, -0.0618552 ,\n", + " -0.37143746]]), array([[-0.70112291, -0.95802485, 1.38959749, 1.76155062, 2.51005844,\n", + " 0.91094868],\n", + " [-0.70112291, -0.95802485, 1.39309019, 1.76624349, 2.51005844,\n", + " 0.91094868],\n", + " [-0.70112291, -0.95802485, 1.39646557, 1.77077872, 2.51005844,\n", + " 0.91094868],\n", + " ...,\n", + " [-0.70112291, -0.95802485, -0.62365684, -0.94350128, 2.63141296,\n", + " 0.71376734],\n", + " [-0.70112291, -0.95802485, -0.67914414, -1.01805522, 2.63141296,\n", + " 0.71376734],\n", + " [-0.70112291, -0.95802485, -0.713688 , -1.0644691 , 2.63141296,\n", + " 0.71376734]]), array([[-0.39162345, -0.77307617, 0.37867865, 0.21940014, 0.63658961,\n", + " -0.86950231],\n", + " [-0.39162345, -0.77307617, 0.37839947, 0.21909086, 0.63658961,\n", + " -0.86950231],\n", + " [-0.39162345, -0.77307617, 0.37807441, 0.21873075, 0.63658961,\n", + " -0.86950231],\n", + " ...,\n", + " [-0.39162345, -0.77307617, -0.84416418, -1.13529109, 0.85064753,\n", + " -0.53320699],\n", + " [-0.39162345, -0.77307617, -0.87837299, -1.17318833, 0.85064753,\n", + " -0.53320699],\n", + " [-0.39162345, -0.77307617, -0.89966981, -1.19678141, 0.85064753,\n", + " -0.53320699]]), array([[ 0.67573725, -0.81465401, -0.55765161, -0.24191291, 0.11970908,\n", + " -0.70480243],\n", + " [ 0.67526792, -0.81465401, -0.55765161, -0.24281012, 0.11970908,\n", + " -0.70480243],\n", + " [ 0.67473702, -0.81465401, -0.55765161, -0.24382503, 0.11970908,\n", + " -0.70480243],\n", + " ...,\n", + " [-0.88682589, -0.81465401, -0.55765161, -3.22902427, 0.19913457,\n", + " 0.3208753 ],\n", + " [-0.9121858 , -0.81465401, -0.55765161, -3.27750416, 0.19913457,\n", + " 0.3208753 ],\n", + " [-0.92767402, -0.81465401, -0.55765161, -3.30711257, 0.22318534,\n", + " 0.76601815]]), array([[-0.49583169, 0.34641847, 1.94129529, 1.46273709, 0.18874347,\n", + " 2.1760096 ],\n", + " [-0.49583169, 0.3476835 , 1.94386122, 1.46571892, 0.18874347,\n", + " 2.1760096 ],\n", + " [-0.49583169, 0.34875708, 1.94603882, 1.4682495 , 0.18874347,\n", + " 2.1760096 ],\n", + " ...,\n", + " [-0.49583169, -0.86448159, -0.51483512, -1.39151191, 0.10361879,\n", + " 1.29895904],\n", + " [-0.49583169, -0.91026036, -0.60769054, -1.49941843, 0.10361879,\n", + " 1.29895904],\n", + " [-0.49583169, -0.94128389, -0.67061714, -1.57254492, 0.10361879,\n", + " 1.29895904]])] [[4, 5, 4, 5, 4, 5], [1, 8, 3, 6, 4, 5], [8, 1, 8, 1, 2, 7], [2, 7, 5, 4, 4, 5], [4, 5, 6, 3, 4, 5], [8, 1, 5, 4, 3, 6], [3, 6, 5, 4, 4, 5], [2, 7, 3, 6, 4, 5], [7, 2, 4, 5, 7, 2], [2, 7, 6, 3, 2, 7], [3, 6, 5, 4, 1, 8], [5, 4, 6, 3, 7, 2], [9, 0, 7, 2, 1, 8], [6, 3, 2, 7, 5, 4], [7, 2, 5, 4, 6, 3], [3, 6, 5, 4, 5, 4], [3, 6, 6, 3, 2, 7], [4, 5, 8, 1, 5, 4], [5, 4, 3, 6, 5, 4], [4, 5, 3, 6, 5, 4], [5, 4, 3, 6, 6, 3]]\n" + ] + } + ], + "source": [ + "import scipy.io as sio\n", + "import numpy as np\n", + "\n", + "'''a=sio.loadmat('s_2_e_1.mat')\n", + "b=sio.loadmat('s_2_e_2.mat')\n", + "c=sio.loadmat('s_2_e_3.mat')\n", + "print(a.keys())\n", + "for i in range(0,len(a[\"y\"])):\n", + " print(a['y'][i],b['y'][i],c['y'][i],)\n", + "'''\n", + "X=[]\n", + "O=[]\n", + "for c in range(1,23):\n", + " if c==17:\n", + " continue\n", + " a1=sio.loadmat('s_%d_e_1.mat'%(c))\n", + " a2=sio.loadmat('s_%d_e_2.mat'%(c))\n", + " a3=sio.loadmat('s_%d_e_3.mat'%(c))\n", + " \n", + " lenth=len(a1[\"y\"])\n", + " x=a1[\"x\"][lenth-1]\n", + " #print(a1[\"y\"][lenth-1],a2[\"y\"][lenth-1],a3[\"y\"][lenth-1])\n", + " o1=a1[\"y\"][lenth-1][0]\n", + " o2=a1[\"y\"][lenth-1][1]\n", + " o3=a2[\"y\"][lenth-1][0]\n", + " o4=a2[\"y\"][lenth-1][1]\n", + " o5=a3[\"y\"][lenth-1][0]\n", + " o6=a3[\"y\"][lenth-1][1]\n", + " o1=int(o1*10)\n", + " o2=int(o2*10)\n", + " o3=int(o3*10)\n", + " o4=int(o4*10)\n", + " o5=int(o5*10)\n", + " o6=int(o6*10)\n", + " o=[o1,o2,o3,o4,o5,o6]\n", + " X.append(x)\n", + " O.append(o)\n", + " print(x,o)\n", + "print(X,O)\n", + "sio.savemat('result.mat', {'X':X,\"Y\":O,}) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "interpreter": { + "hash": "60d0a6506e690051ad38d772e690e6e667f43736fc28c7db55a8b02aefc08bd8" + }, + "kernelspec": { + "display_name": "Python 3.8.5 ('base')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/code/CNN/CNN_classifier.py b/code/CNN/CNN_classifier.py new file mode 100644 index 0000000..8f2e095 --- /dev/null +++ b/code/CNN/CNN_classifier.py @@ -0,0 +1,188 @@ +# Load the libraries +from keras.callbacks import ModelCheckpoint, EarlyStopping +import time +from sklearn.metrics import f1_score, precision_score, recall_score +import gc +from keras import backend as K +import tensorflow as tf +from sklearn.utils import class_weight +import os +import scipy.io as sio +from tensorflow.keras.utils import to_categorical +from keras.layers import Dense, Dropout, Input, Conv2D, BatchNormalization, Activation, MaxPooling2D, GlobalAveragePooling2D +from keras.models import Model +import numpy as np +#from keras.optimizers import Adam +from keras.optimizers import adam_v2 +adam = adam_v2.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08) +#from keras.utils import to_categorical + +t1 = time.time() +np.random.seed(50) # 200 +HIST_VLoss, HIST_TLoss, HIST_VAcc, HIST_TAcc, pred, SUB, P_score, F_score, R_score, K_val, Len = [ +], [], [], [], [], [], [], [], [], [], [] +'''这些是结果''' +for iterate in range(5): # 5 + # set the seed constant for each fold 1,200 + seed = np.random.randint(1, 20) + + for j in range(3): # each class 3 + for i in range(22): # for each subjects 22 + training = sio.loadmat( + 'Data_Class_%d_Subject_%d.mat' % ((j+1), (i+1))) + X = training['Feature'] + Y = np.ravel(training['labels']) + + identifier_val = training['identifier_val'] + '''读取数据''' + X = np.reshape(X, (X.shape[0], X.shape[1], X.shape[2], 1)) + loc = np.where(np.isin(identifier_val[:, 1], i+1))[0] + '''划分训练集验证集,下面都是数据预处理''' + # Shuffle unique trials to divide into training and validation set + loc0 = loc[np.where(Y[loc] == 0)[0]] + loc1 = loc[np.where(Y[loc] == 1)[0]] + classes = identifier_val[:, 0] + class0 = np.unique(classes[loc0]) + class1 = np.unique(classes[loc1]) + + order = np.arange(len(class0)) + np.random.seed(seed) + np.random.shuffle(order) + class0 = class0[order] + + order = np.arange(len(class1)) + np.random.seed(seed) + np.random.shuffle(order) + class1 = class1[order] + + # Divide 20% trials into validation set and rest into training set + classes = 5 + Test = np.append(class0[:classes], class1[:classes]) + Train = np.append(class0[classes:], class1[classes:]) + + test_loc = np.where(np.logical_and( + np.isin(identifier_val[:, 0], Test), np.isin(identifier_val[:, 1], i+1)))[0] + train_loc = np.where(np.logical_and( + np.isin(identifier_val[:, 0], Train), np.isin(identifier_val[:, 1], i+1)))[0] + + # Shuffle the data + order = np.arange(len(test_loc)) + np.random.seed(seed) + np.random.shuffle(order) + test_loc = test_loc[order] + + order = np.arange(len(train_loc)) + np.random.seed(seed) + np.random.shuffle(order) + train_loc2 = train_loc[order] + + # To account for class imbalance + class_weights = class_weight.compute_class_weight('balanced', + np.unique( + Y[train_loc2]), + Y[train_loc2]) + print(class_weights) + if len(class_weights) == 1: + #sio.savemat('temp_%d.mat' %(iterate),{'iterate': iterate,'Y':Y,'train_loc':train_loc,'train_loc2':train_loc2,'npunique': Y[train_loc2]}) + continue + + class_weights = {i: class_weights[i] + for i in range(len(class_weights))} + # class_weights = {i: class_weights[i] for i in range(2)} + # Divide into training and validation set + X_train = X[train_loc] + X_valid = X[test_loc] + y_train = Y[train_loc] + y_valid = Y[test_loc] + Y_train = to_categorical(y_train) + Y_valid = to_categorical(y_valid) + '''开始初始化模型''' + # %% Define model architecture + ch = 6 # Number of features + act = 'tanh' # Activation function + dense_num_units, cnn_units = 16, 8 # Model architecture hyper parameters + bn_axis = 3 # batch normalization dimension + + input_shape = (320, ch, 1) + Allinput_img = Input(shape=(input_shape)) + x_c = Conv2D(cnn_units, kernel_size=(2, 1), strides=( + 2, 1), name='convch_0', use_bias=True)(Allinput_img) # 卷积 + x_c = BatchNormalization(axis=3)(x_c) # 分批 + x_c = Activation(act)(x_c) # 激活 + x_c = MaxPooling2D(pool_size=(2, 1))(x_c) # 最大池化 + '''x_c一层层往下传递,就是神经网络那种''' + x_c = Conv2D(cnn_units, kernel_size=(2, 1), + strides=(2, 1), use_bias=True)(x_c) + x_c = BatchNormalization(axis=3)(x_c) + x_c = Activation(act)(x_c) + x_c = MaxPooling2D(pool_size=(2, 1))(x_c) + + x_c = Conv2D(cnn_units, kernel_size=(2, ch), + strides=(2, 1), use_bias=True)(x_c) + x_c = BatchNormalization(axis=3)(x_c) + x_c = Activation(act)(x_c) + x_c = MaxPooling2D(pool_size=(2, 1))(x_c) + '''多整几层''' + x_c = GlobalAveragePooling2D()(x_c) + x = (Dropout(rate=0.5, name='Drop_D2'))(x_c) + Out = (Dense(2, activation='softmax', name='Allenc_18'))(x) + model = Model(Allinput_img, Out) + model.compile(optimizer=adam_v2.Adam(lr=0.000001), + loss='categorical_crossentropy', metrics=['accuracy']) + model.summary() + + # File to save the model weights to + Checkpoint_filename = './save_weights_CNN_RR.hdf5' + '''保存临时文件''' + # Callbacks to perfrorm early stopping if model does not improve in successive 5 epochs and to save + # model weights only if there is improvement + callback_array = [EarlyStopping(monitor='val_loss', min_delta=0, patience=5, verbose=1, mode='auto'), + ModelCheckpoint(Checkpoint_filename, monitor='val_loss', verbose=1, save_best_only=True, mode='auto')] + '''epoch越大,训练次数越多,越慢''' + # Train the model + history = model.fit(X_train, Y_train, + epochs=10, validation_data=(X_valid, Y_valid), + batch_size=32, verbose=1, callbacks=callback_array, shuffle=True, class_weight=class_weights) + model.load_weights(Checkpoint_filename) + print('Iteration number: %d; Condition: %d; Subject: %d' % + (iterate, j+1, i+1)) + + # Compute the performance metrics + y_pred = np.argmax(model.predict(X_valid), axis=1) + y_pred2 = model.predict(X_valid) + F_score.append([f1_score(y_valid, y_pred, average="macro"), f1_score( + y_valid, y_pred, average="micro"), f1_score(y_valid, y_pred, average="weighted")]) + P_score.append([precision_score(y_valid, y_pred, average="macro"), precision_score( + y_valid, y_pred, average="micro"), precision_score(y_valid, y_pred, average="weighted")]) + R_score.append([recall_score(y_valid, y_pred, average="macro"), recall_score( + y_valid, y_pred, average="micro"), recall_score(y_valid, y_pred, average="weighted")]) + print(y_pred) + print(y_pred2) + hist_dict = history.history + print('====histore_dict----') + print(hist_dict.keys()) + + HIST_VLoss.append(history.history['val_loss']) + HIST_TLoss.append(history.history['loss']) + # HIST_VAcc.append(history.history['val_acc']) #报错就改成后面这个 val_accuracy + HIST_VAcc.append(history.history['val_accuracy']) + # HIST_TAcc.append(history.history['acc']) #报错就改成后面这个 accuracy + HIST_TAcc.append(history.history['accuracy']) + pred.append(model.evaluate(X_valid, Y_valid)) + SUB.append(i) + Len.append(len(history.history['val_loss'])) + + # Reset the variables and delete the model + os.remove(Checkpoint_filename) + K.clear_session() + gc.collect() + del model, training, X, Y, identifier_val, loc, loc0, loc1, order, class0, class1, Test, Train, test_loc, train_loc, class_weights, X_train, X_valid, Y_valid, Y_train + '''prf三个,是那三个参数,由预测得出,另外的,是训练过程中产生的''' + '''保存文件''' + # Save the final results + sio.savemat('Prediction_V%d.mat' % (iterate), {'HIST_VLoss': HIST_VLoss, + 'HIST_VAcc': HIST_VAcc, 'HIST_TLoss': HIST_TLoss, + 'HIST_TAcc': HIST_TAcc, 'pred': pred, 'sub': SUB, + 'F_score': F_score, 'P_score': P_score, 'R_score': R_score, 'Length': Len, }) +t2 = time.time() +print(t2-t1)