diff --git a/VitTrainer.py b/VitTrainer.py index 2b5564a..bea1205 100755 --- a/VitTrainer.py +++ b/VitTrainer.py @@ -7,9 +7,7 @@ import numpy as np import torch from datasets import Dataset, load_metric, ClassLabel from datasets import load_metric -from transformers import AutoConfig -from transformers import TrainingArguments, Trainer -from transformers import ViTForImageClassification +from transformers import ViTConfig, TrainingArguments, Trainer, ViTForImageClassification from torch.utils.data import DataLoader import torchvision.transforms as transforms from transformers import ViTFeatureExtractor @@ -126,11 +124,16 @@ class VitTrainer: if isinstance(v, torch.Tensor): print(k, v.shape) + """ model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', - num_labels=self.num_labels, - id2label=self.id2label, - label2id=self.label2id) - model = ViTForImageClassification(model.config) + num_labels=self.num_labels, + id2label=self.id2label, + label2id=self.label2id) + """ + configuration = ViTConfig(num_labels=self.num_labels, + id2label=self.id2label, + label2id=self.label2id) + model = ViTForImageClassification(configuration) trainer = Trainer( model,