====== Úkol 3 ====== ==== Trénování SVM klasifikátoru ==== Jedním z nejpopulárnějších klasifikátorů je [[https://en.wikipedia.org/wiki/Support-vector_machine|SVM]] (support vector machine). Mějme $n$ pozorování $\mathbf x_1,\dots,\mathbf x_n\in\mathbb R^d$. Pro jednoduchost předpokládejme, že se jedná o binární klasifikaci s labely $y_1,\dots,y_n\in\{-1,1\}$, a že bias je už zakomponován v datech. SVM hledá lineární oddělovací nadrovinu $\mathbf w^\top \mathbf x$ tak, že si pro každé pozorování $i$ zavede pomocnou proměnnou $\xi_i\ge 0$, která měří míru misklasifikace. Toho dosáhne pomocí podmínky $$ \xi_i \ge 1 - y_i\mathbf w^\top \mathbf x_i. $$ Vzhledem k tomu, že budeme chtít minimalizovat $\xi_i$, jeho minimální hodnoty dosáhneme pro $y_i\mathbf w^\top \mathbf x_i\ge 1$. Tedy pro label $y_i=1$ chceme podmínku $\mathbf w^\top \mathbf x_i\ge 1$ a podobně pro záporný label. Toto je silnější podmínka než klasické $\mathbf w^\top \mathbf x_i\ge 0$. Jinými slovy se SVM snaží o dosažení určité vzdálenost dobře klasifikovaných pozorování od oddělovací nadroviny, což přidává robustnost. Celý optimalizační problém pak jde zapsat jako \begin{align*} \operatorname*{minimalizuj}_{w,\xi}\qquad &C\sum_{i=1}^n\xi_i + \frac12||\mathbf w||^2 \\ \text{za podmínek}\qquad &\xi_i \ge 1 - y_i\mathbf w^\top \mathbf x_i, \\ &\xi_i\ge 0. \end{align*} V účelové funkci se objevila regularizace $\frac12||\mathbf w||^2$ a regularizační konstanta $C>0$. I když jde tento problém řešit v této formulaci, standardně se převede do duální formulace \begin{align*} \operatorname*{maximalizuj}_{z}\qquad &-\frac12\mathbf z^\top Q\mathbf z + \sum_{i=1}^nz_i \\ \text{za podmínek}\qquad &0\le z_i\le C, \end{align*} kde matice $Q$ má prvky $q_{ij}=y_iy_j\mathbf x_i^\top \mathbf x_j$. Po vyřešení tohoto problému se řešení dostane pomocí $\mathbf w=\sum_{i=1}^ny_iz_i\mathbf x_i$. Tento problém jde vyřešit pomocí metody [[https://en.wikipedia.org/wiki/Coordinate_descent|coordinate descent]]. Tato metoda v každé iteraci updatuje pouze jednu komponentu vektoru $\mathbf z$. To znamená, že optimalizaci přes $\mathbf z$ nahradíme optimalizaci přes $\mathbf z+d\mathbf e_i$, kde $\mathbf e_i$ je nulový vektor s jedničkou na komponentě $i$. V každé iteraci se pak řeší \begin{align*} \operatorname*{maximalizuj}_d\qquad &-\frac12(\mathbf z+d\mathbf e_i)^\top Q(\mathbf z+d\mathbf e_i) + \sum_{i=1}^nz_i + d\\ \text{za podmínek}\qquad &0\le z_i + d\le C. \end{align*} Tento optimalizační problém je jednoduchý, neboť $d\in\mathbb R$ a existuje řešení v uzavřené formě. ==== Zadání ==== Naimplementujte funkce ''Q = computeQ(X, y)'', ''w = computeW(X, y, z)'', ''z = solve_SVM_dual(Q, C; max_iter)'' a ''w = solve_SVM(X, y, C; kwargs...)''. Toto schéma ukazuje jak vstupní tak výstupní argumenty. Podrobněji tyto argumenty jsou: * ''X'': matice $n\times d$ vstupních dat; * ''y'': vektor $n\times 1$ vstupních labelů; * ''C'': kladná regularizační konstanta; * ''w'': řešení primárního problému $\mathbf w$; * ''z'': řešení duálního problému $\mathbf z$; * ''Q'': matice duálního problému $Q$. Funkce ''computeQ'' spočte matici ''Q'' a funkce ''computeW'' spočte ''w'' z duálního řešení. Funkce ''solve_SVM_dual'' dostává jako vstupy parametry duálního problému a udělá ''max_iter'' iterací metody coordinate descent. V každé iteraci se popořadě spočítá $n$ updatů účelové funkce dle algoritmu. Funkce ''solve_SVM'' zkombinuje předchozí funkce do jedné. Nakonec napište funkci ''w = iris(C; kwargs...)'', která načte iris dataset, jako pozitivní třídu bude uvažovat ''versicolor'', jako negativní třídu ''virginica'', jako featury použije ''PetalLength'' a ''PetalWidth'', tyto vstupy znormuje, přidá bias a nakonec použije výše napsané funkce na spočtení optimální oddělovací nadroviny ''w''. ==== Odevzdání a vyhodnocení ==== Úkoly budou automaticky vyhodnocovány systémem [[https://cw.felk.cvut.cz/brute/teacher/course/1292|BRUTE]]. Vypracovaný úkol je třeba před nahráním uložit do souboru s názvem ''run.jl'' a zabalit do formátu ''.zip''. Můžete použít balík ''RDatasets'' a Julia core knihovny (''Statistics'', ''LinearAlgebra'', ''Random'', ...).