#How generalizable can I expect a CNN to be?

5 messages · Page 1 of 1 (latest)

neat briar
#

I'm in the process of labeling a dataset, and I'm trying to build up some intuition about how CNNs generalize, so that I can get a sense of what types of images I should focus on, and what sort of performance I should expect.

The end goal is to use this dataset to train a CNN to classify images into two classes 'room within a house' and 'not a room in a house'. I have a dataset set of tens of millions of unlabeled images, that need to be filtered down to just images of residential room interiors, and I'm trying to automate this process.

This is a task that is easy for most pictures, but has tricky edge cases. For example, any image that is taken outdoors we can throw out right away. And a bedroom, or a closet full of clothes, those are usually very recognizable rooms in a home. But a room with a desk and an office chair is harder, even for a human, because it could be a home office or in an office building. The dataset is extremely varied, with rooms containing all sorts of things, some that provide obvious clues and some that do not.

I manually labeled about 10k images, and used them to train a ResNet18, and it achieves about 80% validation accuracy. There are tricky cases, but nevertheless I'm pretty sure a human could achieve high 90s on this dataset, so I'm hoping to improve that.

Currently, I suspect the issue is a kind of class imbalance. My dataset has an equal number of 'residential' and 'nonresidential' images, however there are other classes I could consider balancing. For example, I could try to collect an equal number of all types of rooms. The problem is that there's an almost infinite long tail of increasingly unique rooms, making this a daunting task! And so I'm wondering, as my dataset grows, should I expect my CNN to learn more generalizable features automatically, or do I need to fastidiously identify and balance subtle classes?

#

Here's an analogous, but simpler, example, which I'd like to understand:

Lets say I have a dataset of red circles and green squares. And lets say that what I care about is the square-ness or circle-ness of my images, and the fact that all the circles are red and all the squares are green is a coincidence. If I were to train a CNN on this hypothetical dataset, it would presumably be able to partition this dataset quite easily. However, the model wouldn't know that what I care about is shape, and so the function it learns may depend heavily on color.

If I were then to pass an image of a blue circle through this model, what should I expect to happen? Should I expect the model to have any capability of understanding a blue circle is more like a red circle than a green square? Or would the blue circle be more or less meaningless to the model?

If I then included that single blue circle in the dataset, it probably wouldn't help all that much, because most of the training samples would contain only red circles and green squares. I'd probably have to add many blue circle samples in order to get good performance.

But what happens when more, differently colored circles are added? Or other types of rounded shapes? Would they all need to be balanced? (And if so, how does this not result in datasets needing to be almost infinitely large?)

wraith vessel
#

If I were then to pass an image of a blue circle through this model, what should I expect to happen?

The models one possible output could be Red circle. For CNN's its easier to learn colors, as compared to shapes as it has to perform and learn various edge detectors and pattern detectors, hence takes the easier path as there's no change in color.

#

But what happens when more, differently colored circles are added? Or other types of rounded shapes?

With more varied colors in the dataset, now the model doesnt take the shortcut by skipping edge detection and pattern recognition. Hence would achieve better results for circle classification

#

how does this not result in datasets needing to be almost infinitely large?

I dont think the dataset needs to be infinitely large, this a vary simple usecase as there are no major variations in the shapes or patterns to be learnt. Actually huge data doesn't necessarily do the work, for this usecase, the model could converge easily learning the shape