Introduction to Machine Learning Classification: K Nearest Neighbors - ECE-180D-WS-2023/Knowledge-Base-Wiki GitHub Wiki
Introduction to Machine Learning Classification: K-Nearest Neighbors
by Jonathan Lai
Introduction
Machine learning has been around for quite some time now, but has had many advancements in recent years. Applications of machine learning nowadays are plentiful and diverse, from automated application screening in the corporate world to autonomous navigation in modern day vehicles. ChatGPT, “an AI-powered chatbot,” which has been all the hype recently, also has its roots in machine learning (Hetler). Machine learning has many practical applications, but what exactly is machine learning and how does it work? In this article, I will discuss the fundamental concepts of machine learning and focus on introducing one of the basic machine learning algorithms known as K-Nearest Neighbors. I will also provide examples on the application of the algorithm as well as discuss the pros and cons.
Machine Learning Classification
While there are multiple categories of machine learning, classification tasks will be the main topic for this article. Our goal in classification is to determine to which type something belongs. For example, a classification problem could be: given an iris with known petal width and sepal length, is it a Setosa species or a Versicolor species? A machine learning model could predict the species of our iris after receiving the inputted values of petal width and sepal length. How is it able to do this? First, we need to discuss what a machine learning model is, what it means to train it, and ultimately how it's used for classification.
A machine learning model is one that can be “trained to recognize certain types of patterns (QuinnRadich).” Training the model requires a dataset and an algorithm to learn by. For the iris classification example above, our data would be a list of irises with their corresponding petal width and sepal lengths (the features) and their true species type (the labels). An algorithm then interprets this iris data, creating a trained model that could be used for predicting the label of unknown species based on given petal widths and sepal lengths. There are countless machine learning algorithms with their own ways to train and predict, but the one I will talk about today is k-Nearest Neighbors, or k-NN.
K-Nearest Neighbors Algorithm
The k-Nearest Neighbors machine learning algorithm is one that uses “proximity to make predictions about [an unknown] data point.” (IBM) That is, the algorithm relies on the heuristic that data points that are close to each other are likely to have the same label. Take a look at figure 1 and with your best intuition, what would you guess the shape of “?” to be?
Figure 1
Because the “?” data point lies close to other triangles, you would probably guess it's also going to be a triangle. This is exactly what the K-Nearest Neighbors algorithm does for classification. You might be wondering how the metric of “closeness” is computed. The most basic method is the Euclidean Distance, which calculates the length of the line segment between two points on a plane.
The ‘k’ in k-Nearest Neighbors represents how many known data points the algorithm should check before determining what the label is of the unknown data point. For example, if k is 1 (1-NN), the algorithm would assign the unknown data point the same label of its immediate nearest neighbor. If k is 3 (3-NN), the algorithm would look at the unknown point’s nearest three neighbors and determine the resulting label.
Figure 2 (the label of a data point is often referred to as its class)
Notice that it's not always the case that the nearest neighbors agree on the label. In these situations, a majority vote is done amongst the nearest neighboring points. In figure 2, we can see that the three points closest to the unknown data point (star) consists of two green points and one blue point. Because the majority is green, a 3-NN model would deem this unknown data point green. Then for a 6-NN model, the nearest 6 points to the star include 4 blue and 2 green, so it would classify the star as blue.
Real-World Example
The fundamental concept of KNN should now be clear, so it’s time to move on to see KNN implemented in a real-world classification task. Here we demonstrate using the MiniPlaces dataset, which consists of various images of different places such as playground and bedroom. In order to calculate an effective distance between two images, we can first get the RGB values at each pixel, calculate their distance using L2-Norm, and then take an averaged sum.
Figure 3
There are two kinds of classification tasks for this dataset: binary classification (indoor and outdoor) and multi-class classification (the actual scene classes such as playground). In the first task, KNN was able to achieve an accuracy of 69.5% when using the nearest 8 neighbors to determine the label. This is a relatively promising result. However, KNN is only able to predict with 14.9% accuracy in the multi-class classification task, showing that the algorithm may be too simple and straightforward to handle more complex cases.
Discussion
As shown above, KNN is a relatively simple machine learning algorithm to implement, and the ability to vary the parameter k makes tuning the kNN model straightforward. It is probably the very first machine learning algorithm that people learn. Furthermore, k-Nearest-Neighbors is a nonparametric model, which means it relies on no underlying assumptions about the data (like its distribution). This makes kNN very robust to noisy data.
However, there are also some key drawbacks to this algorithm. The first, and also the most important one, is that KNN is fast to train but too slow to test (make a prediction). The training phase is relatively fast since it only involves storing the entire training dataset in memory. However, when making predictions on new data, KNN needs to compute the distances between the new data point and every point in the training set. This can be computationally expensive, especially when dealing with large datasets with dimensionally complex data. In real-world applications, we can generally afford long training but not predicting. For instance, when using ChatGPT, no user will notice the fact that the model took months to train, but every user will notice the speed at which texts are generated. Taking hours to generate one response is definitely not acceptable.
Another major issue facing KNN is the “curse of dimensionality”. As the number of dimensions in the data increases, the distance between data points becomes increasingly uniform, making it difficult to distinguish between points that are close and those that are far away. This can lead to inaccurate predictions and poor performance of the KNN algorithm. For instance, the data in the two examples aforementioned to explain the algorithm only have a dimension of two, which is an optimal situation that hardly occurs in real world data. RGB images, when analyzed pixel-wise, have a dimension of 3 * width * height, while data obtained after the commonly used feature extraction process often have more than thousands.
Closing Remarks
K-Nearest-Neighbors is simple and fast to implement, and in many scenarios it predicts accurately enough. However, always keep in mind that it also has many downsides such as the inability to deal with high-dimensional data and the long inference time. Afterall, KNN is only one algorithm out of many techniques that can be used in a machine learning classification task. More methods include logistic regression, support vector machines, and trees. The user should always choose an appropriate model based on their research purpose and the data they have. This discussion is merely the tip of the iceberg about machine learning. Those who are interested in diving deeper should definitely explore k-Nearest-Neighbors and other classification techniques with real data in order to truly see what machine learning is capable of.
Citations
QuinnRadich. “What Is a Machine Learning Model?” Microsoft Learn, https://learn.microsoft.com/en-us/windows/ai/windows-ml/what-is-a-machine-learning-model
“What Is the K-Nearest Neighbors Algorithm?” IBM, https://www.ibm.com/topics/knn.
[Figure (2)] Jean-Christophe, Chouinard. “K-Nearest Neighbors (KNN) in Python.” JC Chouinard, https://www.jcchouinard.com/k-nearest-neighbors/.
Hetler, Amanda. “Bard vs. CHATGPT: What's the Difference?” WhatIs.com, TechTarget, 8 Feb. 2023, https://www.techtarget.com/whatis/feature/Bard-vs-ChatGPT-Whats-the-difference.