|
|
@ -0,0 +1,122 @@ |
|
|
|
import pandas as pd |
|
|
|
import datetime as DT |
|
|
|
import numpy as np |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
import matplotlib.dates as mdates |
|
|
|
import math |
|
|
|
import statsmodels.api as sm |
|
|
|
from sklearn.metrics import mean_squared_error |
|
|
|
|
|
|
|
#打开数据文件 |
|
|
|
dataset = pd.read_csv('E:\dase intro\COVID-19Analysis\COVID-19\covid-19-all.csv') |
|
|
|
|
|
|
|
#数据预处理 |
|
|
|
def parse_ymd(s): |
|
|
|
year_s, mon_s, day_s = s.split('-') |
|
|
|
return datetime.datetime(int(year_s), int(mon_s), int(day_s)).strftime("%Y-%m-%d") |
|
|
|
dataset = dataset.fillna(0) |
|
|
|
dataset['Date'] = pd.to_datetime(dataset['Date']) |
|
|
|
dataset = dataset[['Country/Region','Confirmed','Recovered','Deaths','Date']].groupby(['Country/Region','Date']).sum().reset_index() |
|
|
|
|
|
|
|
#取出中、美的数据 |
|
|
|
CN = dataset[dataset['Country/Region'] == 'China'] |
|
|
|
CN.index = pd.Index(pd.date_range('2020-01-22','2020-12-09',freq = '1D')) |
|
|
|
US = dataset[dataset['Country/Region'] == 'US'] |
|
|
|
US.index = pd.Index(pd.date_range('2020-01-22','2020-12-09',freq = '1D')) |
|
|
|
|
|
|
|
#划分训练集、测试集 |
|
|
|
trainCN = CN[CN['Date'] < '2020-11-01 '] |
|
|
|
testCN = CN[CN['Date'] >= '2020-11-01'] |
|
|
|
|
|
|
|
trainUS = US[US['Date'] < '2020-11-01 '] |
|
|
|
testUS = US[US['Date'] >= '2020-11-01'] |
|
|
|
|
|
|
|
#自回归移动平均模型(ARIMA) |
|
|
|
yCNARIMA = testCN.copy() |
|
|
|
yUSARIMA = testUS.copy() |
|
|
|
|
|
|
|
#训练模型 |
|
|
|
fitCNconfirmed = sm.tsa.statespace.SARIMAX(trainCN.Confirmed).fit() |
|
|
|
fitCNrecovered = sm.tsa.statespace.SARIMAX(trainCN['Recovered']).fit() |
|
|
|
fitCNdeaths = sm.tsa.statespace.SARIMAX(trainCN['Deaths']).fit() |
|
|
|
|
|
|
|
fitUSconfirmed = sm.tsa.statespace.SARIMAX(trainUS.Confirmed,trend='ct').fit() |
|
|
|
fitUSrecovered = sm.tsa.statespace.SARIMAX(trainUS['Recovered']).fit() |
|
|
|
fitUSdeaths = sm.tsa.statespace.SARIMAX(trainUS['Deaths']).fit() |
|
|
|
|
|
|
|
#测试 |
|
|
|
yCNARIMA['SARIMAconfirmed'] = fitCNconfirmed.predict(start="2020-11-01", end="2020-12-09", dynamic=True) |
|
|
|
yCNARIMA['SARIMArecovered'] = fitCNrecovered.predict(start="2020-11-01", end="2020-12-09", dynamic=True) |
|
|
|
yCNARIMA['SARIMAdeaths'] = fitCNdeaths.predict(start="2020-11-01", end="2020-12-09", dynamic=True) |
|
|
|
|
|
|
|
yUSARIMA['SARIMAconfirmed'] = fitUSconfirmed.predict(start="2020-11-01", end="2020-12-09") |
|
|
|
yUSARIMA['SARIMArecovered'] = fitUSrecovered.predict(start="2020-11-01", end="2020-12-09", dynamic=True) |
|
|
|
yUSARIMA['SARIMAdeaths'] = fitUSdeaths.predict(start="2020-11-01", end="2020-12-09", dynamic=True) |
|
|
|
|
|
|
|
#预测将来七天 |
|
|
|
forecastCNARIMA = pd.DataFrame({'Date':['2020-12-10','2020-12-11','2020-12-12','2020-12-13','2020-12-14','2020-12-15','2020-12-16']}) |
|
|
|
forecastUSARIMA = pd.DataFrame({'Date':['2020-12-10','2020-12-11','2020-12-12','2020-12-13','2020-12-14','2020-12-15','2020-12-16']}) |
|
|
|
|
|
|
|
forecastCNARIMA['Date'] = pd.to_datetime(forecastCNARIMA['Date'], format='%Y/%m/%d').values.astype('datetime64[h]') |
|
|
|
forecastCNARIMA['confirmedPred'] = fitCNconfirmed.predict(start="2020-12-10", end="2020-12-16", dynamic=True) |
|
|
|
forecastCNARIMA['recoveredPred'] = fitCNrecovered.predict(start="2020-12-10", end="2020-12-16", dynamic=True) |
|
|
|
forecastCNARIMA['deathsPred'] = fitCNdeaths.predict(start="2020-12-10", end="2020-12-16", dynamic=True) |
|
|
|
|
|
|
|
forecastUSARIMA['Date'] = pd.to_datetime(forecastUSARIMA['Date'], format='%Y/%m/%d').values.astype('datetime64[h]') |
|
|
|
forecastUSARIMA['confirmedPred'] = fitUSconfirmed.predict(start="2020-12-10", end="2020-12-16", dynamic=True) |
|
|
|
forecastUSARIMA['recoveredPred'] = fitUSrecovered.predict(start="2020-12-10", end="2020-12-16", dynamic=True) |
|
|
|
forecastUSARIMA['deathsPred'] = fitUSdeaths.predict(start="2020-12-10", end="2020-12-16", dynamic=True) |
|
|
|
|
|
|
|
#RMSE |
|
|
|
rmseCNARIMACon = pow(mean_squared_error(np.asarray(testCN['Confirmed']), np.asarray(yCNARIMA['SARIMAconfirmed'])),0.05) |
|
|
|
rmseCNARIMARec = pow(mean_squared_error(np.asarray(testCN['Recovered']), np.asarray(yCNARIMA['SARIMArecovered'])),0.05) |
|
|
|
rmseCNARIMADea = pow(mean_squared_error(np.asarray(testCN['Deaths']), np.asarray(yCNARIMA['SARIMAdeaths'])),0.5) |
|
|
|
|
|
|
|
rmseUSARIMACon = pow(mean_squared_error(np.asarray(testUS['Confirmed']), np.asarray(yUSARIMA['SARIMAconfirmed'])),0.05) |
|
|
|
rmseUSARIMARec = pow(mean_squared_error(np.asarray(testUS['Recovered']), np.asarray(yUSARIMA['SARIMArecovered'])),0.05) |
|
|
|
rmseUSARIMADea = pow(mean_squared_error(np.asarray(testUS['Deaths']), np.asarray(yUSARIMA['SARIMAdeaths'])),0.05) |
|
|
|
|
|
|
|
#可视化 |
|
|
|
fig = plt.figure() |
|
|
|
axCNARIMA = fig.add_subplot(211) |
|
|
|
axCNARIMA.set_title("ARIMA (CN)",verticalalignment="bottom",fontsize="13") |
|
|
|
|
|
|
|
CN.index = pd.Index(pd.date_range('2020-01-22','2020-12-09',freq = '1D')) |
|
|
|
yCNARIMA.index = pd.Index(pd.date_range('2020-11-01','2020-12-09',freq = '1D')) |
|
|
|
forecastCNARIMA.index = pd.Index(pd.date_range('2020-12-10','2020-12-16',freq = '1D')) |
|
|
|
|
|
|
|
axCNARIMA.plot(CN['Confirmed'],label="confirmed",linestyle=":") |
|
|
|
axCNARIMA.plot(CN['Recovered'],label="recovered",linestyle=":") |
|
|
|
axCNARIMA.plot(CN['Deaths'],label="deaths",linestyle=":") |
|
|
|
|
|
|
|
axCNARIMA.plot(yCNARIMA['SARIMAconfirmed'],label="confirmed test") |
|
|
|
axCNARIMA.plot(yCNARIMA['SARIMArecovered'],label="recovered test") |
|
|
|
axCNARIMA.plot(yCNARIMA['SARIMAdeaths'],label="deaths test") |
|
|
|
|
|
|
|
axCNARIMA.plot(forecastCNARIMA['confirmedPred'],label="confirmed prediction") |
|
|
|
axCNARIMA.plot(forecastCNARIMA['recoveredPred'],label="recovered prediction") |
|
|
|
axCNARIMA.plot(forecastCNARIMA['deathsPred'],label="deaths prediction") |
|
|
|
|
|
|
|
axUSARIMA = fig.add_subplot(212) |
|
|
|
axUSARIMA.set_title("ARIMA (US)",verticalalignment="bottom",fontsize="13") |
|
|
|
|
|
|
|
US.index = pd.Index(pd.date_range('2020-01-22','2020-12-09',freq = '1D')) |
|
|
|
yUSARIMA.index = pd.Index(pd.date_range('2020-11-01','2020-12-09',freq = '1D')) |
|
|
|
forecastUSARIMA.index = pd.Index(pd.date_range('2020-12-10','2020-12-16',freq = '1D')) |
|
|
|
|
|
|
|
axUSARIMA.plot(US['Confirmed'],label="confirmed",linestyle=":") |
|
|
|
axUSARIMA.plot(US['Recovered'],label="recovered",linestyle=":") |
|
|
|
axUSARIMA.plot(US['Deaths'],label="deaths",linestyle=":") |
|
|
|
|
|
|
|
axUSARIMA.plot(yUSARIMA['SARIMAconfirmed'],label="confirmed test") |
|
|
|
axUSARIMA.plot(yUSARIMA['SARIMArecovered'],label="recovered test") |
|
|
|
axUSARIMA.plot(yUSARIMA['SARIMAdeaths'],label="deaths test") |
|
|
|
|
|
|
|
axUSARIMA.plot(forecastUSARIMA['confirmedPred'],label="confirmed prediction") |
|
|
|
axUSARIMA.plot(forecastUSARIMA['recoveredPred'],label="recovered prediction") |
|
|
|
axUSARIMA.plot(forecastUSARIMA['deathsPred'],label="deaths prediction") |
|
|
|
|
|
|
|
plt.tight_layout() |
|
|
|
plt.gcf().autofmt_xdate() |
|
|
|
plt.legend(labelspacing=0.05) |
|
|
|
plt.show() |