Skip to content

Commit

Permalink
Contrastive learning GP with cosine similarity
Browse files Browse the repository at this point in the history
  • Loading branch information
woodRock committed Oct 7, 2024
1 parent 8ae945e commit 6417d38
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 10 deletions.
9 changes: 9 additions & 0 deletions code/siamese/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Siamese network

## Genetic Programming

Run the following command to evaluate performance of Genetic Programming for contrastive learning.

```bash
python3 gp.py -nt 30 -g 200 -p 100
```
7 changes: 4 additions & 3 deletions code/siamese/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from deap import algorithms, base, creator, gp, tools
from sklearn.metrics import balanced_accuracy_score
Expand Down Expand Up @@ -122,9 +123,9 @@ def compile_trees(individual: List[gp.PrimitiveTree]) -> List[Callable]:
return [gp.compile(expr, pset) for expr in individual]

# Contrastive loss function
def contrastive_loss(output1: torch.Tensor, output2: torch.Tensor, label: float, margin: float = 1.0) -> torch.Tensor:
euclidean_distance = F.pairwise_distance(output1, output2)
loss = label * torch.pow(euclidean_distance, 2) + (1 - label) * torch.pow(torch.clamp(margin - euclidean_distance, min=0.0), 2)
def contrastive_loss(z1, z2, y, temperature=0.5):
similarity = nn.functional.cosine_similarity(z1, z2)
loss = y * torch.pow(1 - similarity, 2) + (1 - y) * torch.pow(torch.clamp(similarity - 0.1, min=0.0), 2)
return loss.mean()

# Evaluation function
Expand Down
44 changes: 37 additions & 7 deletions code/siamese/logs/results_0.log
Original file line number Diff line number Diff line change
@@ -1,12 +1,42 @@
INFO:__main__:gen nevals avg std min max
0 100 1.8426e+18 1.81428e+19 1.18593 1.82351e+20
INFO:__main__:1 75 3.95869e+08 2.32875e+09 1.15322 1.83026e+10
INFO:__main__:gen nevals avg std min max
0 100 0.536682 0.117834 0.371207 0.739102
INFO:__main__:1 83 0.437101 0.0844668 0.368868 0.7264
INFO:__main__:
Generation 1: Best Fitness = 1.1532
Generation 1: Best Fitness = 0.3689
Balanced accuracy: Train: 0.5000
Validation: 0.5000
INFO:__main__:2 72 466143 4.54385e+06 1.02204 4.56715e+07
INFO:__main__:2 82 0.409887 0.0679905 0.353061 0.712627
INFO:__main__:
Generation 2: Best Fitness = 1.0220
Balanced accuracy: Train: 0.5004
Generation 2: Best Fitness = 0.3531
Balanced accuracy: Train: 0.5000
Validation: 0.5000
INFO:__main__:3 77 0.401588 0.0607651 0.353061 0.718457
INFO:__main__:
Generation 3: Best Fitness = 0.3531
Balanced accuracy: Train: 0.5000
Validation: 0.5000
INFO:__main__:4 78 0.376162 0.0374512 0.346475 0.596387
INFO:__main__:
Generation 4: Best Fitness = 0.3465
Balanced accuracy: Train: 0.5000
Validation: 0.5000
INFO:__main__:5 76 0.3653 0.0428357 0.343384 0.634618
INFO:__main__:
Generation 5: Best Fitness = 0.3434
Balanced accuracy: Train: 0.5000
Validation: 0.5000
INFO:__main__:6 77 0.366268 0.0686927 0.343384 0.738176
INFO:__main__:
Generation 6: Best Fitness = 0.3434
Balanced accuracy: Train: 0.5000
Validation: 0.5000
INFO:__main__:7 82 0.366218 0.0583039 0.342481 0.706247
INFO:__main__:
Generation 7: Best Fitness = 0.3425
Balanced accuracy: Train: 0.5000
Validation: 0.5000
INFO:__main__:8 79 0.360883 0.054227 0.340341 0.707478
INFO:__main__:
Generation 8: Best Fitness = 0.3403
Balanced accuracy: Train: 0.5000
Validation: 0.5000

0 comments on commit 6417d38

Please sign in to comment.