Merge remote-tracking branch 'origin/master'
This commit is contained in:
@@ -44,14 +44,14 @@ class VitTrainer:
|
||||
save_strategy="epoch",
|
||||
evaluation_strategy="epoch",
|
||||
learning_rate=2e-5,
|
||||
per_device_train_batch_size=32,
|
||||
per_device_eval_batch_size=32,
|
||||
per_device_train_batch_size=128,
|
||||
per_device_eval_batch_size=128,
|
||||
weight_decay=0.01,
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="accuracy",
|
||||
logging_dir=os.path.join(self.RESOURCE_PATH, 'model', 'logs'),
|
||||
remove_unused_columns=False,
|
||||
num_train_epochs=4,
|
||||
num_train_epochs=24,
|
||||
)
|
||||
|
||||
return
|
||||
@@ -236,5 +236,5 @@ if __name__ == "__main__":
|
||||
stock_code = "252670"
|
||||
vitTrainer = VitTrainer(RESOURCE_PATH)
|
||||
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220701", eDate="20220731")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220501", eDate="20220819")
|
||||
vitTrainer.train(train_ds, val_ds, model_path)
|
||||
|
||||
Reference in New Issue
Block a user