티스토리 뷰

반응형

TorchMetrics란?

TorchMetrics는 PyTorch 에서 사용할 수 있는 Metric 구현 라이브러리이다. 구현된 Metric은 100가지가 넘으며, 쉬운 API 구성으로 훈련 및 평가시 사용할 수 있다.

특징

  • PyTorch에서 사용 가능한 다양한 메트릭 함수 제공
  • 메트릭 컬렉션을 제공하여 여러 메트릭 함수를 한번에 사용 가능
  • GPU를 사용하여 빠른 속도로 메트릭을 계산할 수 있음
  • 분산 학습 호환 및 자동 동기화
  • metrics 계산 시 배치 단위로 누적 연산이 가능함

설치

터미널에서 아래 명령어를 실행하여 설치한다.

pip3 install torchmetrics

함수형 사용법

  • torchmetrics.functional 로부터 함수를 가져와 사용한다.
  • 아래 코드처럼 단순 계산에 사용할 수 있다.
  • 입력되는 preds, target과 출력 형태 모두 torch.tensor 이다.
import torch
import torchmetrics

# simulate a classification problem
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))

acc = torchmetrics.functional.accuracy(preds, target, task="multiclass", num_classes=5)

모듈형 사용법

  • 배치나 에폭 단위로 누적되어 계산할 때 사용할 수 있다.
  • 함수를 호출할 때에는 함수형처럼 현재의 입력에 대한 결과만 나온다.
  • compute() 함수를 호출하면 누적된 결과값을 확인할 수 있다.
import torch
from torchmetrics.classification import Accuracy

# initialize metric
metric = Accuracy(task="multiclass", num_classes=5)

n_batches = 10
for i in range(n_batches):
    # simulate a classification problem
    preds = torch.randn(10, 5).softmax(dim=-1)
    target = torch.randint(5, (10,))
    # metric on current batch
    acc = metric(preds, target)
    print(f"Accuracy on batch {i}: {acc}")

# metric on all batches using custom accumulation
acc = metric.compute()
print(f"Accuracy on all data: {acc}")

# Reseting internal state such that metric ready for new data
metric.reset()

새로운 Custom Metric 구현 방법

새로운 Custom Metric 구현 방법은 다음과 같습니다.

  1. Metric 클래스를 상속하여 새로운 Metric 클래스를 작성합니다.
  1. __init__() 메서드에서 누적으로 계산할 변수를 add_state() 를 사용해서 추가해줍니다.
  1. update() 메서드를 구현하여 측정하고자 하는 지표를 업데이트합니다. 이건 클래스가 호출될때 동작됩니다.
  1. compute() 메서드를 구현하여 업데이트한 지표를 계산합니다.

위와 같은 방법으로 Custom Metric을 구현할 수 있습니다. Metric 클래스의 다른 메서드들에 대해서는 torchmetrics 문서를 참고하시기 바랍니다.

from torchmetrics import Metric

class MyAccuracy(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert preds.shape == target.shape

        self.correct += torch.sum(preds == target)
        self.total += target.numel()

    def compute(self):
        return self.correct.float() / self.total

위의 custom accuracy로 계산을 하면 아래와 같이 계산할 수 있다.

acc = MyAccuracy()
print(acc(torch.tensor([1, 0]), torch.tensor([1, 0])))
# tensor(1.)
print(acc(torch.tensor([1, 0]), torch.tensor([1, 1])))
# tensor(0.5)
print(acc.compute())
# tensor(0.75)

코드를 보면 acc를 호출해서 사용하면 해당 입력(배치)에 대한 acc 결과가 나오고 acc.compute()를 호출할 때 현재까지의 누적 acc 결과가 출력된다.

Metric 내부 동작

MyAccuracy 클래스 구현체를 보면 현재 입력(배치)의 결과를 반환해주는 코드가 없는데 어떻게 가능한지 살펴봤다.

Metric은 내부적으로 아래와 같이 동작한다. (full_state_update에 따라 다를 수 있음)

  1. global state를 캐시에 저장한다.
  1. reset()을 호출하여 default state로 초기화한다.
  1. update()를 호출하여 입력된 batch의 state를 update한다.
  1. compute()를 호출하여 입력된 batch에 대해 metric을 계산한다.
  1. 캐시에 있는 global state에 현재의 batch state를 추가하여 새로운 global state로 업데이트한다.

위 동작으로 인해 각 acc 호출시에는 현재 batch의 metric 값을 반환하고, compute() 호출시에는 전체 누적된 결과의 metric을 반환할 수 있게 된다.

GPU에서의 사용방법

기본적으로 torchmetrics의 Metric은 torch.nn.Module의 subclass이기 때문에 연산을 위해서는 같은 device 위에서 동작해야 한다.

또한 Metric은 항상 cpu에서 초기화 되기 때문에 to()함수를 통해 원하는 gpu device로 올려줘야 한다.

.device property를 사용해서 현재 어떤 device에 올라가 있는지 확인할 수 있다.

from torchmetrics.classification import BinaryAccuracy

target = torch.tensor([1, 1, 0, 0], device=torch.device("cuda", 0))
preds = torch.tensor([0, 1, 0, 0], device=torch.device("cuda", 0))

# Metric states are always initialized on cpu, and needs to be moved to
# the correct device
confmat = BinaryAccuracy().to(torch.device("cuda", 0))
out = confmat(preds, target)
print(out.device) # cuda:0

MetricCollection 사용방법

torchmetrics는 한 metric 내부에 또다른 metric 구현하는 것을 권장하지 않는다. 따라서 여러가지 metric을 한번에 계산하고 싶을 때에는 리스트 형태로 입력하는 MetricCollection을 제공한다.

Metric은 기본적으로 torch.nn.Module의 형태이기 때문에, list나 dict가 아닌 ModuleList, ModuleDict로 입력 받아야 한다.

출력은 dictionary 형태로 출력 되며, key값은 클래스와 동일하게 출력된다.

from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall
target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
metric_collection = MetricCollection([
    MulticlassAccuracy(num_classes=3, average="micro"),
    MulticlassPrecision(num_classes=3, average="macro"),
    MulticlassRecall(num_classes=3, average="macro")
])
print(metric_collection(preds, target))
# {'MulticlassAccuracy': tensor(0.1250),
# 'MulticlassPrecision': tensor(0.0667),
# 'MulticlassRecall': tensor(0.1111)}

MetricCollection은 내부의 Metric 클래스들이 모두 동일한 call signature(동일 매개변수)를 가져야 한다.

MetricCollection의 추가적인 장점으로는 자동으로 compute_group을 찾아 중복되는 state의 연산을 방지하여 속도가 빠르다는 것이다.

예를들어 Accuracy, Recall, Precision 같은 경우 모두 TP, FN, FP, TN 등의 state(stat_score)로 관리하게 되는데 이는 각각 호출되면 중복되는 연산이 있을 수 있는데 이를 자동으로 감지하여 한번 계산된 값은 다른 class에서 참고하여 사용하는 방식으로 구현되어있다. (공식문서에 따르면 기본 연산보다 2~3배 빠르다고 함) 이때에는 반드시 .update() 함수를 사용해서 구현해야 한다.

  • 예시 1

MetricCollection을 dict 형태로 입력하는 것도 가능하다. 이렇게 하면 입력 key값과 동일한 출력의 key값이 반환된다.

metrics = MetricCollection({'micro_recall': MulticlassRecall(num_classes=3, average='micro'),
														'macro_recall': MulticlassRecall(num_classes=3, average='macro')})
same_metric = metrics.clone()
pprint(metrics(preds, target))
# {'macro_recall': tensor(0.1111), 'micro_recall': tensor(0.1250)}
pprint(same_metric(preds, target))
# {'macro_recall': tensor(0.1111), 'micro_recall': tensor(0.1250)}
  • 예시 2

MetricCollection은 nested 형태로 입력하는 것도 가능하다. prefix와 postfix를 입력하여 각각의 출력 dict key값에 반영할 수 있다. (각 Metric 별 쉽게 구분 가능)

metrics = MetricCollection([
     MetricCollection([
         MulticlassAccuracy(num_classes=3, average='macro'),
         MulticlassPrecision(num_classes=3, average='macro')
     ], postfix='_macro'),
     MetricCollection([
         MulticlassAccuracy(num_classes=3, average='micro'),
         MulticlassPrecision(num_classes=3, average='micro')
     ], postfix='_micro'),
 ], prefix='valmetrics/')
pprint(metrics(preds, target))  
# {'valmetrics/MulticlassAccuracy_macro': tensor(0.1111),
#  'valmetrics/MulticlassAccuracy_micro': tensor(0.1250),
#  'valmetrics/MulticlassPrecision_macro': tensor(0.0667),
#  'valmetrics/MulticlassPrecision_micro': tensor(0.1250)}
  • 예시 3

compute_groups는 default로 True가 입력되어 자동적으로 적용할 수 있지만 어떤 metric들끼리 state를 공유할지 직접 지정해줄 수도 있다. 추후에 .compute_groups argument를 출력하면 어떤 Metric끼리 연결되어있는지 확인할 수 있다.

이때 state update는 metrics(pred, target)처럼 forward 방식이 아닌, metrics.update(pred, target)와 같이 .udpate() 함수를 사용해야 한다.

metrics = MetricCollection(
   MulticlassRecall(num_classes=3, average='macro'),
   MulticlassPrecision(num_classes=3, average='macro'),
   MeanSquaredError(),
   compute_groups=[['MulticlassRecall', 'MulticlassPrecision'], ['MeanSquaredError']]
)
metrics.update(preds, target)
pprint(metrics.compute())
# {'MeanSquaredError': tensor(2.3750), 'MulticlassPrecision': tensor(0.0667), 'MulticlassRecall': tensor(0.1111)}
pprint(metrics.compute_groups)
# {0: ['MulticlassRecall', 'MulticlassPrecision'], 1: ['MeanSquaredError']}

Memory management

기본적으로 Metrics는 state를 관리하는 형태이기 때문에 크게 두가지로 나뉜다.

  1. Metrics with tensor states: 이 유형은 update함에 따라 메모리가 증가하지 않는 유형이다. 예를들어 tp, fn, fp, tn으로 관리하게 되면, 이 값들은 숫자가 증가할 뿐 메모리의 크기가 커지지 않기 때문에 메모리 측면에서 유리하다.
  1. Metrics with list states: 이 유형은 리스트 형태의 state이다. update함에 따라 원소가 하나씩 추가하게 되고, 리스트가 커짐에 따라 메모리 사용량도 증가한다.

metric의 현재 입력 상태를 .metric_state property로 확인할 수 있다. (1.1 버전 이상에서만 가능) 그럼 현재 사용중인 metric의 state가 list인지 single tensor인지 확인 할 수 있다.

아래 코드는 list로 구성되어있는 SpearmanCorrCoef 예시로 update()를 호출할수록 state에서 들고있는 리스트가 점점 많아지게 된다.

import torch
from torchmetrics.regression import SpearmanCorrCoef

gen = torch.manual_seed(42)
metric = SpearmanCorrCoef()
metric(torch.rand(2,), torch.rand(2,))
print(metric.metric_state)
# {'preds': [tensor([0.8823, 0.9150])], 'target': [tensor([0.3829, 0.9593])]}
metric(torch.rand(2,), torch.rand(2,))
print(metric.metric_state)
# {'preds': [tensor([0.8823, 0.9150]), tensor([0.3904, 0.6009])], 'target': [tensor([0.3829, 0.9593]), tensor([0.2566, 0.7936])]}

아래는 메모리를 효율적으로 사용하는 방법이다.

  • 메트릭을 사용한 후에는 항상 reset 메소드를 호출한다. 그렇지 않은 경우, 동일한 메트릭의 여러 인스턴스가 동일한 스크립트에서 생성되면 메모리 누수가 발생할 수 있다.
  • 새로운 메트릭을 초기화하는 대신 동일한 메트릭 인스턴스를 재사용한다. reset 메소드를 호출하면 메트릭이 초기 상태로 돌아가므로 동일한 인스턴스를 재사용할 수 있다. 그러나 train, valid, teest 등에서 서로 다른 인스턴스를 사용하는 것이 좋다.
  • 데이터셋이 작아 누적이나 반복 평가가 필요없는 경우, 내부 상태를 유지하지 않으므로 함수형 API를 사용하는 것이 좋다.

참고자료

[1] torchmetrics 공식문서


Uploaded by N2T

반응형
댓글
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday