# tensor - numpy - PILImage 변환 (https://qlsenddl-lab.tistory.com/37) import os os.environ['KMP_DUPLICATE_LIB_OK']='True' import random import numpy as np import torch from datasets import Dataset, load_dataset from datasets import load_metric from transformers import TrainingArguments, Trainer from transformers import ViTForImageClassification from torch.utils.data import DataLoader import torchvision.transforms as transforms from transformers import ViTFeatureExtractor from torchvision.transforms import (CenterCrop, Compose, Normalize, RandomHorizontalFlip, RandomResizedCrop, Resize, ToTensor) from stock.util.Stock2Vector import Stock2Vector class VitTrainer: RESOURCE_PATH = None stock2Vector = None num_labels = None id2label = None label2id = None args = None _train_transforms = None _val_transforms = None def __init__(self, RESOURCE_PATH): self.set_seed(42) self.RESOURCE_PATH = RESOURCE_PATH self.stock2Vector = Stock2Vector(RESOURCE_PATH) self.num_labels = 3 self.id2label = {0: 'none', 1: 'sell', 2: 'buy'} self.label2id = {'none': 0, 'sell': 1, 'buy': 2} self.args = TrainingArguments( os.path.join(self.RESOURCE_PATH, 'model', f"stock_vit_predictor"), save_strategy="epoch", evaluation_strategy="epoch", learning_rate=2e-5, per_device_train_batch_size=16, per_device_eval_batch_size=16, weight_decay=0.01, load_best_model_at_end=True, metric_for_best_model="accuracy", logging_dir=os.path.join(self.RESOURCE_PATH, 'model', 'logs'), remove_unused_columns=False, num_train_epochs=14, ) return def set_seed(self, seed=42, n_gpu=0): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if n_gpu > 0: torch.cuda.manual_seed_all(seed) def train_transforms(self, examples): examples['pixel_values'] = [self._train_transforms(image.convert("RGB")) for image in examples['img']] return examples def val_transforms(self, examples): examples['pixel_values'] = [self._val_transforms(image.convert("RGB")) for image in examples['img']] return examples def collate_fn(self, examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) labels = torch.tensor([example["label"] for example in examples]) return {"pixel_values": pixel_values, "labels": labels} def compute_metrics(self, eval_pred): predictions, labels = eval_pred predictions = np.argmax(predictions, axis=1) metric = load_metric("accuracy") return metric.compute(predictions=predictions, references=labels) def getFeature(self, model_path=None): if model_path == None: self.feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") #self.feature_extractor = ViTFeatureExtractor() else: #self.feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_path) normalize = Normalize(mean=self.feature_extractor.image_mean, std=self.feature_extractor.image_std) self._train_transforms = Compose( [ RandomResizedCrop(self.feature_extractor.size), RandomHorizontalFlip(), ToTensor(), normalize, ] ) self._val_transforms = Compose( [ Resize(self.feature_extractor.size), CenterCrop(self.feature_extractor.size), ToTensor(), normalize, ] ) return def train(self, train_ds, val_ds, model_path): self.getFeature() # Set the transforms train_ds.set_transform(self.train_transforms) val_ds.set_transform(self.val_transforms) train_dataloader = DataLoader(train_ds, collate_fn=self.collate_fn, batch_size=4) batch = next(iter(train_dataloader)) for k,v in batch.items(): if isinstance(v, torch.Tensor): print(k, v.shape) model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=self.num_labels, id2label=self.id2label, label2id=self.label2id) model = ViTForImageClassification(model.config) trainer = Trainer( model, self.args, train_dataset=train_ds, eval_dataset=val_ds, data_collator=self.collate_fn, compute_metrics=self.compute_metrics, tokenizer=self.feature_extractor ) trainer.train() # save trained model model_to_save = (model.module if hasattr(model, "module") else model) # Take care of distributed/parallel training model_to_save.save_pretrained(model_path) self.feature_extractor.save_pretrained(model_path) torch.save(self.args, os.path.join(RESOURCE_PATH, "model", "training_args.bin")) return def finetunning(self, train_ds, val_ds, model_path): self.getFeature(model_path) # Set the transforms train_ds.set_transform(self.train_transforms) val_ds.set_transform(self.val_transforms) train_dataloader = DataLoader(train_ds, collate_fn=self.collate_fn, batch_size=4) batch = next(iter(train_dataloader)) for k,v in batch.items(): if isinstance(v, torch.Tensor): print(k, v.shape) model = ViTForImageClassification.from_pretrained(model_path, num_labels=self.num_labels, id2label=self.id2label, label2id=self.label2id) trainer = Trainer( model, self.args, train_dataset=train_ds, eval_dataset=val_ds, data_collator=self.collate_fn, compute_metrics=self.compute_metrics, tokenizer=self.feature_extractor ) trainer.train() # save trained model model_to_save = (model.module if hasattr(model, "module") else model) # Take care of distributed/parallel training model_to_save.save_pretrained(model_path) self.feature_extractor.save_pretrained(model_path) torch.save(self.args, os.path.join(RESOURCE_PATH, "model", "training_args.bin")) return def getData(self, stock_code, sDate, eDate): data = self.stock2Vector.getTrainData(stock_code, sDate, eDate) #X, Y = self.stock2Vector.getDataset2D(data) X, Y = self.stock2Vector.getVectorData(data) print("Data count: ", len(X)) trans = transforms.ToPILImage() #X = [trans(torch.tensor([x])) for x in X] X = [trans(torch.tensor(x)) for x in X] split_point1 = int(len(X) * 0.9) train_X = X[:split_point1] train_Y = Y[:split_point1] valid_X = X[split_point1:] valid_Y = Y[split_point1:] # load cifar10 (only small portion for demonstration purposes) train_data = {'img': train_X, 'label': train_Y} val_dsta = {'img': valid_X, 'label': valid_Y} train_ds = Dataset.from_dict(train_data) val_ds = Dataset.from_dict(val_dsta) return train_ds, val_ds if __name__ == "__main__": PROJECT_HOME = os.getcwd() RESOURCE_PATH = os.path.join(PROJECT_HOME, "resources") model_path = os.path.join(RESOURCE_PATH, "model") stock_code = "252670" vitTrainer = VitTrainer(RESOURCE_PATH) train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220701", eDate="20220731") vitTrainer.train(train_ds, val_ds, model_path) """ print("ym: 2020-07") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20200701", eDate="20200731") vitTrainer.train(train_ds, val_ds, model_path) print ("ym: 2020-08") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20200725", eDate="20200831") vitTrainer.finetunning(train_ds, val_ds, model_path) print("ym: 2020-09") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20200825", eDate="20200931") vitTrainer.finetunning(train_ds, val_ds, model_path) print("ym: 2020-10") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20200925", eDate="20201031") vitTrainer.finetunning(train_ds, val_ds, model_path) print("ym: 2020-11") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20201025", eDate="20201131") vitTrainer.finetunning(train_ds, val_ds, model_path) print("ym: 2020-12") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20201125", eDate="20201231") vitTrainer.finetunning(train_ds, val_ds, model_path) print("ym: 2021-01") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20201225", eDate="20210131") vitTrainer.finetunning(train_ds, val_ds, model_path) print("ym: 2021-02") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210125", eDate="20210231") vitTrainer.finetunning(train_ds, val_ds, model_path) print("ym: 2021-03") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210225", eDate="20210331") vitTrainer.finetunning(train_ds, val_ds, model_path) print("ym: 2021-04") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210325", eDate="20210431") vitTrainer.finetunning(train_ds, val_ds, model_path) print("ym: 2021-05") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210425", eDate="20210531") vitTrainer.finetunning(train_ds, val_ds, model_path) print("ym: 2021-06") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210525", eDate="20210631") vitTrainer.finetunning(train_ds, val_ds, model_path) print("ym: 2021-07") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210625", eDate="20210731") vitTrainer.finetunning(train_ds, val_ds, model_path) print("ym: 2021-08") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210725", eDate="20210831") vitTrainer.finetunning(train_ds, val_ds, model_path) print("ym: 2021-09") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210825", eDate="20210931") vitTrainer.finetunning(train_ds, val_ds, model_path) print("ym: 2021-10") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210925", eDate="20212031") vitTrainer.finetunning(train_ds, val_ds, model_path) print("ym: 2021-11") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20211025", eDate="20211131") vitTrainer.finetunning(train_ds, val_ds, model_path) print("ym: 2021-12") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20211125", eDate="20211231") vitTrainer.finetunning(train_ds, val_ds, model_path) print("ym: 2022-01") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20211225", eDate="20220131") vitTrainer.finetunning(train_ds, val_ds, model_path) print("ym: 2022-02") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220125", eDate="20220231") vitTrainer.finetunning(train_ds, val_ds, model_path) print("ym: 2022-03") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220225", eDate="20220331") vitTrainer.finetunning(train_ds, val_ds, model_path) print("ym: 2022-04") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220325", eDate="20220431") vitTrainer.finetunning(train_ds, val_ds, model_path) print("ym: 2022-05") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220425", eDate="20220531") vitTrainer.finetunning(train_ds, val_ds, model_path) print("ym: 2022-06") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220525", eDate="20220631") vitTrainer.finetunning(train_ds, val_ds, model_path) print("ym: 2022-07") train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220625", eDate="20220731") vitTrainer.finetunning(train_ds, val_ds, model_path) """