Files
DeepStock/VitTrainer.py
dsyoon 31394e7694 init
2022-08-20 01:16:38 +09:00

241 lines
8.7 KiB
Python
Executable File

# tensor - numpy - PILImage 변환 (https://qlsenddl-lab.tistory.com/37)
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import random
import numpy as np
import torch
from datasets import Dataset, load_metric, ClassLabel
from datasets import load_metric
from transformers import ViTConfig, TrainingArguments, Trainer, ViTForImageClassification
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from transformers import ViTFeatureExtractor
from torchvision.transforms import (CenterCrop, Compose, Normalize, RandomHorizontalFlip, RandomResizedCrop, Resize, ToTensor)
from stock.util.Stock2Vector import Stock2Vector
class VitTrainer:
RESOURCE_PATH = None
stock2Vector = None
num_labels = None
id2label = None
label2id = None
args = None
_train_transforms = None
_val_transforms = None
def __init__(self, RESOURCE_PATH):
self.set_seed(42)
self.RESOURCE_PATH = RESOURCE_PATH
self.stock2Vector = Stock2Vector(RESOURCE_PATH)
self.num_labels = 3
self.id2label = {0: 'none', 1: 'sell', 2: 'buy'}
self.label2id = {'none': 0, 'sell': 1, 'buy': 2}
self.args = TrainingArguments(
os.path.join(self.RESOURCE_PATH, 'model', f"stock_vit_predictor"),
save_strategy="epoch",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
logging_dir=os.path.join(self.RESOURCE_PATH, 'model', 'logs'),
remove_unused_columns=False,
num_train_epochs=4,
)
return
def set_seed(self, seed=42, n_gpu=0):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(seed)
def train_transforms(self, examples):
examples['pixel_values'] = [self._train_transforms(image.convert("RGB")) for image in examples['img']]
return examples
def val_transforms(self, examples):
examples['pixel_values'] = [self._val_transforms(image.convert("RGB")) for image in examples['img']]
return examples
def collate_fn(self, examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example["label"] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}
def compute_metrics(self, eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
metric = load_metric("accuracy")
return metric.compute(predictions=predictions, references=labels)
def getFeature(self, model_path=None):
if model_path == None:
self.feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
#self.feature_extractor = ViTFeatureExtractor()
else:
#self.feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_path)
normalize = Normalize(mean=self.feature_extractor.image_mean, std=self.feature_extractor.image_std)
self._train_transforms = Compose(
[
RandomResizedCrop(self.feature_extractor.size),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
)
self._val_transforms = Compose(
[
Resize(self.feature_extractor.size),
CenterCrop(self.feature_extractor.size),
ToTensor(),
normalize,
]
)
return
def train(self, train_ds, val_ds, model_path):
self.getFeature()
# Set the transforms
train_ds.set_transform(self.train_transforms)
val_ds.set_transform(self.val_transforms)
train_dataloader = DataLoader(train_ds, collate_fn=self.collate_fn, batch_size=32)
batch = next(iter(train_dataloader))
for k,v in batch.items():
if isinstance(v, torch.Tensor):
print(k, v.shape)
"""
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
num_labels=self.num_labels,
id2label=self.id2label,
label2id=self.label2id)
"""
configuration = ViTConfig(num_labels=self.num_labels,
id2label=self.id2label,
label2id=self.label2id)
model = ViTForImageClassification(configuration)
trainer = Trainer(
model,
self.args,
train_dataset=train_ds,
eval_dataset=val_ds,
data_collator=self.collate_fn,
compute_metrics=self.compute_metrics,
tokenizer=self.feature_extractor
)
trainer.train()
# save trained model
model_to_save = (model.module if hasattr(model, "module") else model) # Take care of distributed/parallel training
model_to_save.save_pretrained(model_path)
self.feature_extractor.save_pretrained(model_path)
torch.save(self.args, os.path.join(RESOURCE_PATH, "model", "training_args.bin"))
return
def finetunning(self, train_ds, val_ds, model_path):
self.getFeature(model_path)
# Set the transforms
train_ds.set_transform(self.train_transforms)
val_ds.set_transform(self.val_transforms)
train_dataloader = DataLoader(train_ds, collate_fn=self.collate_fn, batch_size=32)
batch = next(iter(train_dataloader))
for k,v in batch.items():
if isinstance(v, torch.Tensor):
print(k, v.shape)
model = ViTForImageClassification.from_pretrained(model_path,
num_labels=self.num_labels,
id2label=self.id2label,
label2id=self.label2id)
trainer = Trainer(
model,
self.args,
train_dataset=train_ds,
eval_dataset=val_ds,
data_collator=self.collate_fn,
compute_metrics=self.compute_metrics,
tokenizer=self.feature_extractor
)
trainer.train()
# save trained model
model_to_save = (model.module if hasattr(model, "module") else model) # Take care of distributed/parallel training
model_to_save.save_pretrained(model_path)
self.feature_extractor.save_pretrained(model_path)
torch.save(self.args, os.path.join(RESOURCE_PATH, "model", "training_args.bin"))
return
def getData(self, stock_code, sDate, eDate):
# Instance Normalization를 NumPy 및 PyTorch로 구현하는 방법! (https://ndb796.tistory.com/653)
data = self.stock2Vector.getTrainData(stock_code, sDate, eDate)
#X, Y = self.stock2Vector.getDataset2D(data)
X, Y = self.stock2Vector.getVectorData(data)
print("Data count: ", len(X))
trans = transforms.ToPILImage()
#X = [trans(torch.tensor([x])) for x in X]
X = [trans(torch.tensor(x)) for x in X]
split_point1 = int(len(X) * 0.9)
train_X = X[:split_point1]
train_Y = Y[:split_point1]
valid_X = X[split_point1:]
valid_Y = Y[split_point1:]
# load cifar10 (only small portion for demonstration purposes)
train_data = {'img': train_X, 'label': train_Y}
val_dsta = {'img': valid_X, 'label': valid_Y}
train_ds = Dataset.from_dict(train_data)
val_ds = Dataset.from_dict(val_dsta)
features = train_ds.features.copy()
features["label"] = ClassLabel(num_classes=self.num_labels, names=["none", "sell", "buy"])
def adjust_labels(batch):
batch["label"] = [lbl for lbl in batch["label"]]
return batch
train_ds = train_ds.map(adjust_labels, batched=True, features=features)
val_ds = train_ds.map(adjust_labels, batched=True, features=features)
return train_ds, val_ds
if __name__ == "__main__":
PROJECT_HOME = os.getcwd()
RESOURCE_PATH = os.path.join(PROJECT_HOME, "resources")
model_path = os.path.join(RESOURCE_PATH, "model")
stock_code = "252670"
vitTrainer = VitTrainer(RESOURCE_PATH)
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220701", eDate="20220731")
vitTrainer.train(train_ds, val_ds, model_path)