Warning
This page is located in archive.

Úkol 3

Trénování SVM klasifikátoru

Jedním z nejpopulárnějších klasifikátorů je SVM (support vector machine). Mějme $n$ pozorování $\mathbf x_1,\dots,\mathbf x_n\in\mathbb R^d$. Pro jednoduchost předpokládejme, že se jedná o binární klasifikaci s labely $y_1,\dots,y_n\in\{-1,1\}$, a že bias je už zakomponován v datech. SVM hledá lineární oddělovací nadrovinu $\mathbf w^\top \mathbf x$ tak, že si pro každé pozorování $i$ zavede pomocnou proměnnou $\xi_i\ge 0$, která měří míru misklasifikace. Toho dosáhne pomocí podmínky $$ \xi_i \ge 1 - y_i\mathbf w^\top \mathbf x_i. $$ Vzhledem k tomu, že budeme chtít minimalizovat $\xi_i$, jeho minimální hodnoty (která se rovná nule) dosáhneme pro $y_i\mathbf w^\top \mathbf x_i\ge 1$. Tedy pro label $y_i=1$ chceme podmínku $\mathbf w^\top \mathbf x_i\ge 1$ a pro label $y_i=-1$ chceme podmínku $\mathbf w^\top \mathbf x_i\le -1$. Toto jsou silnější podmínky než klasické $\mathbf w^\top \mathbf x_i\ge 0$ a $\mathbf w^\top \mathbf x_i\le 0$. Jinými slovy se SVM snaží o dosažení určité vzdálenost dobře klasifikovaných pozorování od oddělovací nadroviny, což přidává robustnost.

Celý optimalizační problém pak jde zapsat jako \begin{align*} \operatorname*{minimalizuj}_{w,\xi}\qquad &C\sum_{i=1}^n\xi_i + \frac12||\mathbf w||^2 \\ \text{za podmínek}\qquad &\xi_i \ge 1 - y_i\mathbf w^\top \mathbf x_i, \\ &\xi_i\ge 0. \end{align*} V účelové funkci se objevila regularizace $\frac12||\mathbf w||^2$ a regularizační konstanta $C>0$. I když jde tento problém řešit v této formulaci, standardně se převede do duální formulace \begin{align*} \operatorname*{maximalizuj}_{z}\qquad &-\frac12\mathbf z^\top Q\mathbf z + \sum_{i=1}^nz_i \\ \text{za podmínek}\qquad &0\le z_i\le C, \end{align*} kde matice $Q$ má prvky $q_{ij}=y_iy_j\mathbf x_i^\top \mathbf x_j$. Mezi primárním a duálním problémem je velká souvislost. Zaprvé je optimální hodnota obou problémů stejná. Zadruhé když se vyřeší duální problém, tak řešení primárního problému se dostane pomocí $\mathbf w=\sum_{i=1}^ny_iz_i\mathbf x_i$.

Duální problém jde vyřešit pomocí metody coordinate descent. Tato metoda v každé iteraci updatuje pouze jednu komponentu vektoru $\mathbf z$. Tedy v každé iteraci zafixujeme nějaké $i$ a optimalizaci přes $\mathbf z$ nahradíme optimalizaci přes $\mathbf z+d\mathbf e_i$, kde $\mathbf e_i$ je nulový vektor s jedničkou na komponentě $i$. V každé iteraci se pak řeší \begin{align*} \operatorname*{maximalizuj}_d\qquad &-\frac12(\mathbf z+d\mathbf e_i)^\top Q(\mathbf z+d\mathbf e_i) + \sum_{i=1}^nz_i + d\\ \text{za podmínek}\qquad &0\le z_i + d\le C. \end{align*} Tento optimalizační problém je jednoduchý, neboť $d\in\mathbb R$ a existuje řešení v uzavřené formě. Toto se opakuje pro předem daný počet epoch, kde jedna epocha udělá výše zmíněné iterace postupně pro $i=1$, $i=2$ až po $i=n$.

Zadání

Naimplementujte funkce Q = computeQ(X, y), w = computeW(X, y, z), z = solve_SVM_dual(Q, C; max_epoch) a w = solve_SVM(X, y, C; kwargs…). Toto schéma ukazuje jak vstupní tak výstupní argumenty. Podrobněji tyto argumenty jsou:

  • X: matice $n\times d$ vstupních dat;
  • y: vektor $n$ vstupních labelů;
  • C: kladná regularizační konstanta;
  • w: vektor řešení primárního problému $\mathbf w$;
  • z: vektor řešení duálního problému $\mathbf z$;
  • Q: matice duálního problému $Q$.

Funkce computeQ spočte matici Q a funkce computeW spočte w z duálního řešení. Funkce solve_SVM_dual dostává jako vstupy parametry duálního problému a udělá max_epoch epoch (tedy n*max_epoch iterací) metody coordinate descent. Počáteční řešení bude $z=0$. Funkce solve_SVM zkombinuje předchozí funkce do jedné.

Nakonec napište funkci w = iris(C; kwargs…), která načte iris dataset, jako pozitivní třídu bude uvažovat versicolor, jako negativní třídu virginica, jako featury použije PetalLength a PetalWidth, tyto vstupy znormuje, přidá bias (jako poslední sloupec matice $X$) a nakonec použije výše napsané funkce na spočtení optimální oddělovací nadroviny w.

Rozmyslete si, které funkce musíte volat s klíčovými argumenty kwargs.

Doporučené testování

Psaní následujících testů není nutné. Může vám však pomoci v hledání chyb. Jako testy můžete například použít:

  • Ověřit správnost duálního řešení pomocí vyzkoušení velkého počtu různých hodnot $d$.
  • Ověřit rovnost optimálních hodnot primárního a duálního problému.
  • Ověřit správnou propagaci kwargs jejich nahrazením různými hodnotami max_epoch.
  • Cokoli dalšího.
courses/b0b36jul/hw/hw3.txt · Last modified: 2023/09/20 16:12 by machava2