import os os.environ['KMP_DUPLICATE_LIB_OK']='True' from datasets import load_dataset # load cifar10 (only small portion for demonstration purposes) train_ds, test_ds = load_dataset('cifar10', split=['train[:5000]', 'test[:2000]']) # split up training into training + validation splits = train_ds.train_test_split(test_size=0.1) train_ds = splits['train'] val_ds = splits['test'] id2label = {id:label for id, label in enumerate(train_ds.features['label'].names)} label2id = {label:id for id,label in id2label.items()} from transformers import ViTFeatureExtractor feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") from torchvision.transforms import (CenterCrop, Compose, Normalize, RandomHorizontalFlip, RandomResizedCrop, Resize, ToTensor) normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) _train_transforms = Compose( [ RandomResizedCrop(feature_extractor.size), RandomHorizontalFlip(), ToTensor(), normalize, ] ) _val_transforms = Compose( [ Resize(feature_extractor.size), CenterCrop(feature_extractor.size), ToTensor(), normalize, ] ) def train_transforms(examples): examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['img']] return examples def val_transforms(examples): examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['img']] return examples # Set the transforms train_ds.set_transform(train_transforms) val_ds.set_transform(val_transforms) test_ds.set_transform(val_transforms) from torch.utils.data import DataLoader import torch def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) labels = torch.tensor([example["label"] for example in examples]) return {"pixel_values": pixel_values, "labels": labels} train_dataloader = DataLoader(train_ds, collate_fn=collate_fn, batch_size=4) batch = next(iter(train_dataloader)) for k,v in batch.items(): if isinstance(v, torch.Tensor): print(k, v.shape) from transformers import ViTForImageClassification model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=10, id2label=id2label, label2id=label2id) from transformers import TrainingArguments, Trainer metric_name = "accuracy" args = TrainingArguments( f"test-cifar-10", save_strategy="epoch", evaluation_strategy="epoch", learning_rate=2e-5, per_device_train_batch_size=10, per_device_eval_batch_size=4, num_train_epochs=3, weight_decay=0.01, load_best_model_at_end=True, metric_for_best_model=metric_name, logging_dir='logs', remove_unused_columns=False, ) from datasets import load_metric import numpy as np metric = load_metric("accuracy") def compute_metrics(eval_pred): predictions, labels = eval_pred predictions = np.argmax(predictions, axis=1) return metric.compute(predictions=predictions, references=labels) import torch trainer = Trainer( model, args, train_dataset=train_ds, eval_dataset=val_ds, data_collator=collate_fn, compute_metrics=compute_metrics, tokenizer=feature_extractor, ) trainer.train()