This commit is contained in:
dsyoon
2022-08-03 22:40:23 +09:00
parent 41e23641a8
commit ce26ef5623
5 changed files with 139 additions and 32 deletions

View File

@@ -1,7 +1,6 @@
import os
import keras
import numpy as np
from numpy import zeros, newaxis
import tensorflow as tf
from stock.util.Stock2Vector import Stock2Vector
from classification_models.keras import Classifiers
@@ -18,16 +17,17 @@ class StockTrainer:
def getDataset(self, stock_code):
VECTOR_SIZE = 299
df, minmax_df = self.stock2Vector.makeTrainData(stock_code)
result = self.stock2Vector.getTrainData(stock_code)
df, minmax_df = self.stock2Vector.preprocessData(result)
TOTAL_X, TOTAL_Y = [], []
for key in df:
for key in minmax_df:
if key == "date":
continue
elif key == "label":
TOTAL_Y.append(df[key].tolist())
TOTAL_Y.append(minmax_df[key].tolist())
else:
TOTAL_X.append(df[key].tolist())
TOTAL_X.append(minmax_df[key].tolist())
SIZE_WIDTH = len(TOTAL_X[0])
SIZE_HEIGHT = len(TOTAL_X)
@@ -44,38 +44,37 @@ class StockTrainer:
Y.append([0, 1, 0])
else:
Y.append([0, 0, 1])
if i >= VECTOR_SIZE+10:
break
X = np.asarray(X)
Y = np.asarray(Y)
return X, Y
def train(self, stock_code):
ResNet18, preprocess_input = Classifiers.get('inceptionresnetv2')
X, Y = self.getDataset(stock_code)
# build model
n_classes = 3
Inceptionresnetv2, preprocess_input = Classifiers.get('inceptionresnetv2')
X = preprocess_input(X)
n_classes = 3
# build model
base_model = ResNet18(input_shape=(299, 299, 3), include_top=False)
# train
checkpoint_filename = os.path.join(self.RESOURCE_PATH, "model", "stock.ckpt")
base_model = Inceptionresnetv2(input_shape=(299, 299, 3), include_top=False)
x = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
output = tf.keras.layers.Dense(n_classes, activation='softmax')(x)
model = keras.models.Model(inputs=[base_model.input], outputs=[output])
# train
model.compile(optimizer='SGD', loss='categorical_crossentropy', metrics=['accuracy'])
checkpoint_filename = os.path.join(self.RESOURCE_PATH, "model", "stock.ckpt")
chekpoint = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_filename, save_weights_only=True, verbose=1)
earlystop = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3, mode="auto")
earlystop = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=5, mode="auto")
if os.path.isfile(checkpoint_filename):
model.load_weights(checkpoint_filename)
model.fit(x=X,
y=Y,
epochs=10,
batch_size=10000,
epochs=10000,
callbacks=[chekpoint, earlystop])
return
@@ -95,5 +94,4 @@ if __name__ == "__main__":
stockTrainer = StockTrainer(RESOURCE_PATH)
stockTrainer.train(stock_code)
print ("done...")