init
This commit is contained in:
126
VitTrainer.py
126
VitTrainer.py
@@ -1,11 +1,13 @@
|
||||
# tensor - numpy - PILImage 변환 (https://qlsenddl-lab.tistory.com/37)
|
||||
|
||||
import os
|
||||
os.environ['KMP_DUPLICATE_LIB_OK']='True'
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Dataset, load_dataset
|
||||
from datasets import Dataset, load_metric, ClassLabel
|
||||
from datasets import load_metric
|
||||
from transformers import AutoConfig
|
||||
from transformers import TrainingArguments, Trainer
|
||||
from transformers import ViTForImageClassification
|
||||
from torch.utils.data import DataLoader
|
||||
@@ -44,14 +46,14 @@ class VitTrainer:
|
||||
save_strategy="epoch",
|
||||
evaluation_strategy="epoch",
|
||||
learning_rate=2e-5,
|
||||
per_device_train_batch_size=16,
|
||||
per_device_eval_batch_size=16,
|
||||
per_device_train_batch_size=32,
|
||||
per_device_eval_batch_size=32,
|
||||
weight_decay=0.01,
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="accuracy",
|
||||
logging_dir=os.path.join(self.RESOURCE_PATH, 'model', 'logs'),
|
||||
remove_unused_columns=False,
|
||||
num_train_epochs=14,
|
||||
num_train_epochs=4,
|
||||
)
|
||||
|
||||
return
|
||||
@@ -117,7 +119,7 @@ class VitTrainer:
|
||||
train_ds.set_transform(self.train_transforms)
|
||||
val_ds.set_transform(self.val_transforms)
|
||||
|
||||
train_dataloader = DataLoader(train_ds, collate_fn=self.collate_fn, batch_size=4)
|
||||
train_dataloader = DataLoader(train_ds, collate_fn=self.collate_fn, batch_size=32)
|
||||
|
||||
batch = next(iter(train_dataloader))
|
||||
for k,v in batch.items():
|
||||
@@ -157,7 +159,7 @@ class VitTrainer:
|
||||
train_ds.set_transform(self.train_transforms)
|
||||
val_ds.set_transform(self.val_transforms)
|
||||
|
||||
train_dataloader = DataLoader(train_ds, collate_fn=self.collate_fn, batch_size=4)
|
||||
train_dataloader = DataLoader(train_ds, collate_fn=self.collate_fn, batch_size=32)
|
||||
|
||||
batch = next(iter(train_dataloader))
|
||||
for k,v in batch.items():
|
||||
@@ -211,6 +213,14 @@ class VitTrainer:
|
||||
train_ds = Dataset.from_dict(train_data)
|
||||
val_ds = Dataset.from_dict(val_dsta)
|
||||
|
||||
features = train_ds.features.copy()
|
||||
features["label"] = ClassLabel(num_classes=self.num_labels, names=["none", "sell", "buy"])
|
||||
def adjust_labels(batch):
|
||||
batch["label"] = [lbl for lbl in batch["label"]]
|
||||
return batch
|
||||
train_ds = train_ds.map(adjust_labels, batched=True, features=features)
|
||||
val_ds = train_ds.map(adjust_labels, batched=True, features=features)
|
||||
|
||||
return train_ds, val_ds
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -222,107 +232,5 @@ if __name__ == "__main__":
|
||||
stock_code = "252670"
|
||||
vitTrainer = VitTrainer(RESOURCE_PATH)
|
||||
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220701", eDate="20220731")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220809", eDate="20220812")
|
||||
vitTrainer.train(train_ds, val_ds, model_path)
|
||||
|
||||
"""
|
||||
print("ym: 2020-07")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20200701", eDate="20200731")
|
||||
vitTrainer.train(train_ds, val_ds, model_path)
|
||||
|
||||
print ("ym: 2020-08")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20200725", eDate="20200831")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2020-09")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20200825", eDate="20200931")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2020-10")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20200925", eDate="20201031")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2020-11")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20201025", eDate="20201131")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2020-12")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20201125", eDate="20201231")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-01")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20201225", eDate="20210131")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-02")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210125", eDate="20210231")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-03")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210225", eDate="20210331")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-04")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210325", eDate="20210431")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-05")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210425", eDate="20210531")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-06")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210525", eDate="20210631")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-07")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210625", eDate="20210731")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-08")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210725", eDate="20210831")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-09")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210825", eDate="20210931")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-10")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210925", eDate="20212031")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-11")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20211025", eDate="20211131")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-12")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20211125", eDate="20211231")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2022-01")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20211225", eDate="20220131")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2022-02")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220125", eDate="20220231")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2022-03")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220225", eDate="20220331")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2022-04")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220325", eDate="20220431")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2022-05")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220425", eDate="20220531")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2022-06")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220525", eDate="20220631")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2022-07")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220625", eDate="20220731")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
"""
|
||||
Reference in New Issue
Block a user