119 lines
3.8 KiB
Python
119 lines
3.8 KiB
Python
# tensor - numpy - PILImage 변환 (https://qlsenddl-lab.tistory.com/37)
|
|
|
|
import os
|
|
|
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
|
import random
|
|
import numpy as np
|
|
from datasets import Dataset, load_dataset
|
|
import torch
|
|
import torchvision.transforms as transforms
|
|
|
|
try:
|
|
from transformers import ViTFeatureExtractor, ViTForImageClassification, TrainingArguments, Trainer
|
|
from torchvision.transforms import (CenterCrop, Compose, Normalize, RandomHorizontalFlip, RandomResizedCrop, Resize, ToTensor)
|
|
except:
|
|
pass
|
|
from stock.util.Stock2Vector import Stock2Vector
|
|
|
|
|
|
class StockPredictor:
|
|
RESOURCE_PATH = None
|
|
stock2Vector = None
|
|
model_dir = None
|
|
predictor = None
|
|
|
|
def __init__(self, RESOURCE_PATH):
|
|
self.RESOURCE_PATH = RESOURCE_PATH
|
|
|
|
self.model_dir = os.path.join(RESOURCE_PATH, "model")
|
|
self.stock2Vector = Stock2Vector(RESOURCE_PATH)
|
|
|
|
self.set_seed(42)
|
|
|
|
self.num_labels = 3
|
|
self.id2label = {0: 'none', 1: 'sell', 2: 'buy'}
|
|
self.label2id = {'none': 0, 'sell': 1, 'buy': 2}
|
|
|
|
self.trans = transforms.ToPILImage()
|
|
self.predictor = self.loadModel()
|
|
|
|
return
|
|
|
|
def set_seed(self, seed=42, n_gpu=0):
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if n_gpu > 0:
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
def loadModel(self):
|
|
feature_extractor = ViTFeatureExtractor.from_pretrained(self.model_dir)
|
|
|
|
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
|
|
|
|
self._test_transforms = Compose(
|
|
[
|
|
Resize(feature_extractor.size),
|
|
CenterCrop(feature_extractor.size),
|
|
ToTensor(),
|
|
normalize,
|
|
]
|
|
)
|
|
|
|
model = ViTForImageClassification.from_pretrained(self.model_dir,
|
|
num_labels=self.num_labels,
|
|
id2label=self.id2label,
|
|
label2id=self.label2id)
|
|
args = TrainingArguments(
|
|
os.path.join(self.RESOURCE_PATH, 'model', f"stock_vit_predictor"),
|
|
save_strategy="epoch",
|
|
evaluation_strategy="epoch",
|
|
learning_rate=2e-5,
|
|
per_device_train_batch_size=762,
|
|
per_device_eval_batch_size=762,
|
|
weight_decay=0.01,
|
|
load_best_model_at_end=True,
|
|
metric_for_best_model="accuracy",
|
|
logging_dir=os.path.join(self.RESOURCE_PATH, 'model', 'logs'),
|
|
remove_unused_columns=False,
|
|
num_train_epochs=4,
|
|
)
|
|
|
|
trainer = Trainer(
|
|
model,
|
|
args,
|
|
data_collator=self.collate_fn,
|
|
tokenizer=feature_extractor,
|
|
)
|
|
|
|
return trainer
|
|
|
|
def test_transforms(self, examples):
|
|
examples['pixel_values'] = [self._test_transforms(image.convert("RGB")) for image in examples['img']]
|
|
return examples
|
|
|
|
def collate_fn(self, examples):
|
|
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
|
#labels = torch.tensor([example["label"] for example in examples])
|
|
#return {"pixel_values": pixel_values, "labels": labels}
|
|
return {"pixel_values": pixel_values}
|
|
|
|
def predict(self, X, Y=None):
|
|
print("Data count: ", len(X))
|
|
|
|
X = [self.trans(torch.tensor(x)) for x in X]
|
|
|
|
test_X = X
|
|
test_Y = Y
|
|
|
|
# load cifar10 (only small portion for demonstration purposes)
|
|
test_data = {'img': test_X, 'label': test_Y}
|
|
|
|
test_ds = Dataset.from_dict(test_data)
|
|
|
|
# Set the transforms
|
|
test_ds.set_transform(self.test_transforms)
|
|
|
|
outputs = self.predictor.predict(test_ds)
|
|
return outputs.predictions |