From fc6400ba77578c52a15e339082f08b2e4542abf2 Mon Sep 17 00:00:00 2001 From: dosangyoon Date: Fri, 5 Aug 2022 14:35:44 +0900 Subject: [PATCH] init --- VitTrainer.py | 130 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100755 VitTrainer.py diff --git a/VitTrainer.py b/VitTrainer.py new file mode 100755 index 0000000..239a146 --- /dev/null +++ b/VitTrainer.py @@ -0,0 +1,130 @@ +import os +os.environ['KMP_DUPLICATE_LIB_OK']='True' +from datasets import load_dataset + +# load cifar10 (only small portion for demonstration purposes) +train_ds, test_ds = load_dataset('cifar10', split=['train[:5000]', 'test[:2000]']) +# split up training into training + validation +splits = train_ds.train_test_split(test_size=0.1) +train_ds = splits['train'] +val_ds = splits['test'] + +id2label = {id:label for id, label in enumerate(train_ds.features['label'].names)} +label2id = {label:id for id,label in id2label.items()} + + +from transformers import ViTFeatureExtractor + +feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") + +from torchvision.transforms import (CenterCrop, + Compose, + Normalize, + RandomHorizontalFlip, + RandomResizedCrop, + Resize, + ToTensor) + +normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) +_train_transforms = Compose( + [ + RandomResizedCrop(feature_extractor.size), + RandomHorizontalFlip(), + ToTensor(), + normalize, + ] + ) + +_val_transforms = Compose( + [ + Resize(feature_extractor.size), + CenterCrop(feature_extractor.size), + ToTensor(), + normalize, + ] + ) + +def train_transforms(examples): + examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['img']] + return examples + +def val_transforms(examples): + examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['img']] + return examples + +# Set the transforms +train_ds.set_transform(train_transforms) +val_ds.set_transform(val_transforms) +test_ds.set_transform(val_transforms) + + +from torch.utils.data import DataLoader +import torch + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + labels = torch.tensor([example["label"] for example in examples]) + return {"pixel_values": pixel_values, "labels": labels} + +train_dataloader = DataLoader(train_ds, collate_fn=collate_fn, batch_size=4) + + +batch = next(iter(train_dataloader)) +for k,v in batch.items(): + if isinstance(v, torch.Tensor): + print(k, v.shape) + + +from transformers import ViTForImageClassification + +model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', + num_labels=10, + id2label=id2label, + label2id=label2id) + + +from transformers import TrainingArguments, Trainer + +metric_name = "accuracy" + +args = TrainingArguments( + f"test-cifar-10", + save_strategy="epoch", + evaluation_strategy="epoch", + learning_rate=2e-5, + per_device_train_batch_size=10, + per_device_eval_batch_size=4, + num_train_epochs=3, + weight_decay=0.01, + load_best_model_at_end=True, + metric_for_best_model=metric_name, + logging_dir='logs', + remove_unused_columns=False, +) + + +from datasets import load_metric +import numpy as np + +metric = load_metric("accuracy") + +def compute_metrics(eval_pred): + predictions, labels = eval_pred + predictions = np.argmax(predictions, axis=1) + return metric.compute(predictions=predictions, references=labels) + + +import torch + +trainer = Trainer( + model, + args, + train_dataset=train_ds, + eval_dataset=val_ds, + data_collator=collate_fn, + compute_metrics=compute_metrics, + tokenizer=feature_extractor, +) + + +trainer.train() \ No newline at end of file