This commit is contained in:
dsyoon
2022-08-07 13:38:07 +09:00
parent c173a6d7dc
commit 655eedc421
3 changed files with 8 additions and 5 deletions

View File

@@ -28,6 +28,9 @@ class Simulation (HTS):
return
def draw(self, stock_code, given_day, data, bsLine):
if bsLine is None:
return
# 어제 데이터는 지운다.
data = data.loc[pd.DatetimeIndex(data.index).day == int(given_day[6:])]
buy_line = bsLine['buy'][381:]
@@ -144,7 +147,7 @@ class Simulation (HTS):
X, Y = self.stock2Vector.getDataset2D(data)
predY = self.stockPredictor.predict(X, Y)
print (predY)
bsLine = None
else:
LAST_DATA = self.stock2Vector.getLastData(stock_code, today)
result = self.stock2Vector.getRealTime(stock_code, today, LAST_DATA)
@@ -169,10 +172,10 @@ if __name__ == "__main__":
stock_codes = {
# 252670
# 122630
"252670": ['20200731'],
"252670": ['20220801', '20220802', '20220803', '20220804', '20220805'],
}
method = "ml" # "ml", "answer"
method = "rul" # "rul", "ml", "answer"
for stock_code in stock_codes:
simulation = Simulation(RESOURCE_PATH)

View File

@@ -52,7 +52,7 @@ class VitTrainer:
metric_for_best_model="accuracy",
logging_dir='logs',
remove_unused_columns=False,
num_train_epochs=20,
num_train_epochs=14,
)
return

View File

@@ -24,7 +24,7 @@ class StockPredictor:
def __init__(self, RESOURCE_PATH):
self.RESOURCE_PATH = RESOURCE_PATH
self.model_dir = os.path.join(RESOURCE_PATH, "tmp")
self.model_dir = os.path.join(RESOURCE_PATH, "model")
self.stock2Vector = Stock2Vector(RESOURCE_PATH)
self.set_seed(42)