This commit is contained in:
dsyoon
2022-08-29 15:41:05 +09:00
parent 5815dc303f
commit 4a5fb415ef

View File

@@ -44,14 +44,14 @@ class VitTrainer:
save_strategy="epoch",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
per_device_train_batch_size=128,
per_device_eval_batch_size=128,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
logging_dir=os.path.join(self.RESOURCE_PATH, 'model', 'logs'),
remove_unused_columns=False,
num_train_epochs=4,
num_train_epochs=24,
)
return
@@ -236,5 +236,5 @@ if __name__ == "__main__":
stock_code = "252670"
vitTrainer = VitTrainer(RESOURCE_PATH)
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220701", eDate="20220731")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220501", eDate="20220819")
vitTrainer.train(train_ds, val_ds, model_path)