import matplotlib.pyplot as plt# it should look all random, since the weights are random at the initialization# Visualize all patch embeddingspatches_viz = embeddings[0].detach().numpy() # Shape: [196, 768]plt.figure(figsize=(15, 8))plt.imshow(patches_viz, aspect='auto', cmap='viridis')plt.colorbar()plt.title('Visualization of All Patch Embeddings')plt.xlabel('Embedding Dimension')plt.ylabel('Patch Number')plt.show()
from transformers import SiglipVisionModel as HFSiglipVisionModelour_state_dict = embd.state_dict()hf_state_dict = {k.replace("vision_model.embeddings.", ""): v for k, v in vision_model.state_dict().items() if"vision_model.embeddings."in k}our_state_dict.update(hf_state_dict)embd.load_state_dict(our_state_dict)with torch.no_grad(): our_output = embd(image_tensor) hf_output = vision_model.vision_model.embeddings(image_tensor)print("Max difference between our output and HF output:", torch.max(torch.abs(our_output - hf_output))) # =0, so they match!
Max difference between our output and HF output: tensor(0.)
import mathclass Head(nn.Module):""" A single head of the multi-head attention """def__init__(self, n_in, n_head, context_length):super().__init__()self.head_size = n_headself.key = nn.Linear(n_in, n_head, bias=False)self.query = nn.Linear(n_in, n_head, bias=False)self.value = nn.Linear(n_in, n_head, bias=False)def forward(self, x): B, T, C = x.shape k =self.key(x) q =self.query(x) v =self.value(x) wei = (q @ k.transpose(-2, -1)) * (1.0/ math.sqrt(self.head_size)) wei = F.softmax(wei, dim=-1) out = wei @ vreturn outclass MultiHeadAttention(nn.Module):""" Multi-head attention implementation with concatenating every head's output"""def__init__(self, num_head, n_in, head_size, context_length):super().__init__()self.head_size = head_sizeself.num_head = num_headself.heads = [Head(n_in, head_size, context_length) for _ inrange(num_head)]self.proj = nn.Linear(n_in, n_in)def forward(self, x): out = [h(x) for h inself.heads] out = torch.concat(out, -1) out =self.proj(out)return out
@dataclassclass SiglipVisionConfig: num_channels: int=3 image_size: int=224 patch_size: int=16 num_attention_heads: int=12 hidden_size: int=768# `embed_dim` --> `hidden_size`, just renamed it. attention_dropout: float=0.0 intermediate_size: int=3072 layer_norm_eps: float=1e-6class SiglipVisionEmbeddings(nn.Module):def__init__(self, config: SiglipVisionConfig):super().__init__()self.config = configself.num_channels = config.num_channelsself.embed_dim = config.hidden_sizeself.image_size = config.image_sizeself.patch_size = config.patch_sizeself.patch_embedding = nn.Conv2d( in_channels=self.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, # there won't be any overlapping since the stride is equal to the kernel size padding="valid", )self.num_patches = (self.image_size //self.patch_size) **2# initially the images are square, so the patch sizes.self.num_positions =self.num_patches # this is the number of positions in the sequenceself.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)self.register_buffer( # this is a buffer for the position ids, which will be a tensor of shape [1, num_patches]"position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False, # this is a buffer, so it won't be updated during the forward pass )def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: B, C, H, W = pixel_values.shape patch_embeds =self.patch_embedding(pixel_values) embeddings = patch_embeds.flatten(start_dim=2, end_dim=-1) embeddings = embeddings.transpose(1, 2) embeddings = embeddings +self.position_embedding(self.position_ids) # [batch_size, # patches, embed_dim]return embeddings