@dataclass
class SiglipVisionConfig:
num_channels: int = 3
image_size: int = 224
patch_size: int = 16
num_attention_heads: int = 12
hidden_size: int = 768
attention_dropout: float = 0.0
class SiglipAttention(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(self, hidden_states):
# the hidden states are the embeddings of the patches, so (batch_size, num_patches, embed_dim)
B, T, C = hidden_states.shape
q_states = self.q_proj(hidden_states)
k_states = self.k_proj(hidden_states)
v_states = self.v_proj(hidden_states)
q_states = q_states.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
k_states = k_states.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
v_states = v_states.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
attn_weights = (q_states @ k_states.transpose(-2, -1)) * (1.0 / math.sqrt(k_states.size(-1)))
attn_weights = F.softmax(attn_weights, dim=-1).to(q_states.dtype)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn_outs = attn_weights @ v_states
attn_outs = attn_outs.transpose(1, 2)
attn_outs = attn_outs.reshape(B, T, C).contiguous()
attn_outs = self.out_proj(attn_outs)
return attn_outs
batch_size = 1
num_patches = 196
embed_dim = 768
hidden_states = torch.randn(batch_size, num_patches, embed_dim)
config = SiglipVisionConfig(
attention_dropout=0.0,
num_attention_heads=12,
hidden_size=768
)
attention = SiglipAttention(config)
output = attention(hidden_states)
print(f"Input shape: {hidden_states.shape}")
print(f"Output shape: {output.shape}")