init
This commit is contained in:
@@ -44,14 +44,14 @@ class VitTrainer:
|
|||||||
save_strategy="epoch",
|
save_strategy="epoch",
|
||||||
evaluation_strategy="epoch",
|
evaluation_strategy="epoch",
|
||||||
learning_rate=2e-5,
|
learning_rate=2e-5,
|
||||||
per_device_train_batch_size=32,
|
per_device_train_batch_size=128,
|
||||||
per_device_eval_batch_size=32,
|
per_device_eval_batch_size=128,
|
||||||
weight_decay=0.01,
|
weight_decay=0.01,
|
||||||
load_best_model_at_end=True,
|
load_best_model_at_end=True,
|
||||||
metric_for_best_model="accuracy",
|
metric_for_best_model="accuracy",
|
||||||
logging_dir=os.path.join(self.RESOURCE_PATH, 'model', 'logs'),
|
logging_dir=os.path.join(self.RESOURCE_PATH, 'model', 'logs'),
|
||||||
remove_unused_columns=False,
|
remove_unused_columns=False,
|
||||||
num_train_epochs=4,
|
num_train_epochs=24,
|
||||||
)
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
@@ -236,5 +236,5 @@ if __name__ == "__main__":
|
|||||||
stock_code = "252670"
|
stock_code = "252670"
|
||||||
vitTrainer = VitTrainer(RESOURCE_PATH)
|
vitTrainer = VitTrainer(RESOURCE_PATH)
|
||||||
|
|
||||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220701", eDate="20220731")
|
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220501", eDate="20220819")
|
||||||
vitTrainer.train(train_ds, val_ds, model_path)
|
vitTrainer.train(train_ds, val_ds, model_path)
|
||||||
|
|||||||
Reference in New Issue
Block a user