init
This commit is contained in:
@@ -7,9 +7,7 @@ import numpy as np
|
||||
import torch
|
||||
from datasets import Dataset, load_metric, ClassLabel
|
||||
from datasets import load_metric
|
||||
from transformers import AutoConfig
|
||||
from transformers import TrainingArguments, Trainer
|
||||
from transformers import ViTForImageClassification
|
||||
from transformers import ViTConfig, TrainingArguments, Trainer, ViTForImageClassification
|
||||
from torch.utils.data import DataLoader
|
||||
import torchvision.transforms as transforms
|
||||
from transformers import ViTFeatureExtractor
|
||||
@@ -126,11 +124,16 @@ class VitTrainer:
|
||||
if isinstance(v, torch.Tensor):
|
||||
print(k, v.shape)
|
||||
|
||||
"""
|
||||
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
|
||||
num_labels=self.num_labels,
|
||||
id2label=self.id2label,
|
||||
label2id=self.label2id)
|
||||
model = ViTForImageClassification(model.config)
|
||||
num_labels=self.num_labels,
|
||||
id2label=self.id2label,
|
||||
label2id=self.label2id)
|
||||
"""
|
||||
configuration = ViTConfig(num_labels=self.num_labels,
|
||||
id2label=self.id2label,
|
||||
label2id=self.label2id)
|
||||
model = ViTForImageClassification(configuration)
|
||||
|
||||
trainer = Trainer(
|
||||
model,
|
||||
|
||||
Reference in New Issue
Block a user