diff --git a/Simulation.py b/Simulation.py index 5b62658..c44caba 100644 --- a/Simulation.py +++ b/Simulation.py @@ -28,6 +28,9 @@ class Simulation (HTS): return def draw(self, stock_code, given_day, data, bsLine): + if bsLine is None: + return + # 어제 데이터는 지운다. data = data.loc[pd.DatetimeIndex(data.index).day == int(given_day[6:])] buy_line = bsLine['buy'][381:] @@ -144,7 +147,7 @@ class Simulation (HTS): X, Y = self.stock2Vector.getDataset2D(data) predY = self.stockPredictor.predict(X, Y) - print (predY) + bsLine = None else: LAST_DATA = self.stock2Vector.getLastData(stock_code, today) result = self.stock2Vector.getRealTime(stock_code, today, LAST_DATA) @@ -169,10 +172,10 @@ if __name__ == "__main__": stock_codes = { # 252670 # 122630 - "252670": ['20200731'], + "252670": ['20220801', '20220802', '20220803', '20220804', '20220805'], } - method = "ml" # "ml", "answer" + method = "rul" # "rul", "ml", "answer" for stock_code in stock_codes: simulation = Simulation(RESOURCE_PATH) diff --git a/VitTrainer.py b/VitTrainer.py index e127586..6adc103 100755 --- a/VitTrainer.py +++ b/VitTrainer.py @@ -52,7 +52,7 @@ class VitTrainer: metric_for_best_model="accuracy", logging_dir='logs', remove_unused_columns=False, - num_train_epochs=20, + num_train_epochs=14, ) return diff --git a/stock/util/StockPredictor.py b/stock/util/StockPredictor.py index c667eac..b4f8a70 100644 --- a/stock/util/StockPredictor.py +++ b/stock/util/StockPredictor.py @@ -24,7 +24,7 @@ class StockPredictor: def __init__(self, RESOURCE_PATH): self.RESOURCE_PATH = RESOURCE_PATH - self.model_dir = os.path.join(RESOURCE_PATH, "tmp") + self.model_dir = os.path.join(RESOURCE_PATH, "model") self.stock2Vector = Stock2Vector(RESOURCE_PATH) self.set_seed(42)