Home

Training a Vision Model

Seeing as AI is all the rage these days, and I work for a company that does AI stuff, I should probably train an AI. Also, I've been wanting to build a robot that learns while running. Let's see how this goes.

A couple years back I read a couple papers about how they taught a robot to walk in 20 minutes of real-time learning. This sounded pretty good! Time to build some droids - but I never found the time to play with it. A couple days ago I finally found the time to play with it, and so started looking into the first part - training a vision model. The aim is that the model goes from image -> embedding space -> image, and you train on the difference between the images. The embedding space can then be forwarded to another part of the robot (eg a prediction model). This is (my understanding) of the world models architecture.

Because my aim was a real-time-learning system, I didn't want to prepare a massive dataset and train the vision model on that. Instead I want the vision model to be able to learn from a live video feed. My idea is simple and has probably been done a bunch before:

  1. (populate a fixed size dataset from the video feed)
  2. Train the model for one epoch on the current dataset
  3. Grab a new frame of video
  4. Evaluate the loss of the frame
  5. If the new frame has a higher loss than the worst image in the dataset, replace the best image in the dataset
  6. Repeat

This works surprisingly well. Initially the model outputs a black image. So images with high loss are light ones. As the model improves, light and dark images have equal loss, and so the dataset now contains both, and the model learns to output a light image when the input is light and dark when dark. This continues with high contrast images (eg the top is light and bottom is dark) and then color.

Initially it would take about 7 hours of training to get the model from zero to moving patches of color, but ater twiddling some hyperparameters it takes about 10 minutes to get OK. And then the improvements are very slow. But before you get your hopes up, even after 11 hours it still looks like this:

Why is it still so bad? I'm trying to train in real-time at 30FPS on an RTX 3070, so:

And of course, I have no idea really what I'm doing, I'm just messing around

So where to next? Well, this isn't as good as I'd hoped, so I need a bit more time on the vision model. I want to look at CNN's (rather than plain DNN's), VAE's and possibly vision transformers.