Files
DeepStock/VitTrainer.py
dsyoon 16d0b4b01c init
2022-08-07 20:18:33 +09:00

328 lines
12 KiB
Python
Executable File

# tensor - numpy - PILImage 변환 (https://qlsenddl-lab.tistory.com/37)
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import random
import numpy as np
import torch
from datasets import Dataset, load_dataset
from datasets import load_metric
from transformers import TrainingArguments, Trainer
from transformers import ViTForImageClassification
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from transformers import ViTFeatureExtractor
from torchvision.transforms import (CenterCrop, Compose, Normalize, RandomHorizontalFlip, RandomResizedCrop, Resize, ToTensor)
from stock.util.Stock2Vector import Stock2Vector
class VitTrainer:
RESOURCE_PATH = None
stock2Vector = None
num_labels = None
id2label = None
label2id = None
args = None
_train_transforms = None
_val_transforms = None
def __init__(self, RESOURCE_PATH):
self.set_seed(42)
self.RESOURCE_PATH = RESOURCE_PATH
self.stock2Vector = Stock2Vector(RESOURCE_PATH)
self.num_labels = 3
self.id2label = {0: 'none', 1: 'sell', 2: 'buy'}
self.label2id = {'none': 0, 'sell': 1, 'buy': 2}
self.args = TrainingArguments(
f"stock_vit_predictor",
save_strategy="epoch",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
logging_dir='logs',
remove_unused_columns=False,
num_train_epochs=14,
)
return
def set_seed(self, seed=42, n_gpu=0):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(seed)
def train_transforms(self, examples):
examples['pixel_values'] = [self._train_transforms(image.convert("RGB")) for image in examples['img']]
return examples
def val_transforms(self, examples):
examples['pixel_values'] = [self._val_transforms(image.convert("RGB")) for image in examples['img']]
return examples
def collate_fn(self, examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example["label"] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}
def compute_metrics(self, eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
metric = load_metric("accuracy")
return metric.compute(predictions=predictions, references=labels)
def getFeature(self, model_path=None):
if model_path == None:
self.feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
#self.feature_extractor = ViTFeatureExtractor()
else:
#self.feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_path)
normalize = Normalize(mean=self.feature_extractor.image_mean, std=self.feature_extractor.image_std)
self._train_transforms = Compose(
[
RandomResizedCrop(self.feature_extractor.size),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
)
self._val_transforms = Compose(
[
Resize(self.feature_extractor.size),
CenterCrop(self.feature_extractor.size),
ToTensor(),
normalize,
]
)
return
def train(self, train_ds, val_ds, model_path):
self.getFeature()
# Set the transforms
train_ds.set_transform(self.train_transforms)
val_ds.set_transform(self.val_transforms)
train_dataloader = DataLoader(train_ds, collate_fn=self.collate_fn, batch_size=4)
batch = next(iter(train_dataloader))
for k,v in batch.items():
if isinstance(v, torch.Tensor):
print(k, v.shape)
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
num_labels=self.num_labels,
id2label=self.id2label,
label2id=self.label2id)
model = ViTForImageClassification(model.config)
trainer = Trainer(
model,
self.args,
train_dataset=train_ds,
eval_dataset=val_ds,
data_collator=self.collate_fn,
compute_metrics=self.compute_metrics,
tokenizer=self.feature_extractor
)
trainer.train()
# save trained model
model_to_save = (model.module if hasattr(model, "module") else model) # Take care of distributed/parallel training
model_to_save.save_pretrained(model_path)
self.feature_extractor.save_pretrained(model_path)
torch.save(self.args, os.path.join(RESOURCE_PATH, "model", "training_args.bin"))
return
def finetunning(self, train_ds, val_ds, model_path):
self.getFeature(model_path)
# Set the transforms
train_ds.set_transform(self.train_transforms)
val_ds.set_transform(self.val_transforms)
train_dataloader = DataLoader(train_ds, collate_fn=self.collate_fn, batch_size=4)
batch = next(iter(train_dataloader))
for k,v in batch.items():
if isinstance(v, torch.Tensor):
print(k, v.shape)
model = ViTForImageClassification.from_pretrained(model_path,
num_labels=self.num_labels,
id2label=self.id2label,
label2id=self.label2id)
trainer = Trainer(
model,
self.args,
train_dataset=train_ds,
eval_dataset=val_ds,
data_collator=self.collate_fn,
compute_metrics=self.compute_metrics,
tokenizer=self.feature_extractor
)
trainer.train()
# save trained model
model_to_save = (model.module if hasattr(model, "module") else model) # Take care of distributed/parallel training
model_to_save.save_pretrained(model_path)
self.feature_extractor.save_pretrained(model_path)
torch.save(self.args, os.path.join(RESOURCE_PATH, "model", "training_args.bin"))
return
def getData(self, stock_code, sDate, eDate):
data = self.stock2Vector.getTrainData(stock_code, sDate, eDate)
#X, Y = self.stock2Vector.getDataset2D(data)
X, Y = self.stock2Vector.getVectorData(data)
print("Data count: ", len(X))
trans = transforms.ToPILImage()
#X = [trans(torch.tensor([x])) for x in X]
X = [trans(torch.tensor(x)) for x in X]
split_point1 = int(len(X) * 0.9)
train_X = X[:split_point1]
train_Y = Y[:split_point1]
valid_X = X[split_point1:]
valid_Y = Y[split_point1:]
# load cifar10 (only small portion for demonstration purposes)
train_data = {'img': train_X, 'label': train_Y}
val_dsta = {'img': valid_X, 'label': valid_Y}
train_ds = Dataset.from_dict(train_data)
val_ds = Dataset.from_dict(val_dsta)
return train_ds, val_ds
if __name__ == "__main__":
PROJECT_HOME = os.getcwd()
RESOURCE_PATH = os.path.join(PROJECT_HOME, "resources")
model_path = os.path.join(RESOURCE_PATH, "model")
stock_code = "252670"
vitTrainer = VitTrainer(RESOURCE_PATH)
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220701", eDate="20220731")
vitTrainer.train(train_ds, val_ds, model_path)
"""
print("ym: 2020-07")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20200701", eDate="20200731")
vitTrainer.train(train_ds, val_ds, model_path)
print ("ym: 2020-08")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20200725", eDate="20200831")
vitTrainer.finetunning(train_ds, val_ds, model_path)
print("ym: 2020-09")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20200825", eDate="20200931")
vitTrainer.finetunning(train_ds, val_ds, model_path)
print("ym: 2020-10")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20200925", eDate="20201031")
vitTrainer.finetunning(train_ds, val_ds, model_path)
print("ym: 2020-11")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20201025", eDate="20201131")
vitTrainer.finetunning(train_ds, val_ds, model_path)
print("ym: 2020-12")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20201125", eDate="20201231")
vitTrainer.finetunning(train_ds, val_ds, model_path)
print("ym: 2021-01")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20201225", eDate="20210131")
vitTrainer.finetunning(train_ds, val_ds, model_path)
print("ym: 2021-02")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210125", eDate="20210231")
vitTrainer.finetunning(train_ds, val_ds, model_path)
print("ym: 2021-03")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210225", eDate="20210331")
vitTrainer.finetunning(train_ds, val_ds, model_path)
print("ym: 2021-04")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210325", eDate="20210431")
vitTrainer.finetunning(train_ds, val_ds, model_path)
print("ym: 2021-05")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210425", eDate="20210531")
vitTrainer.finetunning(train_ds, val_ds, model_path)
print("ym: 2021-06")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210525", eDate="20210631")
vitTrainer.finetunning(train_ds, val_ds, model_path)
print("ym: 2021-07")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210625", eDate="20210731")
vitTrainer.finetunning(train_ds, val_ds, model_path)
print("ym: 2021-08")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210725", eDate="20210831")
vitTrainer.finetunning(train_ds, val_ds, model_path)
print("ym: 2021-09")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210825", eDate="20210931")
vitTrainer.finetunning(train_ds, val_ds, model_path)
print("ym: 2021-10")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210925", eDate="20212031")
vitTrainer.finetunning(train_ds, val_ds, model_path)
print("ym: 2021-11")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20211025", eDate="20211131")
vitTrainer.finetunning(train_ds, val_ds, model_path)
print("ym: 2021-12")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20211125", eDate="20211231")
vitTrainer.finetunning(train_ds, val_ds, model_path)
print("ym: 2022-01")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20211225", eDate="20220131")
vitTrainer.finetunning(train_ds, val_ds, model_path)
print("ym: 2022-02")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220125", eDate="20220231")
vitTrainer.finetunning(train_ds, val_ds, model_path)
print("ym: 2022-03")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220225", eDate="20220331")
vitTrainer.finetunning(train_ds, val_ds, model_path)
print("ym: 2022-04")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220325", eDate="20220431")
vitTrainer.finetunning(train_ds, val_ds, model_path)
print("ym: 2022-05")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220425", eDate="20220531")
vitTrainer.finetunning(train_ds, val_ds, model_path)
print("ym: 2022-06")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220525", eDate="20220631")
vitTrainer.finetunning(train_ds, val_ds, model_path)
print("ym: 2022-07")
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220625", eDate="20220731")
vitTrainer.finetunning(train_ds, val_ds, model_path)
"""