init
This commit is contained in:
398
VitTrainer.py
398
VitTrainer.py
@@ -2,156 +2,326 @@
|
||||
|
||||
import os
|
||||
os.environ['KMP_DUPLICATE_LIB_OK']='True'
|
||||
from datasets import Dataset
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Dataset, load_dataset
|
||||
from datasets import load_metric
|
||||
from transformers import TrainingArguments, Trainer
|
||||
from transformers import ViTForImageClassification
|
||||
from torch.utils.data import DataLoader
|
||||
import torchvision.transforms as transforms
|
||||
from transformers import ViTFeatureExtractor
|
||||
from torchvision.transforms import (CenterCrop, Compose, Normalize, RandomHorizontalFlip, RandomResizedCrop, Resize, ToTensor)
|
||||
|
||||
from stock.util.Stock2Vector import Stock2Vector
|
||||
|
||||
PROJECT_HOME = os.path.join(os.path.dirname(__file__))
|
||||
RESOURCE_PATH = os.path.join(PROJECT_HOME, "resources")
|
||||
stock2Vector = Stock2Vector(RESOURCE_PATH)
|
||||
X, Y = stock2Vector.getDataset2D("252670")
|
||||
class VitTrainer:
|
||||
|
||||
trans = transforms.ToPILImage()
|
||||
X = [trans(torch.tensor([x])) for x in X]
|
||||
RESOURCE_PATH = None
|
||||
stock2Vector = None
|
||||
|
||||
split_point1 = int(len(X)*0.7)
|
||||
split_point2 = int(len(X)*0.9)
|
||||
train_X = X[:split_point1]
|
||||
train_Y = Y[:split_point1]
|
||||
valid_X = X[split_point1:split_point2]
|
||||
valid_Y = X[split_point1:split_point2]
|
||||
test_X = X[split_point2:]
|
||||
test_Y = X[split_point2:]
|
||||
num_labels = None
|
||||
id2label = None
|
||||
label2id = None
|
||||
|
||||
id2label = {0: '0', 1: '1', 2: '2'}
|
||||
label2id = {'0': 0, '1': 1, '2': 2}
|
||||
args = None
|
||||
|
||||
# load cifar10 (only small portion for demonstration purposes)
|
||||
train_data = {'img': train_X, 'label': train_Y}
|
||||
val_dsta = {'img': valid_X, 'label': valid_Y}
|
||||
test_data = {'img': test_X, 'label': test_Y}
|
||||
_train_transforms = None
|
||||
_val_transforms = None
|
||||
|
||||
train_ds = Dataset.from_dict(train_data)
|
||||
val_ds = Dataset.from_dict(val_dsta)
|
||||
test_ds = Dataset.from_dict(test_data)
|
||||
def __init__(self, RESOURCE_PATH):
|
||||
self.set_seed(42)
|
||||
|
||||
from transformers import ViTFeatureExtractor
|
||||
self.RESOURCE_PATH = RESOURCE_PATH
|
||||
self.stock2Vector = Stock2Vector(RESOURCE_PATH)
|
||||
|
||||
feature_extractor = ViTFeatureExtractor()
|
||||
self.num_labels = 3
|
||||
self.id2label = {0: 'none', 1: 'sell', 2: 'buy'}
|
||||
self.label2id = {'none': 0, 'sell': 1, 'buy': 2}
|
||||
|
||||
from torchvision.transforms import (CenterCrop,
|
||||
Compose,
|
||||
Normalize,
|
||||
RandomHorizontalFlip,
|
||||
RandomResizedCrop,
|
||||
Resize,
|
||||
ToTensor)
|
||||
self.args = TrainingArguments(
|
||||
f"stock_vit_predictor",
|
||||
save_strategy="epoch",
|
||||
evaluation_strategy="epoch",
|
||||
learning_rate=2e-5,
|
||||
per_device_train_batch_size=381,
|
||||
per_device_eval_batch_size=381,
|
||||
weight_decay=0.01,
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="accuracy",
|
||||
logging_dir='logs',
|
||||
remove_unused_columns=False,
|
||||
num_train_epochs=20,
|
||||
)
|
||||
|
||||
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
|
||||
_train_transforms = Compose(
|
||||
[
|
||||
RandomResizedCrop(feature_extractor.size),
|
||||
RandomHorizontalFlip(),
|
||||
ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
)
|
||||
return
|
||||
|
||||
_val_transforms = Compose(
|
||||
[
|
||||
Resize(feature_extractor.size),
|
||||
CenterCrop(feature_extractor.size),
|
||||
ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
)
|
||||
def set_seed(self, seed=42, n_gpu=0):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
def train_transforms(examples):
|
||||
examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['img']]
|
||||
return examples
|
||||
def train_transforms(self, examples):
|
||||
examples['pixel_values'] = [self._train_transforms(image.convert("RGB")) for image in examples['img']]
|
||||
return examples
|
||||
|
||||
def val_transforms(examples):
|
||||
examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['img']]
|
||||
return examples
|
||||
def val_transforms(self, examples):
|
||||
examples['pixel_values'] = [self._val_transforms(image.convert("RGB")) for image in examples['img']]
|
||||
return examples
|
||||
|
||||
# Set the transforms
|
||||
train_ds.set_transform(train_transforms)
|
||||
val_ds.set_transform(val_transforms)
|
||||
test_ds.set_transform(val_transforms)
|
||||
def collate_fn(self, examples):
|
||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||
labels = torch.tensor([example["label"] for example in examples])
|
||||
return {"pixel_values": pixel_values, "labels": labels}
|
||||
|
||||
def compute_metrics(self, eval_pred):
|
||||
predictions, labels = eval_pred
|
||||
predictions = np.argmax(predictions, axis=1)
|
||||
metric = load_metric("accuracy")
|
||||
return metric.compute(predictions=predictions, references=labels)
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
def getFeature(self, model_path=None):
|
||||
if model_path == None:
|
||||
self.feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
|
||||
#self.feature_extractor = ViTFeatureExtractor()
|
||||
else:
|
||||
#self.feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
|
||||
self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_path)
|
||||
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||
labels = torch.tensor([example["label"] for example in examples])
|
||||
return {"pixel_values": pixel_values, "labels": labels}
|
||||
normalize = Normalize(mean=self.feature_extractor.image_mean, std=self.feature_extractor.image_std)
|
||||
self._train_transforms = Compose(
|
||||
[
|
||||
RandomResizedCrop(self.feature_extractor.size),
|
||||
RandomHorizontalFlip(),
|
||||
ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
)
|
||||
|
||||
train_dataloader = DataLoader(train_ds, collate_fn=collate_fn, batch_size=4)
|
||||
train_data_loader = torch.utils.data.DataLoader(train_X,
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
num_workers=16)
|
||||
self._val_transforms = Compose(
|
||||
[
|
||||
Resize(self.feature_extractor.size),
|
||||
CenterCrop(self.feature_extractor.size),
|
||||
ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
)
|
||||
return
|
||||
|
||||
batch = next(iter(train_dataloader))
|
||||
for k,v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
print(k, v.shape)
|
||||
def train(self, train_ds, val_ds, model_path):
|
||||
self.getFeature()
|
||||
|
||||
# Set the transforms
|
||||
train_ds.set_transform(self.train_transforms)
|
||||
val_ds.set_transform(self.val_transforms)
|
||||
|
||||
from transformers import ViTForImageClassification
|
||||
train_dataloader = DataLoader(train_ds, collate_fn=self.collate_fn, batch_size=4)
|
||||
|
||||
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
|
||||
num_labels=10,
|
||||
id2label=id2label,
|
||||
label2id=label2id)
|
||||
batch = next(iter(train_dataloader))
|
||||
for k,v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
print(k, v.shape)
|
||||
|
||||
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
|
||||
num_labels=self.num_labels,
|
||||
id2label=self.id2label,
|
||||
label2id=self.label2id)
|
||||
model = ViTForImageClassification(model.config)
|
||||
|
||||
from transformers import TrainingArguments, Trainer
|
||||
trainer = Trainer(
|
||||
model,
|
||||
self.args,
|
||||
train_dataset=train_ds,
|
||||
eval_dataset=val_ds,
|
||||
data_collator=self.collate_fn,
|
||||
compute_metrics=self.compute_metrics,
|
||||
tokenizer=self.feature_extractor
|
||||
)
|
||||
|
||||
metric_name = "accuracy"
|
||||
trainer.train()
|
||||
|
||||
args = TrainingArguments(
|
||||
f"test-cifar-10",
|
||||
save_strategy="epoch",
|
||||
evaluation_strategy="epoch",
|
||||
learning_rate=2e-5,
|
||||
per_device_train_batch_size=10,
|
||||
per_device_eval_batch_size=4,
|
||||
num_train_epochs=3,
|
||||
weight_decay=0.01,
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model=metric_name,
|
||||
logging_dir='logs',
|
||||
remove_unused_columns=False,
|
||||
)
|
||||
# save trained model
|
||||
model_to_save = (model.module if hasattr(model, "module") else model) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(model_path)
|
||||
self.feature_extractor.save_pretrained(model_path)
|
||||
torch.save(self.args, os.path.join(RESOURCE_PATH, "model", "training_args.bin"))
|
||||
|
||||
return
|
||||
|
||||
from datasets import load_metric
|
||||
import numpy as np
|
||||
def finetunning(self, train_ds, val_ds, model_path):
|
||||
self.getFeature(model_path)
|
||||
|
||||
metric = load_metric("accuracy")
|
||||
# Set the transforms
|
||||
train_ds.set_transform(self.train_transforms)
|
||||
val_ds.set_transform(self.val_transforms)
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
predictions, labels = eval_pred
|
||||
predictions = np.argmax(predictions, axis=1)
|
||||
return metric.compute(predictions=predictions, references=labels)
|
||||
train_dataloader = DataLoader(train_ds, collate_fn=self.collate_fn, batch_size=4)
|
||||
|
||||
batch = next(iter(train_dataloader))
|
||||
for k,v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
print(k, v.shape)
|
||||
|
||||
import torch
|
||||
model = ViTForImageClassification.from_pretrained(model_path,
|
||||
num_labels=self.num_labels,
|
||||
id2label=self.id2label,
|
||||
label2id=self.label2id)
|
||||
trainer = Trainer(
|
||||
model,
|
||||
self.args,
|
||||
train_dataset=train_ds,
|
||||
eval_dataset=val_ds,
|
||||
data_collator=self.collate_fn,
|
||||
compute_metrics=self.compute_metrics,
|
||||
tokenizer=self.feature_extractor
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model,
|
||||
args,
|
||||
train_dataset=train_ds,
|
||||
eval_dataset=val_ds,
|
||||
data_collator=collate_fn,
|
||||
compute_metrics=compute_metrics,
|
||||
tokenizer=feature_extractor,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# save trained model
|
||||
model_to_save = (model.module if hasattr(model, "module") else model) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(model_path)
|
||||
self.feature_extractor.save_pretrained(model_path)
|
||||
torch.save(self.args, os.path.join(RESOURCE_PATH, "model", "training_args.bin"))
|
||||
|
||||
trainer.train()
|
||||
return
|
||||
|
||||
def getData(self, stock_code, sDate, eDate):
|
||||
data = self.stock2Vector.getTrainData(stock_code, sDate, eDate)
|
||||
X, Y = self.stock2Vector.getDataset2D(data)
|
||||
print("Data count: ", len(X))
|
||||
|
||||
trans = transforms.ToPILImage()
|
||||
X = [trans(torch.tensor([x])) for x in X]
|
||||
|
||||
split_point1 = int(len(X) * 0.9)
|
||||
train_X = X[:split_point1]
|
||||
train_Y = Y[:split_point1]
|
||||
valid_X = X[split_point1:]
|
||||
valid_Y = Y[split_point1:]
|
||||
|
||||
# load cifar10 (only small portion for demonstration purposes)
|
||||
train_data = {'img': train_X, 'label': train_Y}
|
||||
val_dsta = {'img': valid_X, 'label': valid_Y}
|
||||
|
||||
train_ds = Dataset.from_dict(train_data)
|
||||
val_ds = Dataset.from_dict(val_dsta)
|
||||
|
||||
return train_ds, val_ds
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
PROJECT_HOME = os.getcwd()
|
||||
RESOURCE_PATH = os.path.join(PROJECT_HOME, "resources")
|
||||
model_path = os.path.join(RESOURCE_PATH, "model")
|
||||
|
||||
stock_code = "252670"
|
||||
vitTrainer = VitTrainer(RESOURCE_PATH)
|
||||
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20200729", eDate="20200731")
|
||||
vitTrainer.train(train_ds, val_ds, model_path)
|
||||
|
||||
"""
|
||||
print("ym: 2020-07")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20200701", eDate="20200731")
|
||||
vitTrainer.train(train_ds, val_ds, model_path)
|
||||
|
||||
print ("ym: 2020-08")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20200725", eDate="20200831")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2020-09")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20200825", eDate="20200931")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2020-10")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20200925", eDate="20201031")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2020-11")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20201025", eDate="20201131")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2020-12")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20201125", eDate="20201231")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-01")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20201225", eDate="20210131")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-02")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210125", eDate="20210231")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-03")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210225", eDate="20210331")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-04")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210325", eDate="20210431")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-05")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210425", eDate="20210531")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-06")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210525", eDate="20210631")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-07")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210625", eDate="20210731")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-08")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210725", eDate="20210831")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-09")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210825", eDate="20210931")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-10")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20210925", eDate="20212031")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-11")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20211025", eDate="20211131")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2021-12")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20211125", eDate="20211231")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2022-01")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20211225", eDate="20220131")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2022-02")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220125", eDate="20220231")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2022-03")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220225", eDate="20220331")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2022-04")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220325", eDate="20220431")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2022-05")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220425", eDate="20220531")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2022-06")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220525", eDate="20220631")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
|
||||
print("ym: 2022-07")
|
||||
train_ds, val_ds = vitTrainer.getData(stock_code, sDate="20220625", eDate="20220731")
|
||||
vitTrainer.finetunning(train_ds, val_ds, model_path)
|
||||
"""
|
||||
Reference in New Issue
Block a user