This commit is contained in:
dsyoon
2022-08-19 22:06:51 +09:00
parent 5975584e86
commit f70942a074

View File

@@ -7,9 +7,7 @@ import numpy as np
import torch
from datasets import Dataset, load_metric, ClassLabel
from datasets import load_metric
from transformers import AutoConfig
from transformers import TrainingArguments, Trainer
from transformers import ViTForImageClassification
from transformers import ViTConfig, TrainingArguments, Trainer, ViTForImageClassification
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from transformers import ViTFeatureExtractor
@@ -126,11 +124,16 @@ class VitTrainer:
if isinstance(v, torch.Tensor):
print(k, v.shape)
"""
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
num_labels=self.num_labels,
id2label=self.id2label,
label2id=self.label2id)
model = ViTForImageClassification(model.config)
num_labels=self.num_labels,
id2label=self.id2label,
label2id=self.label2id)
"""
configuration = ViTConfig(num_labels=self.num_labels,
id2label=self.id2label,
label2id=self.label2id)
model = ViTForImageClassification(configuration)
trainer = Trainer(
model,