تعبیه Dinov2 برای طبقه بندی دقیق تصویر


نویسنده (ها): دکتری لیی گور آری

در ابتدا منتشر شده در به سمت هوش مصنوعیبشر

تعبیه Dinov2 برای طبقه بندی دقیق تصویرتعبیه Dinov2 برای طبقه بندی دقیق تصویر

مقدمه

آموزش یک طبقه بندی کننده تصویر با کارایی بالا به طور معمول به مقادیر زیادی از داده های دارای برچسب نیاز دارد. اما اگر بتوانید با حداقل داده ها و آموزش های سبک به نتایج سطح بالایی برسید؟

Dinov2 یک مدل بنیاد چشم انداز قدرتمند است که بردارهای نمایشی غنی از تصویر را تولید می کند ، همچنین به عنوان تعبیه شناخته می شود. برخلاف مدل های مبتنی بر متن مانند گیره، که بر روی تراز معنایی متمرکز است ، Dinov2 در ضبط ساختار بصری ، بافت و جزئیات مکانی برتری دارد-آن را برای کارهای طبقه بندی تصویر ریز در حوزه های تخصصی مانند تصویربرداری پزشکی و بیولوژیکی ایده آل می کند.

در این آموزش ، نحوه استفاده از Dinov2 را برای ساخت طبقه بندی کننده صفر با استفاده از آن بررسی خواهیم کرد K-Nearest همسایگان (K-NN) ، و چگونگی افزایش عملکرد قابل توجه با آموزش یک لایه خطی در بالای ویژگی های استخراج شده. با تشکر از تعبیه های با کیفیت Dinov2 ، می توانیم با استفاده از تنها تعداد کمی از تصاویر دارای برچسب ، یک طبقه بندی دقیق را آموزش دهیم.

در رمز کامل در نوت بوک COLAB تعبیه شده در زیر موجود است ، برای شما آماده است تا با داده های خود کاوش و سازگار شوید.

خط لوله: تصاویر توسط Dinov2 به بردارهای ویژگی رمزگذاری می شوند ، که سپس برای آموزش یک طبقه بندی خطی استفاده می شوند | تصویر توسط نویسنده.

پیشینه

دکورا (کوتاه برای دیزائیساکت با هیچ برچسب ها) ، که توسط متا ساخته شده است ، روشی برای آموزش مدل های بینایی به روشی خودبوشی و بدون برچسب است. مدل های تولید شده Dino مدل های بنیادی قدرتمند Vision هستند که قادر به استخراج ویژگی های غنی از تصاویر هستند. با اتصال سرهای مختلف در بالای ستون فقرات Dinov2 ، این مدل را می توان در کارهای مختلف دید ، مانند طبقه بندی تصویر ، تقسیم بندی ، تخمین عمق و موارد دیگر تنظیم کرد. اگرچه در این آموزش ما به ستون فقرات دینو آموزش نخواهیم داد ، درک این روش در ابتدا چگونه آموزش دیده است. اگر مشتاق پرش به کد هستید ، می توانید مستقیماً به سمت اجرای کد بخش

𝐃𝐈𝐍𝐎𝐯𝟏

نسخه اول از دکورا یک تکنیک خودآزمایی را معرفی کرد که در آن یک شبکه دانشجویی یاد می گیرد که خروجی یک شبکه معلم را پیش بینی کند. هم معلم و هم دانش آموز با همان معماری مشترک هستند: یک ستون فقرات ترانسفورماتور بینایی (VIT) و یک MLP 3 لایه (چند لایه پیج) سر. شبکه معلم نیز برای جلوگیری از فروپاشی با مرکز و تیز کردن معرفی شد. در حین آموزش ، معلم محصولات بزرگ (نماهای جهانی) از یک تصویر را دریافت می کند ، در حالی که دانش آموز هم کوچک (نمای محلی) و هم محصولات بزرگی را با همان تصویر پردازش می کند. محصولات زراعی از طریق شبکه ها پردازش می شوند و دانش آموز در تلاش است تا خروجی معلم تیز (دمای پایین) را پیش بینی کند. تیز کردن با استفاده از یک مقدار دمای پایین در نرم افزار شبکه معلم انجام می شود ، تا اعتماد به نفس معلم را به یک بعد غالب بالا ببرد ، تا راهنمایی های بهتری به دانش آموز انجام دهد.

وزن دانش آموزان با صلیب به روز می شودآنتروپی تابع هزینه، و وزن معلم به روز می شود تا میانگین حرکت نمایی شبکه دانشجویی باشد.

𝐃𝐈𝐍𝐎𝐯𝟐

dinov2 این چارچوب اصلی را با ادغام چندین پیشرفت کلیدی تقویت می کند. در هسته خود ، Dinov2 از اهداف دوگانه ای که به طور هم زمان کار می کنند استفاده می کند:

  1. هدف سطح تصویر – این هدف از Dino به ارث رسیده ، شبکه دانشجویی را ترغیب می کند تا با نمایندگی تصویر جهانی معلم مطابقت داشته باشد. این کار بر روی نشانه کلاس برای ضبط نمای جامع تصویر انجام می شود.
  2. سطح پچ عینی – الهام گرفته از عید، این هدف شامل نقاب زدن تکه های خاص در ورودی دانش آموز است. دانش آموز سپس سعی می کند این مناطق نقاب دار را با استفاده از تکه های قابل مشاهده اطراف به عنوان زمینه پیش بینی کند. صلیب آنتروپی بین ویژگی های پچ دانش آموز و معلم محاسبه می شود و درک ویژگی های محلی را ارتقا می بخشد.

این رویکرد عینی دوگانه هم درک سطح بالایی از تصویر را از طریق هدف سطح تصویر و یک درک دقیق محلی از طریق هدف سطح پچ تشویق می کند و در نتیجه نمایش های بصری غنی تر است.

از دست دادن تمرین نهایی ، مبلغ وزنی از دست دادن Dino و از دست دادن IBOT است که به طور موثری سیگنال های یادگیری جهانی و محلی را متعادل می کند.

علاوه بر این ، چندین بهینه سازی دیگر در Dinov2 معرفی شده است ، از جمله بهبود استراتژی های عادی سازی و منظم سازی ، یک طرح آموزش چند وضوح و آموزش در مورد یک تصویر با کیفیت بالا و سرپرستی مجموعه دادهبشر می توانید اطلاعات بیشتری در مورد آن بخوانید (در اینجا).

اجرای کد

پس از کاوش در معماری Dino و روند آموزش ستون فقرات آن ، در این آموزش از یک ستون فقرات از پیش آموزش داده شده Dinov2 برای استخراج بردارهای نمایش تصویر استفاده خواهیم کرد. ابتدا عملکرد صفر آن را با استفاده از a ارزیابی خواهیم کرد کنگره طبقه بندی کننده سپس ، ما با آموزش یک لایه طبقه بندی خطی واحد در بالا ، عملکرد را بهبود می بخشیم.

تنظیم محیط

از آنجا که ما از بغل کردن چهره برای بارگذاری مدل از پیش آموزش استفاده می کنیم ، اطمینان حاصل کنید که نشانه چهره بغل شما در شما تنظیم شده است Google Colab محیط

بعد ، کتابخانه های مورد نیاز را نصب و وارد کنید.

مجموعه داده نمای کلی

ما از EMDS-6 مجموعه داده های میکروارگانیسم ، که در ابتدا برای تقسیم بندی طراحی شده بود ، و آن را برای طبقه بندی در این آموزش تطبیق می دهد. این مجموعه داده شامل 21 کلاس میکروارگانیسم است که ویژگی های بصری مشابهی را به اشتراک می گذارد و آن را به یک کار طبقه بندی ریز دانه تبدیل می کند. با تنها 40 تصویر در هر کلاس و فقط 32 مورد برای آموزش استفاده می شود ، همچنین یک تنظیم کم به چالش کشیده است.

من داده ها را در 80 ٪ آموزش و 20 ٪ اعتبار سنجی از قبل تقسیم کرده ام. شما می توانید نسخه آماده شده ، ساختار یافته را به شرح زیر بارگیری کنید:

EMDS6_Data/
├── train/
│ ├── actinophrys/
│ ├── arcella/
│ └── ...
├── val/
│ ├── actinophrys/
│ ├── arcella/
│ └── ...

هر زیر فرم به نام یک کلاس نامگذاری شده است و حاوی تصاویر PNG از میکروارگانیسم مربوطه است. در زیر یک نمونه تصادفی از مجموعه داده ، یک تصویر از هر یک از 21 کلاس:

اکنون که مجموعه داده ها را بارگیری و مشاهده کرده ایم ، یک خط لوله داده برای بارگیری و پیش پردازش تصاویر تنظیم خواهیم کرد.

قسمت اول-طبقه بندی صفر-

در قسمت اول ، عملکرد صفر-شات را در Dinov2 با استفاده از طبقه بندی کننده KNN بررسی خواهیم کرد.

بارگیری داده های تصویر به مدل Dinov2

برای تهیه داده های ما برای استخراج ویژگی، ما استفاده می کنیم timmبرنامه های مناسب برای تعریف تبدیل تصویر بر اساس پیکربندی داده های مدل. سپس ما با استفاده از مجموعه داده های Pytorch آموزش و اعتبار سنجی ایجاد می کنیم ImageDataset کلاس ، استفاده از تبدیل ها به هر مجموعه. بالاخره ، DataLoaderS برای تغذیه تصاویر به مدل Dinov2 تنظیم شده است ، و از پیش پردازش مداوم و دسته بندی کارآمد تصاویر و برچسب ها اطمینان حاصل می کند.

def create_data_loaders(data_dir, batch_size=32, model_name='vit_small_patch14_dinov2', seed=42):
"""
Create data loaders using timm's transforms and dataset utilities.
"""

# Set generator with seed for reproducible data loading
g = torch.Generator()
g.manual_seed(seed)

# Create transforms
data_config = timm.data.resolve_model_data_config(model_name)
data_config['input_size'] = (3, 518, 518) # DINOv2's native resolution

train_transform = timm.data.create_transform(**data_config, is_training=True)
val_transform = timm.data.create_transform(**data_config, is_training=False)

# Create datasets using timm's Dataset class
train_dataset = ImageDataset(root=os.path.join(data_dir, 'train'), transform=train_transform)
val_dataset = ImageDataset(root=os.path.join(data_dir, 'val'), transform=val_transform)

# Create dataloaders
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2,
pin_memory=True,
generator=g
)

val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=2,
pin_memory=True,
generator=g
)

# Get class mappings
class_names = train_dataset.reader.class_to_idx
id2label = {v: k for k, v in class_names.items()}
label2id = class_names

print(f"Created data loaders:")
print(f" Training: {len(train_dataset)} samples, {len(train_loader)} batches")
print(f" Validation: {len(val_dataset)} samples, {len(val_loader)} batches")
print(f" Number of classes: {len(class_names)}")

return train_loader, val_loader, id2label, label2id

# Create data loaders
train_loader, val_loader, id2label, label2id = create_data_loaders(
data_dir='/content/EMDS6_Data',
batch_size=32, seed=0
)

استخراج ویژگی های Dinov2

با آماده شدن Dataloaders ما ، مرحله بعدی عبور تصاویر از طریق یک مدل Dinov2 از پیش آموزش داده شده ، برای استخراج تعبیه های ویژگی غنی است. با تنظیم num_classes=0، سر طبقه بندی را حذف می کنیم و بردارهای ویژگی خام را از ستون فقرات بدست می آوریم.

def extract_features(train_loader, val_loader, model_name='vit_small_patch14_dinov2'):
"""
Extract features using DINOv2 model from timm.
"""

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create a feature extractor using timm
model = timm.create_model(
model_name,
pretrained=True,
num_classes=0
).to(device)

model = model.eval()

# Function to extract features
def extract_batch_features(loader):
all_features = []
all_labels = []

with torch.no_grad():
for images, labels in tqdm(loader, desc="Extracting features"):
images = images.to(device)
features = model(images)
all_features.append(features.cpu())
all_labels.append(labels)

return torch.cat(all_features, dim=0), torch.cat(all_labels, dim=0)

# Extract features from train and validation sets
train_features, train_labels = extract_batch_features(train_loader)
print(f"Training features shape: {train_features.shape}")

val_features, val_labels = extract_batch_features(val_loader)
print(f"Validation features shape: {val_features.shape}")

return train_features, train_labels, val_features, val_labels

# Extract features
train_features, train_labels, val_features, val_labels = extract_features(
train_loader, val_loader
)

طبقه بندی صفر با knn

برای ارزیابی کیفیت تعبیه Dinov2 ، ما یک طبقه بندی کننده K-Nearest همسایگان (KNN) را مستقیماً روی ویژگی های استخراج شده اعمال می کنیم. این روش ساده شامل آموزش نیست – هر تصویر اعتبار سنجی را بر اساس نزدیکترین تعبیه شده از مجموعه آموزش طبقه بندی می کند. نتیجه: knn دقت از 83.9 ٪ که با توجه به چالش تمایز بین 21 کلاس ریز دانه ، نتیجه مناسبی است. گفته می شود ، ما می توانیم با آموزش یک طبقه بندی کننده خطی در بالا ، عملکرد را بیشتر کنیم.

قسمت دوم – آموزش سر طبقه بندی کننده خطی

در حالی که KNN بر اساس فاصله بین ویژگی ها طبقه بندی می شود ، اما با مرزهای تصمیم گیری خاص که کلاس ها را در مجموعه داده های ما جدا می کنند ، تنظیم نمی شود. با آموزش یک طبقه بندی خطی با تعبیه به عنوان ورودی ، می توانیم فضای ویژگی را بهتر شکل دهیم تا با مجموعه داده های خود مطابقت داشته باشد و کلاس ها را بهتر جدا کنیم.

ویژگی Dataloader Setup

برای بارگیری کارآمد داده ها در دسته ها در طول آموزش ، ما مجموعه جدیدی از DataLoaderS برای رسیدگی به تعبیه های Dinov2 که قبلاً استخراج شده بود.

def create_feature_dataloaders(train_features, train_labels, val_features, val_labels, batch_size=64, seed=42):
"""
Create data loaders for pre-extracted features.
"""

# Set generator with seed for reproducible data loading
g = torch.Generator()
g.manual_seed(seed)

# Use timm.data.Dataset for feature datasets
train_dataset = TensorDataset(train_features, train_labels)
val_dataset = TensorDataset(val_features, val_labels)

train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0,
pin_memory=True,
generator=g
)

val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
generator=g
)

print(f"Created feature dataloaders:")
print(f" Training: {len(train_dataset)} samples, {len(train_loader)} batches")
print(f" Validation: {len(val_dataset)} samples, {len(val_loader)} batches")

return train_loader, val_loader

# Create feature dataloaders
train_feature_loader, val_feature_loader = create_feature_dataloaders(
train_features, train_labels, val_features, val_labels, seed=0
)

تعیین یک سر طبقه بندی خطی

ما یک مدل Pytorch ساده متشکل از یک لایه ترکیبی را تعریف می کنیم و به دنبال آن یک لایه خطی کاملاً متصل است. ورودی به مدل یک بردار ویژگی Dinov2 است و خروجی یک نمره کلاس برای هر دسته میکروارگانیسم است. این سر سبک وزن برای دستیابی به نتایج قوی آسان است و کافی است.

class DINOv2Classifier(nn.Module):
"""Linear classifier for DINOv2 features."""
def __init__(self, input_dim, num_classes):
super().__init__()
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(input_dim, num_classes)
)

def forward(self, x):
return self.classifier(x)

# Create classifier model
feature_dim = train_features.shape[1]
num_classes = len(id2label)
classifier = DINOv2Classifier(feature_dim, num_classes).to(device)

آموزش رئیس طبقه بندی

در این مرحله ، پارامترهای حلقه آموزش را تعریف می کنیم و سر طبقه بندی خطی را آموزش می دهیم. این حلقه بر اساس دقت اعتبار سنجی ، بهترین وزن مدل را پیگیری می کند.

def train_model(classifier, train_loader, val_loader, num_epochs, lr):
"""Train the classifier on extracted DINOv2 features."""

# Setup training
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=lr, weight_decay=1e-5)
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'lr': []}
best_val_acc = 0.0

for epoch in range(num_epochs):
# Training phase
classifier.train()
train_loss, train_correct = 0.0, 0

for features, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Train"):
features, labels = features.to(device), labels.to(device)

# Forward & backward pass
outputs = classifier(features)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()

# Track metrics
train_loss += loss.item() * features.size(0)
_, predicted = torch.max(outputs, 1)
train_correct += (predicted == labels).sum().item()

# Validation phase
classifier.eval()
val_loss, val_correct = 0.0, 0

with torch.no_grad():
for features, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Val"):
features, labels = features.to(device), labels.to(device)

outputs = classifier(features)
loss = criterion(outputs, labels)

val_loss += loss.item() * features.size(0)
_, predicted = torch.max(outputs, 1)
val_correct += (predicted == labels).sum().item()

# Calculate epoch metrics
train_size, val_size = len(train_loader.dataset), len(val_loader.dataset)
train_loss, train_acc = train_loss / train_size, train_correct / train_size
val_loss, val_acc = val_loss / val_size, val_correct / val_size

# Update the learning rate
scheduler.step()
current_lr = optimizer.param_groups[0]['lr']

# Store metrics
for key, value in zip(
['train_loss', 'train_acc', 'val_loss', 'val_acc', 'lr'],
[train_loss, train_acc, val_loss, val_acc, current_lr]):
history[key].append(value)

# Print results & save best model
print(f"\nEpoch {epoch+1}/{num_epochs}: train_acc={train_acc:.4f}, val_acc={val_acc:.4f}, lr={current_lr:.6f}")

if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(classifier.state_dict(), "/content/best_dinov2_classifier.pth")
print(f"✓ New best model saved: {val_acc:.4f}")

return history

# Train the classifier
history = train_model(classifier, train_feature_loader, val_feature_loader, num_epochs=20, lr=0.5)

با نگاهی به توطئه های آموزشی زیر ، ما شاهد افت شدید از دست دادن و افزایش شدید دقت در دوره های اولیه هستیم و این نشان می دهد که این مدل به طور مؤثر در حال یادگیری است. با پیشرفت آموزش و کاهش میزان یادگیری ، مدل به تدریج همگرا و تثبیت می شود. مدل بهترین عملکرد در Epoch 13 ذخیره می شود و به یک اعتبار سنجی چشمگیر می رسد دقت از 95.8 ٪، پیشرفت قابل توجهی نسبت به پایه Zero-Shot KNN!

سخنان پایانی

در این آموزش ، ما از تعبیه های غنی Dinov2 برای ساخت طبقه بندی دقیق میکروارگانیسم استفاده کردیم. با وجود مجموعه داده های کوچک و چالش برانگیز ، ما با آموزش یک سر خطی ساده ، 83.9 ٪ دقت صفر با KNN و 95.8 ٪ به دست آوردیم. Dinov2 برای سناریوها با برچسب های محدود و جزئیات بصری ریز و درشت مناسب است. با این حال ، ستون فقرات سنگین آن باعث می شود که برای برنامه های کاربردی در زمان واقعی یا استقرار در دستگاه های لبه کم منابع مناسب تر باشد. برای کارهایی که نیاز به درک معنایی عمیق تر دارند ، مدل های بینایی زبان مانند کلیپ ممکن است تعبیهات مناسب تری را فراهم کنند.

از خواندن شما متشکرم!

تبریک می گویم که این کار را در اینجا انجام دهید! اگر از آموزش لذت بردید ، برای تقویت عزت نفس الگوریتم و کمک به خوانندگان بیشتر در یافتن آن ، روی 👍x50 ضربه بزنید.

می خواهید بیشتر بدانید؟

کد کامل به عنوان نوت بوک COLAB:

منتشر شده از طریق به سمت هوش مصنوعی



منبع: https://towardsai.net/p/machine-learning/harness-dinov2-embeddings-for-accurate-image-classification

پاسخی بگذارید

نشانی ایمیل شما منتشر نخواهد شد. بخش‌های موردنیاز علامت‌گذاری شده‌اند *