This commit is contained in:
dosangyoon
2022-08-05 15:17:33 +09:00
parent fc6400ba77
commit 57c5cc5638
2 changed files with 72 additions and 39 deletions

View File

@@ -15,42 +15,10 @@ class StockTrainer:
self.stock2Vector = Stock2Vector(RESOURCE_PATH)
return
def getDataset(self, stock_code):
VECTOR_SIZE = 299
result = self.stock2Vector.getTrainData(stock_code)
df, minmax_df = self.stock2Vector.preprocessData(result)
TOTAL_X, TOTAL_Y = [], []
for key in minmax_df:
if key == "date":
continue
elif key == "label":
TOTAL_Y.append(minmax_df[key].tolist())
else:
TOTAL_X.append(minmax_df[key].tolist())
SIZE_WIDTH = len(TOTAL_X[0])
SIZE_HEIGHT = len(TOTAL_X)
X, Y = [], []
for i in range(VECTOR_SIZE, SIZE_WIDTH):
temp_X, temp_Y = np.zeros((VECTOR_SIZE, VECTOR_SIZE)), np.zeros(0)
for j in range(SIZE_HEIGHT):
temp_X[j][0:VECTOR_SIZE] = TOTAL_X[j][i-VECTOR_SIZE:i]
temp_X = np.stack([temp_X, temp_X, temp_X], axis=-1)
X.append(temp_X)
if int(TOTAL_Y[0][i]) == 0:
Y.append([1, 0, 0])
elif int(TOTAL_Y[0][i]) == 0.5:
Y.append([0, 1, 0])
else:
Y.append([0, 0, 1])
X = np.asarray(X)
Y = np.asarray(Y)
return X, Y
def train(self, stock_code):
X, Y = self.getDataset(stock_code)
#X, Y = self.stock2Vector.getDataset3D(stock_code)
X, Y = self.stock2Vector.getDataset2D(stock_code)
# build model
n_classes = 3