EDA for Image Classification

Dana Rausch
Geek Culture
Published in
4 min readApr 11, 2021

--

Photo by Corina Rainer on Unsplash

Image classification has been my favorite project during my time in a data science program, but EDA isn’t as straightforward as with other models. My cohort-mate and I found ourselves googling ideas and reading blog posts, unsure how to move forward. How are we supposed to analyze thousands of images to find patterns or learn … well, anything about the dataset we didn’t already know?

This required going back to basics. Putting aside all the fun, funky EDA we’ve learned and focusing instead on what will affect the overall model. What do we need to know to make sure our CNN will perform as well as possible? And what do we need to know as data scientists about our data to better understand the model itself?

Identifying Class Imbalance

Identifying class imbalance is an easy first step. We’re going to find the number of images assigned to each class and plot them in a bar chart to easily identify any imbalance. Imbalance in a CNN could result in poor performance for the class with less representation, impacting overall performance.

**The below code assumes that you’ve split images into folders based on class and have created variables for each directory.

First, we’ll need to create a dictionary with class name as key and number of images as the value. Below we’re pulling number of images straight from the folders in which they’re stored.

number_classes = {'Class_0': len(os.listdir(healthy_dir)),'Class_1': len(os.listdir(blight_dir)),'Class_2': len(os.listdir(gray_dir)),'Class_3': len(os.listdir(rust_dir))}

Next, plot a simple bar chart using keys and values for the axis.

plt.bar(number_classes.keys(), number_classes.values(), width = .5);plt.title("Number of Images by Class");plt.xlabel('Class Name');plt.ylabel('# Images');
A simple bar chart to quickly identify class imbalance.

Plotting Image Size

Consistent image size is crucial for deep learning, mismatched matrices will bring your project to a quick stop. Visualizing raw image size can also help you understand your dataset better. So let’s get to it!

We’re going to create another simple dictionary — you could create one dictionary to use for the chart above as well as below, but I wanted to show a different option that doesn’t require variables for each folder directory. We’re also going to set up a function that will return image dimensions.

directories = {'Class_0': 'data/Class_0/','Class_1': 'data/Class_1/','Class_2': 'data/Class_2/','Class_3': 'data/Class_3/'}
def get_dims(file):
'''Returns dimenstions for an RBG image'''im = Image.open(file)arr = np.array(im)h,w,d = arr.shapereturn h,w

The next step will be to create a loop that will iterate through folders in the dictionary and return dimensions of each image which will be added to a DataFrame for plotting.

for n,d in directories.items():filepath = dfilelist = [filepath + f for f in os.listdir(filepath)]dims = bag.from_sequence(filelist).map(get_dims)with diagnostics.ProgressBar():dims = dims.compute()
dim_df = pd.DataFrame(dims, columns=['height', 'width'])sizes = dim_df.groupby(['height', 'width']).size().reset_index().rename(columns={0:'count'})sizes.plot.scatter(x='width', y='height');plt.title('Image Sizes (pixels) | {}'.format(n))

Viewing a Sampling of Images

Finally, let’s take a look at a handful of images from each class. Although this isn’t crucial for the model itself, it will help you familiarize yourself with the data. The more familiar with the data you are, the better you’ll understand model outputs allowing for intelligent iterations and, ultimately, a much smarter model.

I’d like to call out that an amazing cohort-mate of mine wrote the code below, check out her Medium page here 😊

Set up variables that will iterate over images in a folder. A variable will need to be created for each class.

# Show images displayed 4x4nrows = 4ncols = 4# Index for iterating over imagespic_index = 0pic_index += 8next_blight_pix = [os.path.join(blight_dir, fname)for fname in train_blight_names[pic_index-8:pic_index]]

And set up a function that creates a plot to view our images with.

def show_image_sample(pic_directory):
'''display 4x4 images'''
fig = plt.gcf()fig.set_size_inches(ncols * 4, nrows * 4)for i, img_path in enumerate(pic_directory):sp = plt.subplot(nrows, ncols, i + 1)sp.axis('Off')img = mpimg.imread(img_path)plt.imshow(img)plt.show()show_image_sample(Class_0)
A sampling of images for each class can teach us a lot! This particular project was to identify disease in corn leaves.

As we see above, a small sampling of a single class can teach us a lot about our dataset. We can see that images are different sizes and shapes, we can also see that there’s quite a bit of variation in this class. Variation in background, brightness, orientation and much more can cause several issues in deep learning that we can prepare for simply by looking this sample of the dataset.

Hopefully this has given you a good start for EDA with image classification. Thanks for reading and good luck!

--

--