Search
📅 Deadline: 8.1.2025 21:59
🏦 Points: 5
In this assignment, you are tasked with implementing the EM algorithm for estimating class prior on unlabeled test data. Afterwards, you must use the estimated prior to account for prior shift in the plugin Bayes classifier. You can find the complete description of the assignment in the Assignment .
You are provided with a template containing the following files:
Your objective is to implement the function e_step which estimates p_{test}(y|x), the function m_step which estimates p_{test}(y), the function compute_test_priors which runs the EM algorithm to estimate p_{test}(y), and the function bayes_classifier_with_prior_shift which implements the plugin-bayes classifier adapted to the prior shift. All of these functions can be found in main.py.
Go make yourself a coffee while you wait.
https://arxiv.org/abs/2106.11695
After completing your implementation, you can test your solution using the following commands before submitting it to BRUTE:
python main.py test-cases/public/instances/instance_1.json --plot
Expected output:
EM Algorithm - E Step: e_step: Test OK EM Algorithm - M Step: m_step: Test OK EM Algorithm - Estimated priors: estimated_test_prior: Test OK Bayes Classifier: risks bayes_classifier: Test OK Without prior compensation; Loss: 0.13 With prior compensation; Loss: 0.09 Improvement: 0.04
python main.py test-cases/public/instances/instance_2.json --plot
EM Algorithm - E Step: e_step: Test OK EM Algorithm - M Step: m_step: Test OK EM Algorithm - Estimated priors: estimated_test_prior: Test OK Bayes Classifier: risks bayes_classifier: Test OK Without prior compensation; Loss: 0.43 With prior compensation; Loss: 0.37 Improvement: 0.06
python main.py test-cases/public/instances/instance_3.json --plot
EM Algorithm - E Step: e_step: Test OK EM Algorithm - M Step: m_step: Test OK EM Algorithm - Estimated priors: estimated_test_prior: Test OK Bayes Classifier: risks bayes_classifier: Test OK Without prior compensation; Loss: 0.13 With prior compensation; Loss: 0.1 Improvement: 0.03