This commit is contained in:
dsyoon
2022-08-20 10:26:15 +09:00
parent b6dedc9275
commit 9a6e552cb0

View File

@@ -368,12 +368,12 @@ class Stock2Vector(HTS):
size = len(label) size = len(label)
batch_X, batch_Y = [], [] batch_X, batch_Y = [], []
CHANNEL_SIZE = 4 CHANNEL_SIZE = 3
for i in range(VECTOR_SIZE*CHANNEL_SIZE-1, size): for i in range(VECTOR_SIZE*CHANNEL_SIZE-1, size):
X = np.zeros((CHANNEL_SIZE, VECTOR_SIZE, VECTOR_SIZE)) X = np.zeros((CHANNEL_SIZE, VECTOR_SIZE, VECTOR_SIZE))
s = i - VECTOR_SIZE*CHANNEL_SIZE + 1 s = i - VECTOR_SIZE*CHANNEL_SIZE + 1
e = s+VECTOR_SIZE e = s+VECTOR_SIZE
for c in range(0, 4): for c in range(0, CHANNEL_SIZE):
if c > 0: if c > 0:
s = e s = e
e += VECTOR_SIZE e += VECTOR_SIZE