import os import keras import numpy as np from numpy import zeros, newaxis import tensorflow as tf from stock.util.Stock2Vector import Stock2Vector from classification_models.keras import Classifiers class StockTrainer: RESOURCE_PATH = None stock2Vector = None def __init__(self, RESOURCE_PATH): self.RESOURCE_PATH = RESOURCE_PATH self.stock2Vector = Stock2Vector(RESOURCE_PATH) return def getDataset(self, stock_code): VECTOR_SIZE = 299 df, minmax_df = self.stock2Vector.makeTrainData(stock_code) TOTAL_X, TOTAL_Y = [], [] for key in df: if key == "date": continue elif key == "label": TOTAL_Y.append(df[key].tolist()) else: TOTAL_X.append(df[key].tolist()) SIZE_WIDTH = len(TOTAL_X[0]) SIZE_HEIGHT = len(TOTAL_X) X, Y = [], [] for i in range(VECTOR_SIZE, SIZE_WIDTH): temp_X, temp_Y = np.zeros((VECTOR_SIZE, VECTOR_SIZE)), np.zeros(0) for j in range(SIZE_HEIGHT): temp_X[j][0:VECTOR_SIZE] = TOTAL_X[j][i-VECTOR_SIZE:i] temp_X = np.stack([temp_X, temp_X, temp_X], axis=-1) X.append(temp_X) if int(TOTAL_Y[0][i]) == 0: Y.append([1, 0, 0]) elif int(TOTAL_Y[0][i]) == 0.5: Y.append([0, 1, 0]) else: Y.append([0, 0, 1]) if i >= VECTOR_SIZE+10: break X = np.asarray(X) Y = np.asarray(Y) return X, Y def train(self, stock_code): ResNet18, preprocess_input = Classifiers.get('inceptionresnetv2') X, Y = self.getDataset(stock_code) X = preprocess_input(X) n_classes = 3 # build model base_model = ResNet18(input_shape=(299, 299, 3), include_top=False) x = tf.keras.layers.GlobalAveragePooling2D()(base_model.output) output = tf.keras.layers.Dense(n_classes, activation='softmax')(x) model = keras.models.Model(inputs=[base_model.input], outputs=[output]) # train model.compile(optimizer='SGD', loss='categorical_crossentropy', metrics=['accuracy']) checkpoint_filename = os.path.join(self.RESOURCE_PATH, "model", "stock.ckpt") chekpoint = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_filename, save_weights_only=True, verbose=1) earlystop = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3, mode="auto") model.fit(x=X, y=Y, epochs=10, callbacks=[chekpoint, earlystop]) return if __name__ == "__main__": PROJECT_HOME = os.path.join(os.path.dirname(__file__)) RESOURCE_PATH = os.path.join(PROJECT_HOME, "resources") stock_codes = { # 252670 # 122630 "252670": ['20220729'], } for stock_code in stock_codes: stockTrainer = StockTrainer(RESOURCE_PATH) stockTrainer.train(stock_code) print ("done...")