Files
DeepStock/StockTrainer.py
dosangyoon 57c5cc5638 init
2022-08-05 15:17:33 +09:00

65 lines
2.1 KiB
Python

import os
import keras
import numpy as np
import tensorflow as tf
from stock.util.Stock2Vector import Stock2Vector
from classification_models.keras import Classifiers
class StockTrainer:
RESOURCE_PATH = None
stock2Vector = None
def __init__(self, RESOURCE_PATH):
self.RESOURCE_PATH = RESOURCE_PATH
self.stock2Vector = Stock2Vector(RESOURCE_PATH)
return
def train(self, stock_code):
#X, Y = self.stock2Vector.getDataset3D(stock_code)
X, Y = self.stock2Vector.getDataset2D(stock_code)
# build model
n_classes = 3
Inceptionresnetv2, preprocess_input = Classifiers.get('inceptionresnetv2')
X = preprocess_input(X)
# train
checkpoint_filename = os.path.join(self.RESOURCE_PATH, "model", "stock.ckpt")
base_model = Inceptionresnetv2(input_shape=(299, 299, 3), include_top=False)
x = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
output = tf.keras.layers.Dense(n_classes, activation='softmax')(x)
model = keras.models.Model(inputs=[base_model.input], outputs=[output])
model.compile(optimizer='SGD', loss='categorical_crossentropy', metrics=['accuracy'])
chekpoint = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_filename, save_weights_only=True, verbose=1)
earlystop = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=5, mode="auto")
if os.path.isfile(checkpoint_filename):
model.load_weights(checkpoint_filename)
model.fit(x=X,
y=Y,
batch_size=10000,
epochs=10000,
callbacks=[chekpoint, earlystop])
return
if __name__ == "__main__":
PROJECT_HOME = os.path.join(os.path.dirname(__file__))
RESOURCE_PATH = os.path.join(PROJECT_HOME, "resources")
stock_codes = {
# 252670
# 122630
"252670": ['20220729'],
}
for stock_code in stock_codes:
stockTrainer = StockTrainer(RESOURCE_PATH)
stockTrainer.train(stock_code)
print ("done...")