This commit is contained in:
dsyoon
2022-08-07 00:36:18 +09:00
parent f43688f0da
commit c173a6d7dc
10 changed files with 427 additions and 203 deletions

View File

@@ -154,13 +154,16 @@ class Stock2Vector(HTS):
return df, minmax_df
def getTrainData(self, stock_code):
def getTrainData(self, stock_code, sDate=None, eDate=None):
tableName = 'hts'
conn = sqlite3.connect(os.path.join(self.RESOURCE_PATH, "hts.db"))
cursor = conn.cursor()
cursor.execute('SELECT ymd, hms, open, high, low, close, volume, label FROM ' + tableName + ' WHERE CODE=? and (ymd >= ? and ymd <= ?) order by ymd desc, hms ', (stock_code, "20220726", "20220731"))
#cursor.execute('SELECT ymd, hms, open, high, low, close, volume, label FROM ' + tableName + ' WHERE CODE=? order by ymd desc, hms ', (stock_code,))
if sDate is None and eDate is None:
cursor.execute('SELECT ymd, hms, open, high, low, close, volume, label FROM ' + tableName + ' WHERE CODE=? order by ymd desc, hms ', (stock_code,))
else:
cursor.execute('SELECT ymd, hms, open, high, low, close, volume, label FROM ' + tableName + ' WHERE CODE=? and (ymd >= ? and ymd <= ?) order by ymd desc, hms ', (stock_code, sDate, eDate))
db_result = cursor.fetchall()
temp_result = []
for rows in db_result:
@@ -168,6 +171,9 @@ class Stock2Vector(HTS):
temp_result.sort(key=lambda x: (x[0], x[1]))
result = {"check": set(), "time": [], "open": [], "close": [], "high": [], "low": [], "vol": [], "label": []}
if len(db_result) == 0:
return result
for rows in temp_result:
ymd = rows[0] # hts.날짜
hms = rows[1] # hts.시간
@@ -246,9 +252,9 @@ class Stock2Vector(HTS):
return np.asarray(vector)
def getDataset2D(self, stock_code, VECTOR_SIZE = 381):
result = self.getTrainData(stock_code)
df, minmax_df = self.preprocessData(result)
def getDataset2D(self, data, VECTOR_SIZE = 381):
df, minmax_df = self.preprocessData(data)
TOTAL_X, TOTAL_Y = [], []
for key in minmax_df:
@@ -262,38 +268,24 @@ class Stock2Vector(HTS):
SIZE_WIDTH = len(TOTAL_X[0])
SIZE_HEIGHT = len(TOTAL_X)
X, Y = [], []
for i in range(VECTOR_SIZE, SIZE_WIDTH):
for i in range(VECTOR_SIZE-1, SIZE_WIDTH):
temp_X, temp_Y = np.zeros((VECTOR_SIZE, VECTOR_SIZE)), np.zeros(0)
for j in range(SIZE_HEIGHT):
temp_X[j][0:VECTOR_SIZE] = TOTAL_X[j][i-VECTOR_SIZE:i]
temp_X[j][0:VECTOR_SIZE] = TOTAL_X[j][i-VECTOR_SIZE+1:i+1]
X.append(temp_X)
if TOTAL_Y[0][i] == 0:
#Y.append([1, 0, 0])
Y.append(0)
elif TOTAL_Y[0][i] == 0.5:
#Y.append([0, 1, 0])
Y.append(1)
else:
#Y.append([0, 0, 1])
Y.append(2)
X = np.asarray(X)
Y = np.asarray(Y)
Y = np.asarray(Y, dtype='int64')
return X, Y
def makeDataset2D(self, stock_code, outFileName=None):
X, Y = self.getDataset2D(stock_code)
#reX = X.reshape(X.shape[0], (X.shape[1] * X.shape[2]))
#df = pd.DataFrame(np.hstack((reX, Y)))
#df.to_csv(outFileName, index=False, header=False)
return X, Y
def getDataset3D(self, stock_code, VECTOR_SIZE = 299):
result = self.getTrainData(stock_code)
df, minmax_df = self.preprocessData(result)
def getDataset3D(self, data, VECTOR_SIZE = 299):
df, minmax_df = self.preprocessData(data)
TOTAL_X, TOTAL_Y = [], []
for key in minmax_df:
@@ -338,8 +330,8 @@ if __name__ == "__main__":
for stock_code in stock_codes:
stock2Vector = Stock2Vector(RESOURCE_PATH)
# X, Y = stock2Vector.getDataset2D(stock_code)
stock2Vector.makeDataset2D(stock_code, outFileName=os.path.join(RESOURCE_PATH, "tmp", "stock_features.csv"))
# data = self.stock2Vector.getTrainData(stock_code, sDate, eDate)
# X, Y = self.stock2Vector.getDataset2D(data)
for given_day in stock_codes[stock_code]:
data, minmax_data = stock2Vector.makeData(given_day, stock_code)