Skip to content

Commit

Permalink
Instance recognition: Siamese VAE does better
Browse files Browse the repository at this point in the history
  • Loading branch information
woodRock committed Oct 7, 2024
1 parent d001286 commit 51a3da1
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 12 deletions.
Binary file modified code/siamese/best_model.pth
Binary file not shown.
Binary file modified code/siamese/best_vae_model.pth
Binary file not shown.
40 changes: 28 additions & 12 deletions code/siamese/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@ def __init__(self, input_dim, hidden_dim, latent_dim):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Dropout(0.1)
)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
self.fc_mu = nn.Linear(hidden_dim // 2, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim // 2, latent_dim)

def forward(self, x):
h = self.encoder(x)
Expand All @@ -28,10 +33,15 @@ class VAEDecoder(nn.Module):
def __init__(self, latent_dim, hidden_dim, output_dim):
super().__init__()
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(latent_dim, hidden_dim // 2),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim // 2, hidden_dim),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, output_dim)
)

Expand All @@ -43,7 +53,12 @@ def __init__(self, input_dim, hidden_dim, latent_dim, num_classes):
super().__init__()
self.encoder = VAEEncoder(input_dim, hidden_dim, latent_dim)
self.decoder = VAEDecoder(latent_dim, hidden_dim, input_dim)
self.classifier = nn.Linear(latent_dim, num_classes)
self.classifier = nn.Sequential(
nn.Linear(latent_dim, hidden_dim // 2),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim // 2, num_classes)
)

def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
Expand All @@ -68,6 +83,7 @@ def classify(self, x):
z = self.reparameterize(mu, logvar)
return self.classifier(z)

# The rest of the code remains the same
def vae_loss(x, x_recon, mu, logvar):
recon_loss = nn.functional.mse_loss(x_recon, x, reduction='sum')
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
Expand Down Expand Up @@ -145,18 +161,18 @@ def main():

# Model parameters
input_dim = next(iter(train_loader))[0].shape[1] # Get input dimension from data
hidden_dim = 256
latent_dim = 128
hidden_dim = 512
latent_dim = 256
num_classes = 2 # Assuming binary classification, adjust if needed
learning_rate = 1e-4
learning_rate = 1e-3

model = ContrastiveVAE(input_dim, hidden_dim, latent_dim, num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

num_epochs = 200
best_val_accuracy = 0
# Contrastive loss, VAE loss, Cross entropy, Balanced accuracy score
alpha, beta, gamma, delta = 0.3, 0.3, 0.2, 0.2 # Weights for different loss components
alpha, beta, gamma, delta = 0.5, 0.0, 0.5, 0.0 # Weights for different loss components

for epoch in range(num_epochs):
train_loss, train_accuracy, train_balanced_accuracy = train_epoch(model, train_loader, optimizer, device, alpha, beta, gamma, delta)
Expand Down

0 comments on commit 51a3da1

Please sign in to comment.