This commit is contained in:
dsyoon
2022-08-19 22:06:51 +09:00
parent 5975584e86
commit f70942a074

View File

@@ -7,9 +7,7 @@ import numpy as np
import torch import torch
from datasets import Dataset, load_metric, ClassLabel from datasets import Dataset, load_metric, ClassLabel
from datasets import load_metric from datasets import load_metric
from transformers import AutoConfig from transformers import ViTConfig, TrainingArguments, Trainer, ViTForImageClassification
from transformers import TrainingArguments, Trainer
from transformers import ViTForImageClassification
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import torchvision.transforms as transforms import torchvision.transforms as transforms
from transformers import ViTFeatureExtractor from transformers import ViTFeatureExtractor
@@ -126,11 +124,16 @@ class VitTrainer:
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
print(k, v.shape) print(k, v.shape)
"""
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
num_labels=self.num_labels, num_labels=self.num_labels,
id2label=self.id2label, id2label=self.id2label,
label2id=self.label2id) label2id=self.label2id)
model = ViTForImageClassification(model.config) """
configuration = ViTConfig(num_labels=self.num_labels,
id2label=self.id2label,
label2id=self.label2id)
model = ViTForImageClassification(configuration)
trainer = Trainer( trainer = Trainer(
model, model,