init
This commit is contained in:
83
StockTrainer.py
Normal file
83
StockTrainer.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import os
|
||||
import keras
|
||||
import numpy as np
|
||||
from stock.util.Stock2Vector import Stock2Vector
|
||||
from classification_models.keras import Classifiers
|
||||
|
||||
class StockTrainer:
|
||||
|
||||
RESOURCE_PATH = None
|
||||
stock2Vector = None
|
||||
|
||||
def __init__(self, RESOURCE_PATH):
|
||||
self.RESOURCE_PATH = RESOURCE_PATH
|
||||
self.stock2Vector = Stock2Vector(RESOURCE_PATH)
|
||||
return
|
||||
|
||||
def getDataset(self, stock_code):
|
||||
df, minmax_df = self.stock2Vector.makeTrainData(stock_code)
|
||||
|
||||
TOTAL_X, TOTAL_Y = [], []
|
||||
for key in df:
|
||||
if key == "date":
|
||||
continue
|
||||
elif key == "label":
|
||||
TOTAL_Y.append(df[key].tolist())
|
||||
else:
|
||||
TOTAL_X.append(df[key].tolist())
|
||||
|
||||
X, Y = [], []
|
||||
for i in range(299, len(TOTAL_X[0])):
|
||||
temp_X, temp_Y = np.zeros((299, 299)), np.zeros(0)
|
||||
idx = 0
|
||||
for j in range(i-299, i):
|
||||
for k in range(len(TOTAL_X)):
|
||||
temp_X[k][idx] = TOTAL_X[k][j]
|
||||
idx += 1
|
||||
X.append(temp_X)
|
||||
Y.append(TOTAL_Y[i])
|
||||
|
||||
|
||||
|
||||
|
||||
return X, Y
|
||||
|
||||
def train(self, stock_code):
|
||||
ResNet18, preprocess_input = Classifiers.get('inceptionresnetv2')
|
||||
|
||||
X, Y = self.getDataset(stock_code)
|
||||
|
||||
X = preprocess_input(X)
|
||||
|
||||
n_classes = 3
|
||||
|
||||
# build model
|
||||
base_model = ResNet18(input_shape=(299, 299, 3), weights='imagenet', include_top=False)
|
||||
x = keras.layers.GlobalAveragePooling2D()(base_model.output)
|
||||
output = keras.layers.Dense(n_classes, activation='softmax')(x)
|
||||
model = keras.models.Model(inputs=[base_model.input], outputs=[output])
|
||||
|
||||
# train
|
||||
model.compile(optimizer='SGD', loss='categorical_crossentropy', metrics=['accuracy'])
|
||||
model.fit(X, Y)
|
||||
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
PROJECT_HOME = os.path.join(os.path.dirname(__file__))
|
||||
RESOURCE_PATH = os.path.join(PROJECT_HOME, "resources")
|
||||
|
||||
# to check bying
|
||||
stock_codes = {
|
||||
# 252670
|
||||
# 122630
|
||||
"252670": ['20220729'],
|
||||
}
|
||||
|
||||
for stock_code in stock_codes:
|
||||
stockTrainer = StockTrainer(RESOURCE_PATH)
|
||||
stockTrainer.train(stock_code)
|
||||
|
||||
|
||||
print ("done...")
|
||||
Reference in New Issue
Block a user