init
This commit is contained in:
@@ -7,9 +7,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, load_metric, ClassLabel
|
from datasets import Dataset, load_metric, ClassLabel
|
||||||
from datasets import load_metric
|
from datasets import load_metric
|
||||||
from transformers import AutoConfig
|
from transformers import ViTConfig, TrainingArguments, Trainer, ViTForImageClassification
|
||||||
from transformers import TrainingArguments, Trainer
|
|
||||||
from transformers import ViTForImageClassification
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
from transformers import ViTFeatureExtractor
|
from transformers import ViTFeatureExtractor
|
||||||
@@ -126,11 +124,16 @@ class VitTrainer:
|
|||||||
if isinstance(v, torch.Tensor):
|
if isinstance(v, torch.Tensor):
|
||||||
print(k, v.shape)
|
print(k, v.shape)
|
||||||
|
|
||||||
|
"""
|
||||||
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
|
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
|
||||||
num_labels=self.num_labels,
|
num_labels=self.num_labels,
|
||||||
id2label=self.id2label,
|
id2label=self.id2label,
|
||||||
label2id=self.label2id)
|
label2id=self.label2id)
|
||||||
model = ViTForImageClassification(model.config)
|
"""
|
||||||
|
configuration = ViTConfig(num_labels=self.num_labels,
|
||||||
|
id2label=self.id2label,
|
||||||
|
label2id=self.label2id)
|
||||||
|
model = ViTForImageClassification(configuration)
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model,
|
model,
|
||||||
|
|||||||
Reference in New Issue
Block a user