티스토리 뷰

반응형
AnoGAN

AnoGAN

안녕하세요, 오늘 정리할 논문은 AnoGAN 입니다. 실제 논문 제목은 "Unsupervised Anomaly Detection with Generative Adversarial Networks to Guide Marker Discovery" 입니다.

기존 일반 딥러닝 분류 모델에는 다음과 같은 문제가 있습니다.

  • Anomaly에 대한 Annotation 수작업이 필요하다.
  • 정상/비정상 데이터에 대해서 비정상 데이터의 개수가 압도적으로 적다.

    e.g. 공장에서의 비정상 데이터, 특정 질병 데이터

위의 문제를 해결하기 위해서 Unsupervised 방법으로 GAN 을 결합하였는데요, 비정상 데이터 없이 정상 데이터만 학습하는 방식입니다.

학습방법

  1. 아래 Fig.1 처럼 GAN 모델에는 Healthy Data만 입력하여 정상 데이터에 대한 분포를 학습한다.
  1. 학습이 잘 된 GAN 모델에 정상/비정상(Unseen) 데이터를 입력하여 anomaly score를 비교하여 분류한다.

GAN 모델 구조

일반적인 DCGAN과 같이 convolution으로 이루어진 2개의 모델 Generator, Discriminator를 학습합니다.

Generator 학습

latent space에서 z vector를 랜덤으로 추출하고 Generator에 입력하여 fake image를 생성하고, 이 fake image를 Discriminator에 입력하여 진짜 이미지라고 판단하도록 D를 속일 수 있게 학습합니다.

이때, D의 weight는 udpate하지 않습니다.

Discriminator 학습

생성된 이미지 fake image와 실제 이미지 real image에 대해 각각 fake/real로 잘 판단할 수 있도록 학습합니다.

이때, G의 weight는 update하지 않습니다.

GAN 학습 시 Loss 함수

GAN의 Loss는 다음과 같습니다. D와 G가 서로 경쟁하는 Adversarial 학습입니다.

D는 V가 maximum이 되도록 학습하고, (D(x)=1,D(G(z))=0D(x)=1, D(G(z))=0일때 maximum이 됨)

G는 V가 minimum이 되도록 학습합니다. (D(G(z))=1D(G(z))=1일때 minimum이 됨)

아래와 같은 구조로 일반적인 DCGAN처럼 학습하게 되고, normal/anomalous 이미지의 feature에 대해서 서로 다른 분포를 갖게 되는 것을 확인할 수 있습니다.

Anomaly 를 계산하는 Loss 함수

위의 설명까지는 DCGAN과 완전히 동일합니다. 이제 anomaly score를 구하는 loss와 anomaly한 영역을 추출하는 방법을 알아보겠습니다. 아래 Loss는 GAN 모델을 학습하기 위한 것이 아니고, 입력 이미지와 가장 유사한 이미지를 생성하기 위해 latent space의 z 벡터를 업데이트(back propagation)하기 위한 Loss 입니다.

Loss는 Residual Loss와 Discrimination Loss 두가지를 합하여 계산하게 됩니다.

Redisual Loss

아래 수식을 통해 구하게 되고 x\bold xG(zγ)G(\bold z_\gamma)는 각각 real image와 생성된 fake image로 둘의 차이를 계산하는 것이기 때문에, MAE Loss로 생각하시면 됩니다.

Discrimintation Loss

마찬가지로 아래 수식을 통해 구합니다. 아래에서 f\bold f는 Discriminator 에 이미지를 입력했을 때의 출력 feature vector를 반환하는 함수라고 생각하시면 됩니다. 따라서 Discriminator를 통한 x\bold x의 feature와 G(zγ)G(\bold z_\gamma)의 feature 차이를 계산하는 것으로 feature 단위의 MAE Loss로 생각하면 됩니다.

전체 Loss

위에서 구한 2개의 로스의 합으로 각각 1λ,λ1-\lambda, \lambda 의 가중치를 곱하여 더해줍니다. 최종적으로 아래의 식이 입력 이미지와 가장 유사한 이미지를 생성해내는 latent vector z를 업데이트 할 Loss가 됩니다.

Anomaly score를 구하는 방법

어떤 이미지(real image)를 입력했을 때, 정상인지 비정상인지 분류하는 방법은 다음과 같습니다.

  1. 랜덤으로 생성한 latent vector z를 Generator에 입력하여 fake image를 생성한다.
  1. 위에서 구한 L(zγ)L(\bold z_\gamma)의 backpropagation을 500회 반복하여 latent vector z를 update 한다.
  1. 2번으로부터 얻어진 vector z를 Generator에 통과하여 생성된 이미지를 입력 이미지(real image)와 가장 유사한 이미지라고 판단하여, 이때 생성된 이미지와 입력 이미지를 비교하여 anomaly score를 구합니다.
  1. 아래 수식은 anomaly score를 계산하는 식으로 R(x)R(\bold x)는 Residual Loss, D(x)D(\bold x)는 Discrimination Loss와 구하는 방법이 같습니다. 논문에서는 λ\lambda를 0.1로 두고 사용합니다.

코드 구현

GAN의 코드는 다른데서 쉽게 찾아 볼 수 있기 때문에, latent vector z를 update하는 코드만 알아보겠습니다.

우선 loss는 MSE를 optimizer는 Adam을 사용하였습니다. 최초의 z_vector는 정규분포에서 랜덤하게 가져오고, 이때 required_grad를 True로 해서 gradient를 계산하도록 해줍니다. 또한 z_vector를 update하기 위해 Adam 옵티마이저의 parameter로 넣어 줍니다.

# set loss function
criterion = torch.nn.MSELoss()

# get one random z vector
z_vector = torch.randn(1, 64, 1, 1, device=device, requires_grad=True)

# set optimizer to update z vector
optimizer = torch.optim.Adam([z_vector])

z_vector로부터 fake 이미지를 생성하고 위의 loss 수식에 따라 MSE(x, G(z)), MSE(f(x), f(G(z)))를 loss로 계산하여 업데이트 해줍니다.


# iteration to update z vector
for i in tqdm(range(2001)):

    # generate fake from z
    fake = netG(z_vector)

    # get feature from discriminator
    f_real = netD(real)
    f_fake = netD(fake)

    # get loss
    lossR = criterion(real, fake)
    lossD = criterion(f_real, f_fake)
    loss = (1 - alpha) * lossR + alpha * lossD

    # update z vector
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

전체 코드는 여기서 확인할 수 있습니다.

MNIST 학습 결과

아래는 MNIST로 직접 학습한 결과로 숫자 2는 학습에 사용하였고, 6은 사용하지 않았습니다.

왼쪽에서 오른쪽으로 갈 수록 latent vector z가 업데이트 되며 입력 이미지와 가장 유사한 이미지를 생성하는 latent vector z가 update 됩니다.

6은 학습된 이미지 분포에 포함되지 않기 때문에(abnormal이기 때문에) 0과 6 사이의 애매한 이미지가 최종적으로 생성되었습니다.

빨간 부분은 입력 이미지와 생성된 이미지의 차영상이고 해당 부분을 anomaly 영역으로 판단하게 됩니다.

anomaly score를 비교하면 2는 0.027, 6은 0.067로 적절한 threshold를 취하면 normal/abnormal을 분류를 할 수 있을 것 같습니다.

참고자료

[1] 논문링크: https://arxiv.org/abs/1703.05921

[2] pytorch 코드: https://github.com/jmjeon94/AnoGAN-pytorch

반응형
댓글
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday