This commit is contained in:
dosangyoon
2022-08-07 17:46:06 +09:00
parent 655eedc421
commit 57a906bc08
5 changed files with 53 additions and 26 deletions

View File

@@ -284,6 +284,15 @@ class Stock2Vector(HTS):
Y = np.asarray(Y, dtype='int64')
return X, Y
def getVectorData(self, data, type="avg10", VECTOR_SIZE = 32):
X, Y = [], []
df, minmax_df = self.preprocessData(data)
X = np.asarray(X)
Y = np.asarray(Y, dtype='int64')
return X, Y
def getDataset3D(self, data, VECTOR_SIZE = 299):
df, minmax_df = self.preprocessData(data)

View File

@@ -9,9 +9,11 @@ from datasets import Dataset, load_dataset
import torch
import torchvision.transforms as transforms
from transformers import ViTFeatureExtractor, ViTForImageClassification, TrainingArguments, Trainer
from torchvision.transforms import (CenterCrop, Compose, Normalize, RandomHorizontalFlip, RandomResizedCrop, Resize, ToTensor)
try:
from transformers import ViTFeatureExtractor, ViTForImageClassification, TrainingArguments, Trainer
from torchvision.transforms import (CenterCrop, Compose, Normalize, RandomHorizontalFlip, RandomResizedCrop, Resize, ToTensor)
except:
pass
from stock.util.Stock2Vector import Stock2Vector