Browse Source

Upload files to 'code/CNN'

master
何雨晴 3 years ago
parent
commit
7406aa8a86
2 changed files with 686 additions and 0 deletions
  1. +498
    -0
      code/CNN/3.ipynb
  2. +188
    -0
      code/CNN/CNN_classifier.py

+ 498
- 0
code/CNN/3.ipynb View File

@ -0,0 +1,498 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[-0.73833905 -0.14152196 -0.35304517 -0.5694006 -0.41008208 -0.11232905]\n",
" [-0.73833905 -0.14164066 -0.35324653 -0.56959656 -0.41008208 -0.11232905]\n",
" [-0.73833905 -0.14176103 -0.35345072 -0.56979527 -0.41342335 -0.34642073]\n",
" ...\n",
" [-0.73833905 -0.35299514 -0.71177197 -0.91850029 -0.40122532 0.31341636]\n",
" [-0.73833905 -0.35869641 -0.72144317 -0.92791195 -0.40122532 0.31341636]\n",
" [-0.73833905 -0.36224577 -0.72746403 -0.93377123 -0.40122532 0.31341636]] [4, 5, 4, 5, 4, 5]\n",
"[[-0.70953335 -1.24696123 0.80366164 0.32426513 1.75007941 0.98942482]\n",
" [-0.70953335 -1.24696123 0.80399547 0.32467858 1.75007941 0.98942482]\n",
" [-0.70953335 -1.24696123 0.80433397 0.32509783 1.76582028 0.98942482]\n",
" ...\n",
" [-0.70953335 -1.24696123 -0.58801131 -1.39938764 1.79806729 0.42918451]\n",
" [-0.70953335 -1.24696123 -0.64469193 -1.46958926 1.79806729 0.42918451]\n",
" [-0.70953335 -1.24696123 -0.68310364 -1.51716398 1.79806729 0.42918451]] [1, 8, 3, 6, 4, 5]\n",
"[[-0.45775858 -0.49249489 -0.32148183 -0.61943689 -0.54651911 0.00407354]\n",
" [-0.45775858 -0.49249489 -0.32139901 -0.61934213 -0.54651911 0.00407354]\n",
" [-0.45775858 -0.49249489 -0.32132185 -0.61925384 -0.54651911 0.00407354]\n",
" ...\n",
" [-0.45775858 -0.49249489 -0.71650663 -1.07141423 -0.58342324 -0.48046198]\n",
" [-0.45775858 -0.49249489 -0.72670863 -1.0830871 -0.58342324 -0.48046198]\n",
" [-0.45775858 -0.49249489 -0.73306012 -1.09035432 -0.58342324 -0.48046198]] [8, 1, 8, 1, 2, 7]\n",
"[[ 1.30598423 -1.28489274 -0.95294996 -0.59131444 -0.11004319 -0.35756886]\n",
" [ 1.30544086 -1.28489274 -0.95294996 -0.59170509 -0.11004319 -0.35756886]\n",
" [ 1.30469694 -1.28489274 -0.95294996 -0.59223993 -0.11004319 -0.35756886]\n",
" ...\n",
" [-0.47438622 -1.28489274 -0.95294996 -1.87129372 -0.04617476 -0.82399449]\n",
" [-0.52386171 -1.28489274 -0.95294996 -1.90686362 -0.04617476 -0.82399449]\n",
" [-0.5546629 -1.28489274 -0.95294996 -1.92900783 -0.04617476 -0.82399449]] [2, 7, 5, 4, 4, 5]\n",
"[[-0.77336447 0.83714232 -0.02652395 -0.84068222 -0.31514096 -0.45631725]\n",
" [-0.77336447 0.83719004 -0.02650423 -0.84064236 -0.31514096 -0.45631725]\n",
" [-0.77336447 0.83722937 -0.02648796 -0.8406095 -0.31514096 -0.45631725]\n",
" ...\n",
" [-0.77336447 -1.03465463 -0.80044849 -2.40429892 -0.15185466 0.00683089]\n",
" [-0.77336447 -1.11527661 -0.83378294 -2.47164697 -0.15185466 0.00683089]\n",
" [-0.77336447 -1.16991318 -0.85637331 -2.51728795 -0.15185466 0.00683089]] [4, 5, 6, 3, 4, 5]\n",
"[[-0.76625021 -0.21085965 0.30426279 -0.59697468 -0.28227719 -0.55859043]\n",
" [-0.76625021 -0.21085965 0.30388242 -0.59733074 -0.28227719 -0.55859043]\n",
" [-0.76625021 -0.21085965 0.30351743 -0.5976724 -0.28227719 -0.55859043]\n",
" ...\n",
" [-0.76625021 -0.21085965 -0.73297402 -1.56792586 -0.29692441 -1.03858709]\n",
" [-0.76625021 -0.21085965 -0.74910391 -1.58302495 -0.29692441 -1.03858709]\n",
" [-0.76625021 -0.21085965 -0.75895844 -1.59224972 -0.29364569 -0.38492619]] [8, 1, 5, 4, 3, 6]\n",
"[[-1.04641915 -0.20380784 0.46457631 -1.11356293 -0.78146525 -0.57856754]\n",
" [-1.04641915 -0.20380784 0.46458987 -1.11355069 -0.78146525 -0.57856754]\n",
" [-1.04641915 -0.20380784 0.46459897 -1.11354247 -0.78146525 -0.57856754]\n",
" ...\n",
" [-1.04641915 -0.20380784 -0.6299704 -2.10203693 -0.73427062 -0.65765889]\n",
" [-1.04641915 -0.20380784 -0.65953927 -2.12874028 -0.73427062 -0.65765889]\n",
" [-1.04641915 -0.20380784 -0.6779475 -2.14536456 -0.73427062 -0.65765889]] [3, 6, 5, 4, 4, 5]\n",
"[[-0.4662164 -0.92535848 0.25824924 -0.46050554 -0.37632731 -0.85179158]\n",
" [-0.4662164 -0.92535848 0.25683678 -0.46197126 -0.37632731 -0.85179158]\n",
" [-0.4662164 -0.92535848 0.25569732 -0.46315368 -0.37632731 -0.85179158]\n",
" ...\n",
" [-0.4662164 -0.92535848 -0.92033216 -1.6835251 -0.31534675 -1.24232196]\n",
" [-0.4662164 -0.92535848 -0.96908416 -1.73411528 -0.31534675 -1.24232196]\n",
" [-0.4662164 -0.92535848 -1.00212256 -1.76839939 -0.31534675 -1.24232196]] [2, 7, 3, 6, 4, 5]\n",
"[[-0.22565103 -0.52872559 0.68793776 0.5577097 -0.55639546 -1.75613221]\n",
" [-0.22565103 -0.52928228 0.68686071 0.55636073 -0.55639546 -1.75613221]\n",
" [-0.22565103 -0.529847 0.68576813 0.55499232 -0.5784538 -1.47606102]\n",
" ...\n",
" [-0.22565103 -1.28516026 -0.77555202 -1.27526296 -0.6496459 -1.16293839]\n",
" [-0.22565103 -1.31437712 -0.83207848 -1.34606048 -0.6496459 -1.16293839]\n",
" [-0.22565103 -1.33417699 -0.87038569 -1.394039 -0.65453774 -1.16293839]] [7, 2, 4, 5, 7, 2]\n",
"[[ 0.62743838 0.23976437 -0.46310077 0.11777361 0.55671807 -1.16009444]\n",
" [ 0.6252212 0.23672288 -0.46310077 0.11543644 0.55671807 -1.16009444]\n",
" [ 0.62331947 0.23411411 -0.46310077 0.11343178 0.55671807 -1.16009444]\n",
" ...\n",
" [-0.82841518 -1.75735736 -0.46310077 -1.41687342 0.10305765 -1.50787441]\n",
" [-0.875806 -1.82236749 -0.46310077 -1.46682912 0.10305765 -1.50787441]\n",
" [-0.90792207 -1.86642391 -0.46310077 -1.50068336 0.10305765 -1.50787441]] [2, 7, 6, 3, 2, 7]\n",
"[[-0.14585094 -0.59648574 -0.4026451 -0.64208217 -1.06833399 0.59287188]\n",
" [-0.14576973 -0.59620239 -0.4026451 -0.64195377 -1.06833399 0.59287188]\n",
" [-0.14567293 -0.59586462 -0.4026451 -0.6418007 -1.06833399 0.59287188]\n",
" ...\n",
" [-0.50682212 -1.85604045 -0.4026451 -1.21285978 -1.07844381 -1.49617308]\n",
" [-0.512524 -1.8759363 -0.4026451 -1.22187575 -1.07844381 -1.49617308]\n",
" [-0.51600644 -1.88808774 -0.4026451 -1.22738228 -1.09930667 -0.37006732]] [3, 6, 5, 4, 1, 8]\n",
"[[-0.35862825 -0.35714423 -0.39331462 -0.70993826 -0.37556131 -0.28470072]\n",
" [-0.35849929 -0.35714423 -0.39331462 -0.70985678 -0.37556131 -0.28470072]\n",
" [-0.35839306 -0.35714423 -0.39331462 -0.70978965 -0.37556131 -0.28470072]\n",
" ...\n",
" [-0.88192112 -0.35714423 -0.39331462 -1.04058493 -0.37906982 0.15335923]\n",
" [-0.90278179 -0.35714423 -0.39331462 -1.0537659 -0.37906982 0.15335923]\n",
" [-0.91691875 -0.35714423 -0.39331462 -1.06269845 -0.37906982 0.15335923]] [5, 4, 6, 3, 7, 2]\n",
"[[-0.19698857 -0.73098187 0.92686605 0.47388338 0.24309849 -0.80674761]\n",
" [-0.19698857 -0.73098187 0.92686605 0.47388338 0.24309849 -0.80674761]\n",
" [-0.19698857 -0.73098187 0.92686605 0.47388338 0.24309849 -0.80674761]\n",
" ...\n",
" [-0.19698857 -0.73098187 -0.28154858 -0.35253519 0.24177257 -0.71297856]\n",
" [-0.19698857 -0.73098187 -0.30076571 -0.36567753 0.24177257 -0.71297856]\n",
" [-0.19698857 -0.73098187 -0.31250238 -0.37370408 0.24154317 -0.80674761]] [9, 0, 7, 2, 1, 8]\n",
"[[-0.79352257 1.92045095 0.22047343 0.19939002 0.39408206 -3.64419815]\n",
" [-0.79352257 1.91778509 0.21951147 0.19754164 0.39408206 -3.64419815]\n",
" [-0.79352257 1.91512947 0.21855321 0.19570036 0.39408206 -3.64419815]\n",
" ...\n",
" [-0.79352257 -0.43667077 -0.63007766 -1.43492797 0.48893761 -0.98543255]\n",
" [-0.79352257 -0.50089675 -0.65325317 -1.47945926 0.48893761 -0.98543255]\n",
" [-0.79352257 -0.54088346 -0.6676821 -1.50718418 0.48893761 -0.98543255]] [6, 3, 2, 7, 5, 4]\n",
"[[-0.27237369 -0.54063198 -0.50811608 -0.72395271 -0.68405473 0.33905917]\n",
" [-0.27237369 -0.54063198 -0.50730244 -0.72304336 -0.68405473 0.33905917]\n",
" [-0.27237369 -0.54063198 -0.50658053 -0.72223654 -0.68405473 0.33905917]\n",
" ...\n",
" [-0.27237369 -0.54063198 -1.40019084 -1.72095853 -0.66467772 -0.95228849]\n",
" [-0.27237369 -0.54063198 -1.41464306 -1.73711071 -0.66467772 -0.95228849]\n",
" [-0.27237369 -0.54063198 -1.4234696 -1.74697548 -0.66324022 -0.59486012]] [7, 2, 5, 4, 6, 3]\n",
"[[-0.8197209 -0.37802163 0.11338408 -0.49811041 -0.80759158 -1.6838975 ]\n",
" [-0.8197209 -0.37802163 0.11338337 -0.49811124 -0.80759158 -1.6838975 ]\n",
" [-0.8197209 -0.37802163 0.11338152 -0.49811342 -0.80759158 -1.6838975 ]\n",
" ...\n",
" [-0.8197209 -0.37802163 -0.71477178 -1.47488888 -0.72840232 -0.83375274]\n",
" [-0.8197209 -0.37802163 -0.72865273 -1.49126094 -0.72840232 -0.83375274]\n",
" [-0.8197209 -0.37802163 -0.73713019 -1.50125979 -0.6939122 -0.38602261]] [3, 6, 5, 4, 5, 4]\n",
"[[-0.39336858 -0.1677936 0.37677439 0.31212638 -0.05409307 -0.36385253]\n",
" [-0.39336858 -0.16852957 0.37647502 0.31173676 -0.05409307 -0.36385253]\n",
" [-0.39336858 -0.16941129 0.37611637 0.31126997 -0.05409307 -0.36385253]\n",
" ...\n",
" [-0.39336858 -1.68641238 -0.24094592 -0.49183219 -0.0618552 -0.37143746]\n",
" [-0.39336858 -1.74696276 -0.26557567 -0.52388763 -0.0618552 -0.37143746]\n",
" [-0.39336858 -1.78799705 -0.28226696 -0.54561123 -0.0618552 -0.37143746]] [3, 6, 6, 3, 2, 7]\n",
"[[-0.70112291 -0.95802485 1.38959749 1.76155062 2.51005844 0.91094868]\n",
" [-0.70112291 -0.95802485 1.39309019 1.76624349 2.51005844 0.91094868]\n",
" [-0.70112291 -0.95802485 1.39646557 1.77077872 2.51005844 0.91094868]\n",
" ...\n",
" [-0.70112291 -0.95802485 -0.62365684 -0.94350128 2.63141296 0.71376734]\n",
" [-0.70112291 -0.95802485 -0.67914414 -1.01805522 2.63141296 0.71376734]\n",
" [-0.70112291 -0.95802485 -0.713688 -1.0644691 2.63141296 0.71376734]] [4, 5, 8, 1, 5, 4]\n",
"[[-0.39162345 -0.77307617 0.37867865 0.21940014 0.63658961 -0.86950231]\n",
" [-0.39162345 -0.77307617 0.37839947 0.21909086 0.63658961 -0.86950231]\n",
" [-0.39162345 -0.77307617 0.37807441 0.21873075 0.63658961 -0.86950231]\n",
" ...\n",
" [-0.39162345 -0.77307617 -0.84416418 -1.13529109 0.85064753 -0.53320699]\n",
" [-0.39162345 -0.77307617 -0.87837299 -1.17318833 0.85064753 -0.53320699]\n",
" [-0.39162345 -0.77307617 -0.89966981 -1.19678141 0.85064753 -0.53320699]] [5, 4, 3, 6, 5, 4]\n",
"[[ 0.67573725 -0.81465401 -0.55765161 -0.24191291 0.11970908 -0.70480243]\n",
" [ 0.67526792 -0.81465401 -0.55765161 -0.24281012 0.11970908 -0.70480243]\n",
" [ 0.67473702 -0.81465401 -0.55765161 -0.24382503 0.11970908 -0.70480243]\n",
" ...\n",
" [-0.88682589 -0.81465401 -0.55765161 -3.22902427 0.19913457 0.3208753 ]\n",
" [-0.9121858 -0.81465401 -0.55765161 -3.27750416 0.19913457 0.3208753 ]\n",
" [-0.92767402 -0.81465401 -0.55765161 -3.30711257 0.22318534 0.76601815]] [4, 5, 3, 6, 5, 4]\n",
"[[-0.49583169 0.34641847 1.94129529 1.46273709 0.18874347 2.1760096 ]\n",
" [-0.49583169 0.3476835 1.94386122 1.46571892 0.18874347 2.1760096 ]\n",
" [-0.49583169 0.34875708 1.94603882 1.4682495 0.18874347 2.1760096 ]\n",
" ...\n",
" [-0.49583169 -0.86448159 -0.51483512 -1.39151191 0.10361879 1.29895904]\n",
" [-0.49583169 -0.91026036 -0.60769054 -1.49941843 0.10361879 1.29895904]\n",
" [-0.49583169 -0.94128389 -0.67061714 -1.57254492 0.10361879 1.29895904]] [5, 4, 3, 6, 6, 3]\n",
"[array([[-0.73833905, -0.14152196, -0.35304517, -0.5694006 , -0.41008208,\n",
" -0.11232905],\n",
" [-0.73833905, -0.14164066, -0.35324653, -0.56959656, -0.41008208,\n",
" -0.11232905],\n",
" [-0.73833905, -0.14176103, -0.35345072, -0.56979527, -0.41342335,\n",
" -0.34642073],\n",
" ...,\n",
" [-0.73833905, -0.35299514, -0.71177197, -0.91850029, -0.40122532,\n",
" 0.31341636],\n",
" [-0.73833905, -0.35869641, -0.72144317, -0.92791195, -0.40122532,\n",
" 0.31341636],\n",
" [-0.73833905, -0.36224577, -0.72746403, -0.93377123, -0.40122532,\n",
" 0.31341636]]), array([[-0.70953335, -1.24696123, 0.80366164, 0.32426513, 1.75007941,\n",
" 0.98942482],\n",
" [-0.70953335, -1.24696123, 0.80399547, 0.32467858, 1.75007941,\n",
" 0.98942482],\n",
" [-0.70953335, -1.24696123, 0.80433397, 0.32509783, 1.76582028,\n",
" 0.98942482],\n",
" ...,\n",
" [-0.70953335, -1.24696123, -0.58801131, -1.39938764, 1.79806729,\n",
" 0.42918451],\n",
" [-0.70953335, -1.24696123, -0.64469193, -1.46958926, 1.79806729,\n",
" 0.42918451],\n",
" [-0.70953335, -1.24696123, -0.68310364, -1.51716398, 1.79806729,\n",
" 0.42918451]]), array([[-0.45775858, -0.49249489, -0.32148183, -0.61943689, -0.54651911,\n",
" 0.00407354],\n",
" [-0.45775858, -0.49249489, -0.32139901, -0.61934213, -0.54651911,\n",
" 0.00407354],\n",
" [-0.45775858, -0.49249489, -0.32132185, -0.61925384, -0.54651911,\n",
" 0.00407354],\n",
" ...,\n",
" [-0.45775858, -0.49249489, -0.71650663, -1.07141423, -0.58342324,\n",
" -0.48046198],\n",
" [-0.45775858, -0.49249489, -0.72670863, -1.0830871 , -0.58342324,\n",
" -0.48046198],\n",
" [-0.45775858, -0.49249489, -0.73306012, -1.09035432, -0.58342324,\n",
" -0.48046198]]), array([[ 1.30598423, -1.28489274, -0.95294996, -0.59131444, -0.11004319,\n",
" -0.35756886],\n",
" [ 1.30544086, -1.28489274, -0.95294996, -0.59170509, -0.11004319,\n",
" -0.35756886],\n",
" [ 1.30469694, -1.28489274, -0.95294996, -0.59223993, -0.11004319,\n",
" -0.35756886],\n",
" ...,\n",
" [-0.47438622, -1.28489274, -0.95294996, -1.87129372, -0.04617476,\n",
" -0.82399449],\n",
" [-0.52386171, -1.28489274, -0.95294996, -1.90686362, -0.04617476,\n",
" -0.82399449],\n",
" [-0.5546629 , -1.28489274, -0.95294996, -1.92900783, -0.04617476,\n",
" -0.82399449]]), array([[-0.77336447, 0.83714232, -0.02652395, -0.84068222, -0.31514096,\n",
" -0.45631725],\n",
" [-0.77336447, 0.83719004, -0.02650423, -0.84064236, -0.31514096,\n",
" -0.45631725],\n",
" [-0.77336447, 0.83722937, -0.02648796, -0.8406095 , -0.31514096,\n",
" -0.45631725],\n",
" ...,\n",
" [-0.77336447, -1.03465463, -0.80044849, -2.40429892, -0.15185466,\n",
" 0.00683089],\n",
" [-0.77336447, -1.11527661, -0.83378294, -2.47164697, -0.15185466,\n",
" 0.00683089],\n",
" [-0.77336447, -1.16991318, -0.85637331, -2.51728795, -0.15185466,\n",
" 0.00683089]]), array([[-0.76625021, -0.21085965, 0.30426279, -0.59697468, -0.28227719,\n",
" -0.55859043],\n",
" [-0.76625021, -0.21085965, 0.30388242, -0.59733074, -0.28227719,\n",
" -0.55859043],\n",
" [-0.76625021, -0.21085965, 0.30351743, -0.5976724 , -0.28227719,\n",
" -0.55859043],\n",
" ...,\n",
" [-0.76625021, -0.21085965, -0.73297402, -1.56792586, -0.29692441,\n",
" -1.03858709],\n",
" [-0.76625021, -0.21085965, -0.74910391, -1.58302495, -0.29692441,\n",
" -1.03858709],\n",
" [-0.76625021, -0.21085965, -0.75895844, -1.59224972, -0.29364569,\n",
" -0.38492619]]), array([[-1.04641915, -0.20380784, 0.46457631, -1.11356293, -0.78146525,\n",
" -0.57856754],\n",
" [-1.04641915, -0.20380784, 0.46458987, -1.11355069, -0.78146525,\n",
" -0.57856754],\n",
" [-1.04641915, -0.20380784, 0.46459897, -1.11354247, -0.78146525,\n",
" -0.57856754],\n",
" ...,\n",
" [-1.04641915, -0.20380784, -0.6299704 , -2.10203693, -0.73427062,\n",
" -0.65765889],\n",
" [-1.04641915, -0.20380784, -0.65953927, -2.12874028, -0.73427062,\n",
" -0.65765889],\n",
" [-1.04641915, -0.20380784, -0.6779475 , -2.14536456, -0.73427062,\n",
" -0.65765889]]), array([[-0.4662164 , -0.92535848, 0.25824924, -0.46050554, -0.37632731,\n",
" -0.85179158],\n",
" [-0.4662164 , -0.92535848, 0.25683678, -0.46197126, -0.37632731,\n",
" -0.85179158],\n",
" [-0.4662164 , -0.92535848, 0.25569732, -0.46315368, -0.37632731,\n",
" -0.85179158],\n",
" ...,\n",
" [-0.4662164 , -0.92535848, -0.92033216, -1.6835251 , -0.31534675,\n",
" -1.24232196],\n",
" [-0.4662164 , -0.92535848, -0.96908416, -1.73411528, -0.31534675,\n",
" -1.24232196],\n",
" [-0.4662164 , -0.92535848, -1.00212256, -1.76839939, -0.31534675,\n",
" -1.24232196]]), array([[-0.22565103, -0.52872559, 0.68793776, 0.5577097 , -0.55639546,\n",
" -1.75613221],\n",
" [-0.22565103, -0.52928228, 0.68686071, 0.55636073, -0.55639546,\n",
" -1.75613221],\n",
" [-0.22565103, -0.529847 , 0.68576813, 0.55499232, -0.5784538 ,\n",
" -1.47606102],\n",
" ...,\n",
" [-0.22565103, -1.28516026, -0.77555202, -1.27526296, -0.6496459 ,\n",
" -1.16293839],\n",
" [-0.22565103, -1.31437712, -0.83207848, -1.34606048, -0.6496459 ,\n",
" -1.16293839],\n",
" [-0.22565103, -1.33417699, -0.87038569, -1.394039 , -0.65453774,\n",
" -1.16293839]]), array([[ 0.62743838, 0.23976437, -0.46310077, 0.11777361, 0.55671807,\n",
" -1.16009444],\n",
" [ 0.6252212 , 0.23672288, -0.46310077, 0.11543644, 0.55671807,\n",
" -1.16009444],\n",
" [ 0.62331947, 0.23411411, -0.46310077, 0.11343178, 0.55671807,\n",
" -1.16009444],\n",
" ...,\n",
" [-0.82841518, -1.75735736, -0.46310077, -1.41687342, 0.10305765,\n",
" -1.50787441],\n",
" [-0.875806 , -1.82236749, -0.46310077, -1.46682912, 0.10305765,\n",
" -1.50787441],\n",
" [-0.90792207, -1.86642391, -0.46310077, -1.50068336, 0.10305765,\n",
" -1.50787441]]), array([[-0.14585094, -0.59648574, -0.4026451 , -0.64208217, -1.06833399,\n",
" 0.59287188],\n",
" [-0.14576973, -0.59620239, -0.4026451 , -0.64195377, -1.06833399,\n",
" 0.59287188],\n",
" [-0.14567293, -0.59586462, -0.4026451 , -0.6418007 , -1.06833399,\n",
" 0.59287188],\n",
" ...,\n",
" [-0.50682212, -1.85604045, -0.4026451 , -1.21285978, -1.07844381,\n",
" -1.49617308],\n",
" [-0.512524 , -1.8759363 , -0.4026451 , -1.22187575, -1.07844381,\n",
" -1.49617308],\n",
" [-0.51600644, -1.88808774, -0.4026451 , -1.22738228, -1.09930667,\n",
" -0.37006732]]), array([[-0.35862825, -0.35714423, -0.39331462, -0.70993826, -0.37556131,\n",
" -0.28470072],\n",
" [-0.35849929, -0.35714423, -0.39331462, -0.70985678, -0.37556131,\n",
" -0.28470072],\n",
" [-0.35839306, -0.35714423, -0.39331462, -0.70978965, -0.37556131,\n",
" -0.28470072],\n",
" ...,\n",
" [-0.88192112, -0.35714423, -0.39331462, -1.04058493, -0.37906982,\n",
" 0.15335923],\n",
" [-0.90278179, -0.35714423, -0.39331462, -1.0537659 , -0.37906982,\n",
" 0.15335923],\n",
" [-0.91691875, -0.35714423, -0.39331462, -1.06269845, -0.37906982,\n",
" 0.15335923]]), array([[-0.19698857, -0.73098187, 0.92686605, 0.47388338, 0.24309849,\n",
" -0.80674761],\n",
" [-0.19698857, -0.73098187, 0.92686605, 0.47388338, 0.24309849,\n",
" -0.80674761],\n",
" [-0.19698857, -0.73098187, 0.92686605, 0.47388338, 0.24309849,\n",
" -0.80674761],\n",
" ...,\n",
" [-0.19698857, -0.73098187, -0.28154858, -0.35253519, 0.24177257,\n",
" -0.71297856],\n",
" [-0.19698857, -0.73098187, -0.30076571, -0.36567753, 0.24177257,\n",
" -0.71297856],\n",
" [-0.19698857, -0.73098187, -0.31250238, -0.37370408, 0.24154317,\n",
" -0.80674761]]), array([[-0.79352257, 1.92045095, 0.22047343, 0.19939002, 0.39408206,\n",
" -3.64419815],\n",
" [-0.79352257, 1.91778509, 0.21951147, 0.19754164, 0.39408206,\n",
" -3.64419815],\n",
" [-0.79352257, 1.91512947, 0.21855321, 0.19570036, 0.39408206,\n",
" -3.64419815],\n",
" ...,\n",
" [-0.79352257, -0.43667077, -0.63007766, -1.43492797, 0.48893761,\n",
" -0.98543255],\n",
" [-0.79352257, -0.50089675, -0.65325317, -1.47945926, 0.48893761,\n",
" -0.98543255],\n",
" [-0.79352257, -0.54088346, -0.6676821 , -1.50718418, 0.48893761,\n",
" -0.98543255]]), array([[-0.27237369, -0.54063198, -0.50811608, -0.72395271, -0.68405473,\n",
" 0.33905917],\n",
" [-0.27237369, -0.54063198, -0.50730244, -0.72304336, -0.68405473,\n",
" 0.33905917],\n",
" [-0.27237369, -0.54063198, -0.50658053, -0.72223654, -0.68405473,\n",
" 0.33905917],\n",
" ...,\n",
" [-0.27237369, -0.54063198, -1.40019084, -1.72095853, -0.66467772,\n",
" -0.95228849],\n",
" [-0.27237369, -0.54063198, -1.41464306, -1.73711071, -0.66467772,\n",
" -0.95228849],\n",
" [-0.27237369, -0.54063198, -1.4234696 , -1.74697548, -0.66324022,\n",
" -0.59486012]]), array([[-0.8197209 , -0.37802163, 0.11338408, -0.49811041, -0.80759158,\n",
" -1.6838975 ],\n",
" [-0.8197209 , -0.37802163, 0.11338337, -0.49811124, -0.80759158,\n",
" -1.6838975 ],\n",
" [-0.8197209 , -0.37802163, 0.11338152, -0.49811342, -0.80759158,\n",
" -1.6838975 ],\n",
" ...,\n",
" [-0.8197209 , -0.37802163, -0.71477178, -1.47488888, -0.72840232,\n",
" -0.83375274],\n",
" [-0.8197209 , -0.37802163, -0.72865273, -1.49126094, -0.72840232,\n",
" -0.83375274],\n",
" [-0.8197209 , -0.37802163, -0.73713019, -1.50125979, -0.6939122 ,\n",
" -0.38602261]]), array([[-0.39336858, -0.1677936 , 0.37677439, 0.31212638, -0.05409307,\n",
" -0.36385253],\n",
" [-0.39336858, -0.16852957, 0.37647502, 0.31173676, -0.05409307,\n",
" -0.36385253],\n",
" [-0.39336858, -0.16941129, 0.37611637, 0.31126997, -0.05409307,\n",
" -0.36385253],\n",
" ...,\n",
" [-0.39336858, -1.68641238, -0.24094592, -0.49183219, -0.0618552 ,\n",
" -0.37143746],\n",
" [-0.39336858, -1.74696276, -0.26557567, -0.52388763, -0.0618552 ,\n",
" -0.37143746],\n",
" [-0.39336858, -1.78799705, -0.28226696, -0.54561123, -0.0618552 ,\n",
" -0.37143746]]), array([[-0.70112291, -0.95802485, 1.38959749, 1.76155062, 2.51005844,\n",
" 0.91094868],\n",
" [-0.70112291, -0.95802485, 1.39309019, 1.76624349, 2.51005844,\n",
" 0.91094868],\n",
" [-0.70112291, -0.95802485, 1.39646557, 1.77077872, 2.51005844,\n",
" 0.91094868],\n",
" ...,\n",
" [-0.70112291, -0.95802485, -0.62365684, -0.94350128, 2.63141296,\n",
" 0.71376734],\n",
" [-0.70112291, -0.95802485, -0.67914414, -1.01805522, 2.63141296,\n",
" 0.71376734],\n",
" [-0.70112291, -0.95802485, -0.713688 , -1.0644691 , 2.63141296,\n",
" 0.71376734]]), array([[-0.39162345, -0.77307617, 0.37867865, 0.21940014, 0.63658961,\n",
" -0.86950231],\n",
" [-0.39162345, -0.77307617, 0.37839947, 0.21909086, 0.63658961,\n",
" -0.86950231],\n",
" [-0.39162345, -0.77307617, 0.37807441, 0.21873075, 0.63658961,\n",
" -0.86950231],\n",
" ...,\n",
" [-0.39162345, -0.77307617, -0.84416418, -1.13529109, 0.85064753,\n",
" -0.53320699],\n",
" [-0.39162345, -0.77307617, -0.87837299, -1.17318833, 0.85064753,\n",
" -0.53320699],\n",
" [-0.39162345, -0.77307617, -0.89966981, -1.19678141, 0.85064753,\n",
" -0.53320699]]), array([[ 0.67573725, -0.81465401, -0.55765161, -0.24191291, 0.11970908,\n",
" -0.70480243],\n",
" [ 0.67526792, -0.81465401, -0.55765161, -0.24281012, 0.11970908,\n",
" -0.70480243],\n",
" [ 0.67473702, -0.81465401, -0.55765161, -0.24382503, 0.11970908,\n",
" -0.70480243],\n",
" ...,\n",
" [-0.88682589, -0.81465401, -0.55765161, -3.22902427, 0.19913457,\n",
" 0.3208753 ],\n",
" [-0.9121858 , -0.81465401, -0.55765161, -3.27750416, 0.19913457,\n",
" 0.3208753 ],\n",
" [-0.92767402, -0.81465401, -0.55765161, -3.30711257, 0.22318534,\n",
" 0.76601815]]), array([[-0.49583169, 0.34641847, 1.94129529, 1.46273709, 0.18874347,\n",
" 2.1760096 ],\n",
" [-0.49583169, 0.3476835 , 1.94386122, 1.46571892, 0.18874347,\n",
" 2.1760096 ],\n",
" [-0.49583169, 0.34875708, 1.94603882, 1.4682495 , 0.18874347,\n",
" 2.1760096 ],\n",
" ...,\n",
" [-0.49583169, -0.86448159, -0.51483512, -1.39151191, 0.10361879,\n",
" 1.29895904],\n",
" [-0.49583169, -0.91026036, -0.60769054, -1.49941843, 0.10361879,\n",
" 1.29895904],\n",
" [-0.49583169, -0.94128389, -0.67061714, -1.57254492, 0.10361879,\n",
" 1.29895904]])] [[4, 5, 4, 5, 4, 5], [1, 8, 3, 6, 4, 5], [8, 1, 8, 1, 2, 7], [2, 7, 5, 4, 4, 5], [4, 5, 6, 3, 4, 5], [8, 1, 5, 4, 3, 6], [3, 6, 5, 4, 4, 5], [2, 7, 3, 6, 4, 5], [7, 2, 4, 5, 7, 2], [2, 7, 6, 3, 2, 7], [3, 6, 5, 4, 1, 8], [5, 4, 6, 3, 7, 2], [9, 0, 7, 2, 1, 8], [6, 3, 2, 7, 5, 4], [7, 2, 5, 4, 6, 3], [3, 6, 5, 4, 5, 4], [3, 6, 6, 3, 2, 7], [4, 5, 8, 1, 5, 4], [5, 4, 3, 6, 5, 4], [4, 5, 3, 6, 5, 4], [5, 4, 3, 6, 6, 3]]\n"
]
}
],
"source": [
"import scipy.io as sio\n",
"import numpy as np\n",
"\n",
"'''a=sio.loadmat('s_2_e_1.mat')\n",
"b=sio.loadmat('s_2_e_2.mat')\n",
"c=sio.loadmat('s_2_e_3.mat')\n",
"print(a.keys())\n",
"for i in range(0,len(a[\"y\"])):\n",
" print(a['y'][i],b['y'][i],c['y'][i],)\n",
"'''\n",
"X=[]\n",
"O=[]\n",
"for c in range(1,23):\n",
" if c==17:\n",
" continue\n",
" a1=sio.loadmat('s_%d_e_1.mat'%(c))\n",
" a2=sio.loadmat('s_%d_e_2.mat'%(c))\n",
" a3=sio.loadmat('s_%d_e_3.mat'%(c))\n",
" \n",
" lenth=len(a1[\"y\"])\n",
" x=a1[\"x\"][lenth-1]\n",
" #print(a1[\"y\"][lenth-1],a2[\"y\"][lenth-1],a3[\"y\"][lenth-1])\n",
" o1=a1[\"y\"][lenth-1][0]\n",
" o2=a1[\"y\"][lenth-1][1]\n",
" o3=a2[\"y\"][lenth-1][0]\n",
" o4=a2[\"y\"][lenth-1][1]\n",
" o5=a3[\"y\"][lenth-1][0]\n",
" o6=a3[\"y\"][lenth-1][1]\n",
" o1=int(o1*10)\n",
" o2=int(o2*10)\n",
" o3=int(o3*10)\n",
" o4=int(o4*10)\n",
" o5=int(o5*10)\n",
" o6=int(o6*10)\n",
" o=[o1,o2,o3,o4,o5,o6]\n",
" X.append(x)\n",
" O.append(o)\n",
" print(x,o)\n",
"print(X,O)\n",
"sio.savemat('result.mat', {'X':X,\"Y\":O,}) "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"interpreter": {
"hash": "60d0a6506e690051ad38d772e690e6e667f43736fc28c7db55a8b02aefc08bd8"
},
"kernelspec": {
"display_name": "Python 3.8.5 ('base')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

+ 188
- 0
code/CNN/CNN_classifier.py View File

@ -0,0 +1,188 @@
# Load the libraries
from keras.callbacks import ModelCheckpoint, EarlyStopping
import time
from sklearn.metrics import f1_score, precision_score, recall_score
import gc
from keras import backend as K
import tensorflow as tf
from sklearn.utils import class_weight
import os
import scipy.io as sio
from tensorflow.keras.utils import to_categorical
from keras.layers import Dense, Dropout, Input, Conv2D, BatchNormalization, Activation, MaxPooling2D, GlobalAveragePooling2D
from keras.models import Model
import numpy as np
#from keras.optimizers import Adam
from keras.optimizers import adam_v2
adam = adam_v2.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
#from keras.utils import to_categorical
t1 = time.time()
np.random.seed(50) # 200
HIST_VLoss, HIST_TLoss, HIST_VAcc, HIST_TAcc, pred, SUB, P_score, F_score, R_score, K_val, Len = [
], [], [], [], [], [], [], [], [], [], []
'''这些是结果'''
for iterate in range(5): # 5
# set the seed constant for each fold 1,200
seed = np.random.randint(1, 20)
for j in range(3): # each class 3
for i in range(22): # for each subjects 22
training = sio.loadmat(
'Data_Class_%d_Subject_%d.mat' % ((j+1), (i+1)))
X = training['Feature']
Y = np.ravel(training['labels'])
identifier_val = training['identifier_val']
'''读取数据'''
X = np.reshape(X, (X.shape[0], X.shape[1], X.shape[2], 1))
loc = np.where(np.isin(identifier_val[:, 1], i+1))[0]
'''划分训练集验证集,下面都是数据预处理'''
# Shuffle unique trials to divide into training and validation set
loc0 = loc[np.where(Y[loc] == 0)[0]]
loc1 = loc[np.where(Y[loc] == 1)[0]]
classes = identifier_val[:, 0]
class0 = np.unique(classes[loc0])
class1 = np.unique(classes[loc1])
order = np.arange(len(class0))
np.random.seed(seed)
np.random.shuffle(order)
class0 = class0[order]
order = np.arange(len(class1))
np.random.seed(seed)
np.random.shuffle(order)
class1 = class1[order]
# Divide 20% trials into validation set and rest into training set
classes = 5
Test = np.append(class0[:classes], class1[:classes])
Train = np.append(class0[classes:], class1[classes:])
test_loc = np.where(np.logical_and(
np.isin(identifier_val[:, 0], Test), np.isin(identifier_val[:, 1], i+1)))[0]
train_loc = np.where(np.logical_and(
np.isin(identifier_val[:, 0], Train), np.isin(identifier_val[:, 1], i+1)))[0]
# Shuffle the data
order = np.arange(len(test_loc))
np.random.seed(seed)
np.random.shuffle(order)
test_loc = test_loc[order]
order = np.arange(len(train_loc))
np.random.seed(seed)
np.random.shuffle(order)
train_loc2 = train_loc[order]
# To account for class imbalance
class_weights = class_weight.compute_class_weight('balanced',
np.unique(
Y[train_loc2]),
Y[train_loc2])
print(class_weights)
if len(class_weights) == 1:
#sio.savemat('temp_%d.mat' %(iterate),{'iterate': iterate,'Y':Y,'train_loc':train_loc,'train_loc2':train_loc2,'npunique': Y[train_loc2]})
continue
class_weights = {i: class_weights[i]
for i in range(len(class_weights))}
# class_weights = {i: class_weights[i] for i in range(2)}
# Divide into training and validation set
X_train = X[train_loc]
X_valid = X[test_loc]
y_train = Y[train_loc]
y_valid = Y[test_loc]
Y_train = to_categorical(y_train)
Y_valid = to_categorical(y_valid)
'''开始初始化模型'''
# %% Define model architecture
ch = 6 # Number of features
act = 'tanh' # Activation function
dense_num_units, cnn_units = 16, 8 # Model architecture hyper parameters
bn_axis = 3 # batch normalization dimension
input_shape = (320, ch, 1)
Allinput_img = Input(shape=(input_shape))
x_c = Conv2D(cnn_units, kernel_size=(2, 1), strides=(
2, 1), name='convch_0', use_bias=True)(Allinput_img) # 卷积
x_c = BatchNormalization(axis=3)(x_c) # 分批
x_c = Activation(act)(x_c) # 激活
x_c = MaxPooling2D(pool_size=(2, 1))(x_c) # 最大池化
'''x_c一层层往下传递,就是神经网络那种'''
x_c = Conv2D(cnn_units, kernel_size=(2, 1),
strides=(2, 1), use_bias=True)(x_c)
x_c = BatchNormalization(axis=3)(x_c)
x_c = Activation(act)(x_c)
x_c = MaxPooling2D(pool_size=(2, 1))(x_c)
x_c = Conv2D(cnn_units, kernel_size=(2, ch),
strides=(2, 1), use_bias=True)(x_c)
x_c = BatchNormalization(axis=3)(x_c)
x_c = Activation(act)(x_c)
x_c = MaxPooling2D(pool_size=(2, 1))(x_c)
'''多整几层'''
x_c = GlobalAveragePooling2D()(x_c)
x = (Dropout(rate=0.5, name='Drop_D2'))(x_c)
Out = (Dense(2, activation='softmax', name='Allenc_18'))(x)
model = Model(Allinput_img, Out)
model.compile(optimizer=adam_v2.Adam(lr=0.000001),
loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
# File to save the model weights to
Checkpoint_filename = './save_weights_CNN_RR.hdf5'
'''保存临时文件'''
# Callbacks to perfrorm early stopping if model does not improve in successive 5 epochs and to save
# model weights only if there is improvement
callback_array = [EarlyStopping(monitor='val_loss', min_delta=0, patience=5, verbose=1, mode='auto'),
ModelCheckpoint(Checkpoint_filename, monitor='val_loss', verbose=1, save_best_only=True, mode='auto')]
'''epoch越大,训练次数越多,越慢'''
# Train the model
history = model.fit(X_train, Y_train,
epochs=10, validation_data=(X_valid, Y_valid),
batch_size=32, verbose=1, callbacks=callback_array, shuffle=True, class_weight=class_weights)
model.load_weights(Checkpoint_filename)
print('Iteration number: %d; Condition: %d; Subject: %d' %
(iterate, j+1, i+1))
# Compute the performance metrics
y_pred = np.argmax(model.predict(X_valid), axis=1)
y_pred2 = model.predict(X_valid)
F_score.append([f1_score(y_valid, y_pred, average="macro"), f1_score(
y_valid, y_pred, average="micro"), f1_score(y_valid, y_pred, average="weighted")])
P_score.append([precision_score(y_valid, y_pred, average="macro"), precision_score(
y_valid, y_pred, average="micro"), precision_score(y_valid, y_pred, average="weighted")])
R_score.append([recall_score(y_valid, y_pred, average="macro"), recall_score(
y_valid, y_pred, average="micro"), recall_score(y_valid, y_pred, average="weighted")])
print(y_pred)
print(y_pred2)
hist_dict = history.history
print('====histore_dict----')
print(hist_dict.keys())
HIST_VLoss.append(history.history['val_loss'])
HIST_TLoss.append(history.history['loss'])
# HIST_VAcc.append(history.history['val_acc']) #报错就改成后面这个 val_accuracy
HIST_VAcc.append(history.history['val_accuracy'])
# HIST_TAcc.append(history.history['acc']) #报错就改成后面这个 accuracy
HIST_TAcc.append(history.history['accuracy'])
pred.append(model.evaluate(X_valid, Y_valid))
SUB.append(i)
Len.append(len(history.history['val_loss']))
# Reset the variables and delete the model
os.remove(Checkpoint_filename)
K.clear_session()
gc.collect()
del model, training, X, Y, identifier_val, loc, loc0, loc1, order, class0, class1, Test, Train, test_loc, train_loc, class_weights, X_train, X_valid, Y_valid, Y_train
'''prf三个,是那三个参数,由预测得出,另外的,是训练过程中产生的'''
'''保存文件'''
# Save the final results
sio.savemat('Prediction_V%d.mat' % (iterate), {'HIST_VLoss': HIST_VLoss,
'HIST_VAcc': HIST_VAcc, 'HIST_TLoss': HIST_TLoss,
'HIST_TAcc': HIST_TAcc, 'pred': pred, 'sub': SUB,
'F_score': F_score, 'P_score': P_score, 'R_score': R_score, 'Length': Len, })
t2 = time.time()
print(t2-t1)

Loading…
Cancel
Save