From f70942a07478a6d1c6d38abcbdc7c01d979cc2e3 Mon Sep 17 00:00:00 2001 From: dsyoon Date: Fri, 19 Aug 2022 22:06:51 +0900 Subject: [PATCH] init --- VitTrainer.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/VitTrainer.py b/VitTrainer.py index 2b5564a..bea1205 100755 --- a/VitTrainer.py +++ b/VitTrainer.py @@ -7,9 +7,7 @@ import numpy as np import torch from datasets import Dataset, load_metric, ClassLabel from datasets import load_metric -from transformers import AutoConfig -from transformers import TrainingArguments, Trainer -from transformers import ViTForImageClassification +from transformers import ViTConfig, TrainingArguments, Trainer, ViTForImageClassification from torch.utils.data import DataLoader import torchvision.transforms as transforms from transformers import ViTFeatureExtractor @@ -126,11 +124,16 @@ class VitTrainer: if isinstance(v, torch.Tensor): print(k, v.shape) + """ model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', - num_labels=self.num_labels, - id2label=self.id2label, - label2id=self.label2id) - model = ViTForImageClassification(model.config) + num_labels=self.num_labels, + id2label=self.id2label, + label2id=self.label2id) + """ + configuration = ViTConfig(num_labels=self.num_labels, + id2label=self.id2label, + label2id=self.label2id) + model = ViTForImageClassification(configuration) trainer = Trainer( model,