init
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import keras
|
import keras
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from numpy import zeros, newaxis
|
||||||
|
import tensorflow as tf
|
||||||
from stock.util.Stock2Vector import Stock2Vector
|
from stock.util.Stock2Vector import Stock2Vector
|
||||||
from classification_models.keras import Classifiers
|
from classification_models.keras import Classifiers
|
||||||
|
|
||||||
@@ -15,6 +17,7 @@ class StockTrainer:
|
|||||||
return
|
return
|
||||||
|
|
||||||
def getDataset(self, stock_code):
|
def getDataset(self, stock_code):
|
||||||
|
VECTOR_SIZE = 299
|
||||||
df, minmax_df = self.stock2Vector.makeTrainData(stock_code)
|
df, minmax_df = self.stock2Vector.makeTrainData(stock_code)
|
||||||
|
|
||||||
TOTAL_X, TOTAL_Y = [], []
|
TOTAL_X, TOTAL_Y = [], []
|
||||||
@@ -26,17 +29,26 @@ class StockTrainer:
|
|||||||
else:
|
else:
|
||||||
TOTAL_X.append(df[key].tolist())
|
TOTAL_X.append(df[key].tolist())
|
||||||
|
|
||||||
|
SIZE_WIDTH = len(TOTAL_X[0])
|
||||||
|
SIZE_HEIGHT = len(TOTAL_X)
|
||||||
X, Y = [], []
|
X, Y = [], []
|
||||||
for i in range(299, len(TOTAL_X[0])):
|
for i in range(VECTOR_SIZE, SIZE_WIDTH):
|
||||||
temp_X, temp_Y = np.zeros((299, 299)), np.zeros(0)
|
temp_X, temp_Y = np.zeros((VECTOR_SIZE, VECTOR_SIZE)), np.zeros(0)
|
||||||
idx = 0
|
for j in range(SIZE_HEIGHT):
|
||||||
for j in range(i-299, i):
|
temp_X[j][0:VECTOR_SIZE] = TOTAL_X[j][i-VECTOR_SIZE:i]
|
||||||
for k in range(len(TOTAL_X)):
|
temp_X = np.stack([temp_X, temp_X, temp_X], axis=-1)
|
||||||
temp_X[k][idx] = TOTAL_X[k][j]
|
|
||||||
idx += 1
|
|
||||||
X.append(temp_X)
|
X.append(temp_X)
|
||||||
Y.append(TOTAL_Y[0][i])
|
if int(TOTAL_Y[0][i]) == 0:
|
||||||
|
Y.append([1, 0, 0])
|
||||||
|
elif int(TOTAL_Y[0][i]) == 0.5:
|
||||||
|
Y.append([0, 1, 0])
|
||||||
|
else:
|
||||||
|
Y.append([0, 0, 1])
|
||||||
|
if i >= VECTOR_SIZE+10:
|
||||||
|
break
|
||||||
|
|
||||||
|
X = np.asarray(X)
|
||||||
|
Y = np.asarray(Y)
|
||||||
return X, Y
|
return X, Y
|
||||||
|
|
||||||
def train(self, stock_code):
|
def train(self, stock_code):
|
||||||
@@ -49,14 +61,22 @@ class StockTrainer:
|
|||||||
n_classes = 3
|
n_classes = 3
|
||||||
|
|
||||||
# build model
|
# build model
|
||||||
base_model = ResNet18(input_shape=(299, 299, 3), weights='imagenet', include_top=False)
|
base_model = ResNet18(input_shape=(299, 299, 3), include_top=False)
|
||||||
x = keras.layers.GlobalAveragePooling2D()(base_model.output)
|
x = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
|
||||||
output = keras.layers.Dense(n_classes, activation='softmax')(x)
|
output = tf.keras.layers.Dense(n_classes, activation='softmax')(x)
|
||||||
model = keras.models.Model(inputs=[base_model.input], outputs=[output])
|
model = keras.models.Model(inputs=[base_model.input], outputs=[output])
|
||||||
|
|
||||||
# train
|
# train
|
||||||
model.compile(optimizer='SGD', loss='categorical_crossentropy', metrics=['accuracy'])
|
model.compile(optimizer='SGD', loss='categorical_crossentropy', metrics=['accuracy'])
|
||||||
model.fit(X, Y)
|
|
||||||
|
checkpoint_filename = os.path.join(self.RESOURCE_PATH, "model", "stock.ckpt")
|
||||||
|
chekpoint = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_filename, save_weights_only=True, verbose=1)
|
||||||
|
earlystop = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3, mode="auto")
|
||||||
|
|
||||||
|
model.fit(x=X,
|
||||||
|
y=Y,
|
||||||
|
epochs=10,
|
||||||
|
callbacks=[chekpoint, earlystop])
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -65,7 +85,6 @@ if __name__ == "__main__":
|
|||||||
PROJECT_HOME = os.path.join(os.path.dirname(__file__))
|
PROJECT_HOME = os.path.join(os.path.dirname(__file__))
|
||||||
RESOURCE_PATH = os.path.join(PROJECT_HOME, "resources")
|
RESOURCE_PATH = os.path.join(PROJECT_HOME, "resources")
|
||||||
|
|
||||||
# to check bying
|
|
||||||
stock_codes = {
|
stock_codes = {
|
||||||
# 252670
|
# 252670
|
||||||
# 122630
|
# 122630
|
||||||
|
|||||||
Reference in New Issue
Block a user