diff --git a/VitTrainer.py b/VitTrainer.py index b3afed1..505cd06 100755 --- a/VitTrainer.py +++ b/VitTrainer.py @@ -44,14 +44,14 @@ class VitTrainer: save_strategy="epoch", evaluation_strategy="epoch", learning_rate=2e-5, - per_device_train_batch_size=32, - per_device_eval_batch_size=32, + per_device_train_batch_size=128, + per_device_eval_batch_size=128, weight_decay=0.01, load_best_model_at_end=True, metric_for_best_model="accuracy", logging_dir=os.path.join(self.RESOURCE_PATH, 'model', 'logs'), remove_unused_columns=False, - num_train_epochs=4, + num_train_epochs=24, ) return @@ -236,5 +236,5 @@ if __name__ == "__main__": stock_code = "252670" vitTrainer = VitTrainer(RESOURCE_PATH) - train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220701", eDate="20220731") + train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220501", eDate="20220819") vitTrainer.train(train_ds, val_ds, model_path)