This commit is contained in:
dosangyoon
2022-08-07 17:53:10 +09:00
parent 69c52b5e1b
commit 0796547466
2 changed files with 3 additions and 6 deletions

View File

@@ -1,6 +1,4 @@
# tensor - numpy - PILImage 변환 (https://qlsenddl-lab.tistory.com/37)
from PIL import Image
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import random
@@ -192,7 +190,8 @@ class VitTrainer:
def getData(self, stock_code, sDate, eDate):
data = self.stock2Vector.getTrainData(stock_code, sDate, eDate)
X, Y = self.stock2Vector.getDataset2D(data)
#X, Y = self.stock2Vector.getDataset2D(data)
X, Y = self.stock2Vector.getVectorData(data)
print("Data count: ", len(X))
trans = transforms.ToPILImage()
@@ -215,8 +214,6 @@ class VitTrainer:
if __name__ == "__main__":
image = Image.open("img.png")
PROJECT_HOME = os.getcwd()
RESOURCE_PATH = os.path.join(PROJECT_HOME, "resources")
model_path = os.path.join(RESOURCE_PATH, "model")

View File

@@ -285,7 +285,7 @@ class Stock2Vector(HTS):
return X, Y
def getVectorData(self, data, type="avg10", VECTOR_SIZE = 32):
X, Y = [], []
X, Y = np.zeros((VECTOR_SIZE, VECTOR_SIZE, 4)), np.zeros((32,1))
df, minmax_df = self.preprocessData(data)
X = np.asarray(X)