This write-up is basically my interpretation of this excellent video explanation by Yannic Kilcher. Go check it out.

DINOv2 is a self-supervisied pretraining method for visual transformers. It is much more complex to write and explain compared to ViT pretraining. However, DINO gets a much better performance with a given parameter, data and compute budget.

Remember, DINO is a method for training. The model that DINO is applied to is a vision transformer. The end result of DINO just like any other pretraining model, is a really good embedding model. If course, DINO-pretrained model can be used for fine-tuning tasks too.

DINO-trained ViTs are particularly good as visual embedding models. Such ViTs can support linear classifier or vector search algorithms on top of them.

DINO means self-distilation with no labels.

What happens in DINO is this:

# gs, gt student and teacher models (same models)
# C: center of augmentation
# tps, tpt: student and teacher softmax temperatures (student's is higher)
# alpha, beta: teacher and center momentum rates

gt.params = gs.params # init "teacher" from "student"
for x in dataloader:
   x1, x2 = aug(x), aug(x) # Perform random augs on image x

   s1, s2 = gs(x1), gs(x2) # pass x1,2 aug into "student"
   t1, t2 = gt(x1), gt(x2) # pass x1,2 aug into "teacher"

   loss = (H(s1, t2) + H(s2, t1))/2
   update(gs) # SGD

   gt.params = alpha * gt.params + (1-alpha) * gs.params
   C = alpha * C + (1-beta) * cat([t1,t2]).mean(dims=0)

def H(s, t):
  t = t.detach() # cut off teacher gradient
  s = softmax(s / tps, dim=1) 
  t = softmax( (t - C) / tpt, dim=1) # Center and sharpen
  loss = - (t * log(s)).sum(dim=1).mean()
  return loss

There’s a lot to unpack here.

First, we have two identical models at the start. We call them ‘student’ and ‘teacher’ – terrible naming in my opinion, but alternatives are even worse. At any rate, teacher gt and student gs are same random transformer ViT models at the start.

Next, we get an image of fixed shape from database. ViT’s need fixed-shape images. Say, this image is 224 x 224 RGB image of a cat. This is stored in tensor x.

aug is a function that outputs random image augmentations. x1 and x2 are augmented versions of x. x1 and x2 could be small, non-overlapping crops of x even.

Now, both x1, x2 pass into both networks, and both teacher and student use softmax in the end to create a probability distribution.

We calculate the loss for identity. We want the two models to output the same exact probability distributions, that’s what -t*log(s) does. We penalize deviations between student and teacher, but we only penalize the student. We only update the student with update(gs).

The teacher just tosses the gradient away and instead does two extra steps. First, from teacher’s logits we subtract C, and use lower temperature tpt for softmax. This makes teacher’s distrbution more “peak”-y. The student distribution is more spread out due to higher temperature and it also doesn’t use the centering C.

Teacher parameters are then updated with student parameters via an exponentially moving average, with gt.params = alpha * gt.params + (1-alpha) * gs.params.

I hear you asking: “but Tornike why this setup – complex as it is – not just collapse to both models spitting out a uniform distribution and thus collapsing the to lowest loss and useless model?”

So, what prevents the models from agreeing on outputting a uniform distribution?

I don’t know. It seems that the combination using C and tst prevents that. A nice article about this explans why we can avoid collpse:

Mode collapse is prevented precisely because all samples in the mini-batch cannot take on the same value after batch normalization

As far as I understand, subtracting the C from current teacher representation is precisely what prevents the collapse. For collapse to happen, a model has to output the exact same distribution for all augmented inputs. However, if we subtract C from all representations and also sharpen the distribution, the collapsed state becomes a really unstable parameter setup. Any slight parameter deviation from the collapsed state would cause the models to diverge into non-collapsed states. Basically, we rig the game so that approaching collapse becomes harder and harder.