深層距離学習における平均場理論

深層距離学習における平均場理論

はじめに

こんにちは、ZOZO研究所AppliedMLチームの古澤です。私たちは商品画像の検索の基礎として、深層距離学習という技術を研究しています。本記事では、本研究所からICLR2024に採択された「Mean Field Theory in Deep Metric Learning」という研究について紹介します。対象読者としては、機械学習系のエンジニアや学生を想定しています。

目次

Notation

この記事では以下のnotationを使用します。

  • 訓練データ: \mathsf{D} = \{ x_i, y_i\}^{|\mathsf{D}|}_{i=1}
  • 画像などのデータ: x_i
  • クラスラベル: y_i \in \mathsf{C} = \{1,\cdots,|\mathsf{C}|\}
  • クラス cに属する訓練データ: \mathsf{D}_{c}=\{x|(x,y)\in\mathsf{D}, y=c\}
  • データ xを長さ1のd次元ベクトルに写す特徴抽出器: \mathbf{F}_\theta(x)
  • 特徴空間の距離: d(\mathbf{F}, \mathbf{F}')=1-\mathbf{F}\cdot \mathbf{F}'\ge 0

深層距離学習

深層距離学習は、データ間の距離や類似度を学習するタスクであり、画像認識、顔認証、推薦システムなど、多くのアプリケーションで利用されています。

深層距離学習では、画像やその他の複雑なデータから、データの類似度を反映した特徴ベクトルを抽出することを目指しています。これを実現するためには、ベクトル空間内で類似するデータ点が近く、異なるデータ点が遠くなるように損失関数を設計することが重要です。

深層距離学習の損失関数の典型例として、Contrastive Loss1が知られています。

 \displaystyle
\mathsf{L}_\mathrm{Cont.}(\{ \mathbf{F}_\theta (x_i) \cdot \mathbf{F}_\theta (x_j)  \}_{i,j}) = \frac{1}{2|\mathsf{C}|} \sum_{c \in \mathsf{C}} \frac{1}{|\mathsf{D}_c|^2}\sum_{i,j\in \mathsf{D}_c} \Big[ d(\mathbf{F}_\theta(x_i),\mathbf{F}_\theta(x_j)) - m_{\; p} \Big]_+ \\ + \frac{1}{2|\mathsf{C}|}\sum_{c \neq c'}\frac{1}{|\mathsf{D}_c||\mathsf{D}_{c'}|} \sum_{i\in \mathsf{D}_c, j\in \mathsf{D}_{c'}} \Big[m_{\; n} - d(\mathbf{F}_\theta(x_i),\mathbf{F}_\theta(x_j) ) \Big]_+

ここで、 [x]_+ = \max(x,0)です。第一項は同じクラス内のデータ点(正例ペア)の相互作用を、第二項は異なるクラスのデータ点(負例ペア)の相互作用を表しています。 m_{\; p} m_{\; n})は正例(負例)ペア間の距離を制御するハイパーパラメータです。正例(負例)ペアの距離が m_{\; p} m_{\; n})よりも小さく(大きく)なるように学習が進むことを意味しています。

Contrastive Lossのようなペアに基づく損失関数の場合、可能なペアの組み合わせはデータ数が増えるに従って多項式的に増加します。このペアのうちの多くは学習に寄与しない簡単なペアであるため、学習が遅くなってしまうという課題があります。

一方で、深層距離学習の損失関数としてはペアを使用せず、分類問題と同じように訓練できる損失関数も存在します。

今回の研究では、統計物理学における解析手法である平均場理論を使用することで、ペアに基づく損失関数から、分類問題と同じように訓練できる損失関数を導出する方法を確立することを目的としています。

磁性体と平均場理論

今回の研究は、統計物理学における磁性体のモデルと深層距離学習とのアナロジーに基づいています。磁性体、特に強磁性体は、その物質の微視的な磁気モーメント(スピン)が互いに相互作用し、同じ方向を向くことで巨視的な磁化を示します。磁性体の相転移現象は、統計力学の典型的な問題として知られています。2

磁性体

磁性体の簡単なモデルとして、以下のような無限レンジモデルを考えてみます。無限レンジモデルのエネルギー関数は次のように表されます。

 \displaystyle
\mathsf{H}(\{ \mathbf{S}_i \cdot \mathbf{S}_j  \}_{i,j}) = -\frac{J}{2N} \sum^N_{i,j=1} \mathbf{S}_i \cdot \mathbf{S}_j

ここで、Nはスピンの総数で、 J\gt0とします。 \mathbf{S}_iは半径1の球面上に値をとるベクトルで、i番目のスピンを表します。この \mathsf{H}は、スピンが同じ方向を指している状態がエネルギー的に好まれることを示しています。

統計力学によると、温度Tでのスピン配位の確率分布はギブス分布に従います。

 \displaystyle
P(\{ \mathbf{S}_i \}_i) = \frac{\mathrm{e}^{-\mathsf{H}/T}}{Z}, \quad Z = \int \prod_i d^2\mathbf{S}_i \mathrm{e}^{-\mathsf{H}/T}

このモデルの熱力学的に重要な特性は、分配関数 Zから計算できます。

しかし、スピン間の相互作用のため、分配関数 Zの積分の解析的、数値的な計算は困難に見えます。3

平均場理論

次に平均場理論について説明します。平均場理論の基本的なアイデアは、各スピンが他のスピンと相互作用する際に、他のスピンをその「平均場」で近似し、揺らぎを無視することです。これにより、各スピンが他のスピンと直接相互作用するのではなく、平均場と相互作用するようにハミルトニアンを近似できます。

具体的には、 \mathbf{S}_i = \mathbf{M} + (\mathbf{S}_i -\mathbf{M})という恒等式を使用して \mathsf{H} {(\mathbf{S}_i -\mathbf{M})}_iの揺らぎに関してTaylor展開し、スピン間の揺らぎの交差項を無視します。この操作によりエネルギー関数と分配関数は次のようになります。

 \displaystyle
\mathsf{H}(\{ \mathbf{S}_i \cdot \mathbf{S}_j  \}_{i,j}) \simeq \mathsf{H}_\mathrm{mft}(\{ \mathbf{S}_i \cdot \mathbf{M} \}_{i}, \mathbf{M} \cdot \mathbf{M}) = \frac{JN}{2} \mathbf{M} \cdot \mathbf{M} - J\mathbf{M} \cdot \sum^N_{i=1} \mathbf{S}_i
 \displaystyle
P_\mathrm{mft}(\{ \mathbf{S}_i \}_i) = \frac{\mathrm{e}^{-\mathsf{H}_\mathrm{mft}/T}}{Z_\mathrm{mft}}, \quad Z_\mathrm{mft} = \int \prod_i d^2\mathbf{S}_i \mathrm{e}^{-\mathsf{H}_\mathrm{mft}/T}

ここで、平均場 \mathbf{M} -\log Z_\mathrm{mft}を最小化することで決定されます。この条件は次の自己整合方程式に帰着します。

 \displaystyle
\mathbf{M} = \frac{1}{N}\sum_i\mathbb{E}[\mathbf{S}_i]

この式は、平均場 \mathbf{M}が実際に他のスピンの平均を表しており、実際に平均値周りでの揺らぎに対して展開が行われていたことを示しています。

平均場理論により、新しく最適化パラメータが導入され、スピン間の相互作用が単純化されます。これにより、分配関数に含まれる積分は独立に実行可能となります。

磁性体と深層距離学習のアナロジー

次に、分配関数のT=0の極限を考えてみましょう。この極限では、分配関数に最も寄与するスピン配位はエネルギー関数 \mathsf{H}を最小化する配位となります。これは、T=0の極限で、元の問題が \mathsf{H}を最小化するスピン配位を見つける問題と同等であることを意味します。

平均場理論を適用すると、この問題はハミルトニアン \mathsf{H}_\mathrm{mft}を最小化する問題に変換されます。そして、この極限では、 -\log Z_\mathrm{mft} \{\mathbf{S}_i\}_iに対して最小化された \mathsf{H}_\mathrm{mft}に比例します。したがって、 \mathsf{H}_\mathrm{mft} \{\mathbf{S}_i\}_i \mathbf{M}について最小化する問題となることがわかります。すなわち、以下のように問題が変換されたことになります。

 \displaystyle
\min_{\{\mathbf{S}_i\}_i} \; \mathsf{H}(\{ \mathbf{S}_i \cdot \mathbf{S}_i  \}_{i,j}) \rightarrow \min_{\mathbf{M},\{\mathbf{S}_i\}_i} \mathsf{H}_\mathrm{mft}(\{ \mathbf{S}_i \cdot \mathbf{M} \}_{i}, \mathbf{M} \cdot \mathbf{M})

一方で、深層距離学習では、損失関数を最小化する最適な学習パラメータ \thetaを見つけることが目標であり、次のように書くことができます。

 \displaystyle
\min_{\theta} \; \mathsf{L}(\{ \mathbf{F}_\theta (x_i) \cdot \mathbf{F}_\theta (x_j)  \}_{i,j})

この形式は磁性体の問題と同様であり、どちらもペア間の相互作用の関数として記述されています。このアナロジーは、平均場理論を深層距離学習に適用することで、ペア間の相互作用を平均場との相互作用に置き換えることが可能であることを示唆しています。

深層学習への応用

次に、深層距離学習の損失関数に平均場理論を適用してみます。簡単のため、Contrastive Lossに平均場理論を適用し、その後、より複雑な損失関数への適用について議論します。

Contrastive Loss

平均場理論を適用するために、各クラスの平均場 \{\mathbf{M}_c\}_{c\in\mathsf{C}}を導入し、 \mathsf{L}_\mathrm{Cont.}をこれらの平均場周りの揺らぎに関して展開します。この際、平均場間の相対距離を制約する条件を課します。

 \displaystyle
\Big[m_{\; n} - d(\mathbf{M}_c,\mathbf{M}_{c'})\Big]_+ = 0 \quad \Bigl( c\neq c' \Bigr)

この条件は、0次の展開で \mathsf{L}_\mathrm{Cont.}を最小化する平均場のみを探索することを意味します。

磁性体の場合と同様、各特徴ベクトル \mathbf{F}_\theta(x_i)について、同じクラスの平均場 \mathbf{M}_cの周りで展開します。

展開の際にはやはり交差項は無視しますが、揺らぎの自己相互作用に関しては残しつつ、resummationを行います。結果として、以下のMeanFieldContrastive(MFCont)Lossが得られます。

 \displaystyle
\mathsf{L}_\mathrm{MFCont.} = \frac{1}{|\mathsf{C}|} \sum_{c \in \mathsf{C}} \frac{1}{|\mathsf{D}_c|}\sum_{i\in \mathsf{D}_c} \Big( \Big[d(\mathbf{F}_\theta(x_i),\mathbf{M}_c) - m_{\; p}\Big]_+ \\ + \sum_{c'(\neq c)}\Big[m_{\; n} - d(\mathbf{F}_\theta(x_i),\mathbf{M}_c)\Big]_+ \Big) \\ + \frac{\lambda_\mathrm{MF}}{|\mathsf{C}|} \sum_{c \neq c'} \Big[m_{\; n}-d(\mathbf{M}_c,\mathbf{M}_{c'})\Big]^2_+

ここで、平均場に対する制約条件は \lambda_\mathrm{MF}を用いてソフトに取り入れられています。以上の議論から、深層距離学習の場合でも、平均場を導入することでペアの相互作用を取り除き、通常の分類問題の形に帰着できることがわかります。

ClassWiseMultiSimilarity Loss

より発展的な損失関数として、MultiSimilarity Loss4のようなミニバッチ内のペア間の相互作用を考慮した損失関数の平均場理論について考察します。

Taylor展開を簡単にするためには、 x_i x_jに関して対称な損失関数が望ましいです。次のような新しい損失関数を導入しましょう。

 \displaystyle
\mathsf{L}_\mathrm{CWMS} = \frac{1}{\alpha |\mathsf{C}|} \sum_{c \in \mathsf{C}} \log\left[ 1+ \frac{\sum_{i,j \in \mathsf{D}_c, i \neq j} \mathrm{e}^{\alpha ( d(\mathbf{F}_\theta(x_i),\mathbf{F}_\theta(x_j)) - \delta)}}{2|\mathsf{D}_c|^2} \; \, \right] \\ + \frac{1}{2\beta |\mathsf{C}|} \sum_{c \neq c'} \log\left[ 1+ \frac{\sum_{i \in \mathsf{D}_c, j \in \mathsf{D}_{c'}} \mathrm{e}^{ - \beta ( d(\mathbf{F}_\theta(x_i),\mathbf{F}_\theta(x_j)) - \delta)}}{|\mathsf{D}_c||\mathsf{D}_{c'}|}  \; \, \right]

ここで、 \alpha\gt0 \beta\gt0 \delta\in[-1,1]はハイパーパラメータです。この損失関数はMultiSimilarity Lossに似ていますが、クラス別の負のサンプル間の相互作用を含むため、ClassWiseMultiSimilarity(CWMS)Lossと呼びます。

次に、この損失関数に平均場理論を適用してみましょう。最初と2番目の項のlogの中身は、Contrastive損失の正と負の相互作用に似た形をとります。このため、Contrastive Lossの議論を繰り返すことで、MeanFieldClassWiseMultiSimilarity(MFCWMS)Lossを導くことができます。

 \displaystyle
\mathsf{L}_\mathrm{MFCWMS} = \frac{1}{\alpha |\mathsf{C}|} \sum_{c \in \mathsf{C}} \log\left[ 1 + \frac{\sum_{i\in \mathsf{D}_c} \mathrm{e}^{\alpha ( d(\mathbf{F}_\theta(x_i),\mathbf{M}_c) - \delta)}}{|\mathsf{D}_c|} \; \, \right] \\ + \frac{1}{2\beta |\mathsf{C}|} \sum_{c \neq c'} \log\left[ 1+ \frac{\sum_{i \in \mathsf{D}_c} \mathrm{e}^{ - \beta ( d(\mathbf{F}_\theta(x_i),\mathbf{M}_{c'}) - \delta)}}{|\mathsf{D}_c|} + \frac{\sum_{j \in \mathsf{D}_{c'}} \mathrm{e}^{ - \beta ( d(\mathbf{M}_c,\mathbf{F}_\theta(x_j)) - \delta)}}{|\mathsf{D}_{c'}|} \; \, \right] \\ + \frac{\lambda_\mathrm{MF}}{|\mathsf{C}|} \sum_{c \neq c'} \left(\log\left[ 1+ \mathrm{e}^{ - \beta ( d(\mathbf{M}_c,\mathbf{M}_{c'}) - \delta)} \; \; \; \right]\right)^2

ここで、平均場の制約条件もソフトに導入しています。この損失関数は、各サンプル間の相互作用も取り入れており、より複雑な関係性まで取り入れられていることがわかります。

実験

今回は従来のMetric learningの評価手法とA Metric Learning Reality Check (MLRC)5で導入されたより公平なベンチマーク手法の両方を使用して評価しました。

評価指標

深層距離学習の単純な評価指標としてはPrecision@K(P@K)やRecall@K(R@K)などが考えられます。しかし、今回はMLRCに従い、Mean Average Precision at R(MAP@R)と呼ばれる指標を使用します。

 \displaystyle
\mathrm{MAP@R} = \frac{1}{R} \sum^R_{k=1} P(k)

各クエリに対し、 P(k)はk番目に近いデータ点が同じクラスの場合はP@k、そうでなければ0となる関数です。また、Rはクエリと同じクラスに属するデータ点の数を表します。

MAP@Rは最近傍のデータ点が同じクラスかという情報だけではなく、順位の情報も反映しています。このため、データ点がよりよくクラスタリングされているかを評価できます。

評価指標の比較 (figure from A Metric Learning Reality Check)

データセット

実験では、以下の4つの画像検索用データセットを使用しました。

  • CUB-200-2011 (CUB):200クラス・11788枚の鳥の画像データセット6
  • Cars-196 (Cars):196クラス・16185枚の車の画像データセット7
  • Stanford Online Products (SOP):22634クラス・120053枚の商品画像データセット8
  • InShop:7982クラス・52712枚のファッション商品画像データセット9

CUBとCarsと比較するとSOPとInShopは規模が大きいデータセットです。

MLRCのベンチマークではCUB、Cars、SOPを、通常の評価手法ではすべてのデータセットを使用しました。

実装の詳細

ベースとなる埋め込みモデル \mathbf{F}_\thetaには、ImageNetで事前学習されたBN-Inceptionネットワークを使用し、最後の線形層を所望の次元の特徴ベクトルを得られるように置き換えました。

MLRCの評価手法では、損失関数のハイパーパラメータ最適化のため、ベイズ最適化を50回繰り返します。データセットはtrain-valid(最初の半分のクラス)とtestデータセット(残り)に分割しました。さらに、train-validセットをクラスが重複しないように4分割し、4-foldのクロスバリデーションを実施しました。クロスバリデーションにおけるMAP@Rの平均をベイズ最適化の目的関数として使用しました。また、特徴ベクトルの次元は128次元に、バッチサイズは32に設定しました。

テスト段階では、最適なハイパーパラメータで同様にクロスバリデーションを実行し、4つの埋め込みモデルを得ました。それぞれの特徴ベクトルを独立に使用した場合のMAP@Rの平均と、4つの特徴ベクトルを結合して新しい512次元の特徴ベクトルを作成し、これをもとに計算したMAP@Rを評価に使用しました。テストでは上記の施行を10回繰り返し、各指標の平均値と95%信頼区間を報告しました。

一方、通常の評価手法では、データセットをクラス間が重複しないようにtrainとevaluationの2つに分割しました。各エポックでevaluationスコアを計算し、その最大値を最終的なevaluationスコアとして採用しました。また、特徴ベクトルの次元は512次元に、バッチサイズは128に設定しました。こちらも同様に10回繰り返し、その平均値と95%信頼区間を報告しました。

比較手法

比較手法としては、以下の損失関数を採用しました。

  • Contrastive (Cont.)
  • ClassWiseMultiSimilarity (CWMS)
  • MultiSimilarity (MS)
  • MultiSimilarity + Miner (MS+Miner)
  • ArcFace10
  • CosFace11
  • ProxyNCA12
  • ProxyAnchor (ProxyAnch.)13

最初の4つはペアに基づく損失関数で、残りの4つは分類問題と同様に訓練できる損失関数です。特にProxyAnchor Lossは非常に性能の良い損失関数と考えられています。

定量評価

MLRCベンチマークでは、MFContとMFCWMSは、ほとんどの場合で元の損失関数よりも優れたスコアを示しました。これは、平均場理論を適用することで、学習を単純化するだけでなく、より汎化性能が高い特徴空間を学習できることを意味しています。

CarsデータセットではProxyAnchorおよびCWMSがMFContとMFCWMSよりも優れたパフォーマンスを示しています。しかし、CUBおよびSOPデータセットでは、分離されたMAP@Rおよび連結されたMAP@Rの両方で他のベースライン手法を一貫して上回りました。

ベンチマーク結果

また、従来の評価方法では、Carsデータセットを除く全てのデータセットにおいて、MFContとMFCWMSはProxyAnchor Lossおよび元の損失関数の性能を上回りました。これはMLRCベンチマークとも一致する結果です。精度の向上は特に大きなデータセットで顕著でした。さらに、全てのデータセットにおいて、MFContとMFCWMSは他の損失関数よりも早く収束することが確認されました。

評価結果

まとめ

この記事では、統計物理学における解析手法である平均場理論を深層距離学習に適用した研究を紹介しました。特に、平均場理論を用いることで、学習が難しいペアに基づく損失関数を、分類問題のような損失関数に帰着させることができることを示しました。さらに、平均場理論をContrastive LossとCWMS Lossに適用し、新しい損失関数としてMFCont LossとMFCWM Lossを提案しました。

導出された損失関数の評価した結果、これらは比較手法と比較して、多くのデータセットで優れたパフォーマンスを示すことが確認されました。この結果は、深層距離学習における平均場理論の有効性を示唆しています。

ZOZO研究所では、一緒にサービスを作り上げてくれる方を募集中です。ご興味のある方は、以下のリンクからぜひご応募ください。

hrmos.co hrmos.co

参考文献・注釈


  1. Raia Hadsell, Sumit Chopra, and Yann LeCun, "Dimensionality reduction by learning an invariant mapping", In 2006 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR’06), volume 2, pages 1735–1742. IEEE, 2006.
  2. Nishimori, Hidetoshi, and Gerardo Ortiz, "Elements of Phase Transitions and Critical Phenomena", (Oxford, 2010; online edn, Oxford Academic, 1 Jan. 2011),
  3. 無限レンジモデルの場合は、実は解析的に分配関数を計算可能です。しかし、より現実的なスピンが格子上にある場合などについては、通常、解析的な計算は困難です。
  4. Xun Wang, Xintong Han, Weilin Huang, Dengke Dong, and Matthew R Scott, "Multi-similarity loss with general pair weighting for deep metric learning", In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 5022–5030, 2019.
  5. Kevin Musgrave, Serge Belongie, and Ser-Nam Lim, "A metric learning reality check", In Computer Vision–ECCV 2020: 16th European Conference, Glasgow, UK, August 23–28, 2020, Proceedings, Part XXV 16, pages 681–699. Springer, 2020.
  6. Catherine Wah, Steve Branson, Peter Welinder, Pietro Perona, and Serge Belongie, "The caltech-ucsd birds-200-2011 dataset", 2011.
  7. Jonathan Krause, Michael Stark, Jia Deng, and Li Fei-Fei, "3d object representations for finegrained categorization", In Proceedings of the IEEE international conference on computer vision workshops, pages 554–561, 2013.
  8. Hyun Oh Song, Yu Xiang, Stefanie Jegelka, and Silvio Savarese, "Deep metric learning via lifted structured feature embedding", In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 4004–4012, 2016.
  9. Ziwei Liu, Ping Luo, Shi Qiu, Xiaogang Wang, and Xiaoou Tang, "Deepfashion: Powering robust clothes recognition and retrieval with rich annotations", In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 1096–1104, 2016.
  10. Jiankang Deng, Jia Guo, Niannan Xue, and Stefanos Zafeiriou, "Arcface: Additive angular margin loss for deep face recognition", In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 4690–4699, 2019.
  11. Feng Wang, Jian Cheng, Weiyang Liu, and Haijun Liu, "Additive margin softmax for face verification", IEEE Signal Processing Letters, 25(7):926–930, 2018a.
  12. Yair Movshovitz-Attias, Alexander Toshev, Thomas K Leung, Sergey Ioffe, and Saurabh Singh, "No fuss distance metric learning using proxies", In Proceedings of the IEEE international conference on computer vision, pp. 360–368, 2017.
  13. Sungyeon Kim, Dongwon Kim, Minsu Cho, and Suha Kwak, "Proxy anchor loss for deep metric learning", In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 3238–3247, 2020.
カテゴリー