초점 손실이 객체 감지의 클래스 불균형을 해결하는 방법, 즉 불균형한 데이터 세트의 정확도를 향상시키기 위해 어려운 예제에 집중하여 학습하는 방법을 알아보세요.
초점 손실은 머신러닝 훈련에서 극심한 클래스 불균형 문제를 해결하기 위해 고안된 특수 목적 함수입니다. 학습 훈련, 특히 컴퓨터 비전 분야에서 극심한 클래스 불균형 문제를 해결하기 위해 고안된 특수 목적 함수입니다. 많은 객체 감지 시나리오에서 배경 예제(네거티브)의 수가 배경 예제(네거티브)의 수가 관심 대상(포지티브)의 수를 훨씬 초과하는 경우가 많습니다. 표준 손실 함수는 classify 쉬운 배경 예제의 방대한 양에 압도되어 모델의 학습 능력을 방해할 수 있습니다. 학습하는 데 방해가 될 수 있습니다. 초점 손실은 예측의 신뢰도에 따라 동적으로 손실 규모를 조정함으로써 이러한 문제를 완화합니다. 예측의 신뢰도에 따라 손실을 동적으로 조정하여 쉬운 예제의 가중치를 효과적으로 낮추고 모델이 학습에 집중하도록 합니다. 어려운 부정적 예시와 잘못 분류된 객체에 집중하도록 합니다.
초점 손실의 주요 동기는 1단계 물체 감지기와 같은 1단계 물체 감지기의 성능을 개선하는 것입니다. 레티나넷의 초기 버전과 같은 1단계 오브젝트 디텍터의 성능을 개선하는 것입니다. Ultralytics YOLO11. 이러한 시스템에서 감지기는 이미지를 스캔하여 이미지를 스캔하여 수천 개의 후보 위치를 생성합니다. 이미지의 대부분은 일반적으로 배경이므로, 배경과 객체의 비율은 배경과 물체의 비율은 보통 1000:1 이상이 될 수 있습니다.
개입하지 않으면 방대한 수의 배경 샘플에서 발생하는 작은 오류의 누적 효과로 인해 그라데이션 업데이트를 지배할 수 있습니다. 역전파. 이로 인해 최적화 알고리즘의 우선순위를 단순히 모든 것을 배경으로 분류하여 전체 오류를 최소화하는 데 우선순위를 두게 됩니다. 미묘한 특징을 학습하는 대신 전체 오류를 최소화하는 데 우선순위를 두게 됩니다. 초점 손실은 표준 손실 곡선을 재구성하여 모델이 이미 확신하는 예제에 대한 페널티를 줄임으로써 불이익을 줄이기 위해 표준 손실 곡선을 재구성하여 모델 가중치를 조정하여 까다로운 사례에 맞게 조정합니다.
초점 손실은 표준의 확장입니다. 교차 엔트로피 손실의 확장입니다. 그것은 올바른 클래스에 대한 신뢰도가 증가함에 따라 손실 기여도를 감소시키는 변조 인자를 도입합니다. 모델에서 "쉬운" 예시(예: 하늘이 선명하고 높은 확률로 식별되는 배경과 같이 '쉬운' 예시를 만나면 변조 계수는 손실을 0에 가깝게 만듭니다. 반대로, "어려운" 모델의 예측이 부정확하거나 불확실한 예시에서는 손실이 여전히 크게 남습니다.
이 동작은 흔히 감마로 표시되는 포커싱 매개변수에 의해 제어됩니다. 데이터 과학자는 이 매개변수를 조정하여 는 손실 함수가 잘 분류된 예제에 얼마나 적극적으로 가중치를 낮추는지를 조정할 수 있습니다. 이를 통해 불균형이 심한 훈련 데이터에 대해 보다 안정적인 매우 불균형한 학습 데이터에 대한 학습을 보다 안정적으로 수행하여 희귀한 클래스에 대한 정확도와 리콜률을 높일 수 있습니다.
불균형을 처리하는 기능 덕분에 초점 손실은 안전이 중요한 고정밀 환경에서 필수적인 요소입니다.
그리고 ultralytics 라이브러리는 사용자 지정 교육 파이프라인에 쉽게 통합할 수 있는 강력한 초점 손실 구현을 제공합니다.
커스텀 트레이닝 파이프라인에 쉽게 통합할 수 있습니다. 다음 예는 손실 함수를 초기화하고 예측 로그와 기준값 레이블 사이의
오차를 계산하는 방법을 보여줍니다.
import torch
from ultralytics.utils.loss import FocalLoss
# Initialize Focal Loss with a gamma of 1.5
criterion = FocalLoss(gamma=1.5)
# Example: Prediction logits (before activation) and Ground Truth labels (0 or 1)
preds = torch.tensor([[0.1], [2.5], [-1.0]], requires_grad=True)
targets = torch.tensor([[0.0], [1.0], [1.0]])
# Compute the loss
loss = criterion(preds, targets)
print(f"Focal Loss value: {loss.item():.4f}")
초점 손실을 손실 함수 환경의 관련 용어와 구별하는 것이 도움이 됩니다: