Skip to content

Commit

Permalink
val split
Browse files Browse the repository at this point in the history
  • Loading branch information
bw4sz committed Aug 1, 2020
1 parent 4015d78 commit a1f1e85
Showing 1 changed file with 37 additions and 37 deletions.
74 changes: 37 additions & 37 deletions experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,24 @@

##Train
#Train see config.yml for tfrecords path with weighted classes in cross entropy
model.read_data(validation_split=False)
model.read_data(validation_split=True)
class_weight = model.calc_class_weight()

## Train subnetwork
experiment.log_parameter("Train subnetworks", True)
with experiment.context_manager("spatial_subnetwork"):
print("Train spatial subnetwork")
model.read_data(mode="submodel",validation_split=False)
model.read_data(mode="submodel",validation_split=True)
model.train(submodel="spatial", class_weight=[class_weight, class_weight, class_weight])

with experiment.context_manager("spectral_subnetwork"):
print("Train spectral subnetwork")
model.read_data(mode="submodel",validation_split=False)
model.read_data(mode="submodel",validation_split=True)
model.train(submodel="spectral", class_weight=[class_weight, class_weight, class_weight])

#Train full model
experiment.log_parameter("Class Weighted", True)
model.read_data(validation_split=False)
model.read_data(validation_split=True)
model.train(class_weight=class_weight)

#Get Alpha score for the weighted spectral/spatial average. Higher alpha favors spatial network.
Expand All @@ -60,44 +60,44 @@

##Evaluate
#Evaluation scores, see config.yml for tfrecords path
#y_pred, y_true = model.evaluate(model.val_split)
y_pred, y_true = model.evaluate(model.val_split)

#Evaluation accuracy
#eval_acc = keras_metrics.CategoricalAccuracy()
#eval_acc.update_state(y_true, y_pred)
#experiment.log_metric("Evaluation Accuracy",eval_acc.result().numpy())
eval_acc = keras_metrics.CategoricalAccuracy()
eval_acc.update_state(y_true, y_pred)
experiment.log_metric("Evaluation Accuracy",eval_acc.result().numpy())

#macro, micro = metrics.f1_scores(y_true, y_pred)
#experiment.log_metric("MicroF1",micro)
#experiment.log_metric("MacroF1",macro)
macro, micro = metrics.f1_scores(y_true, y_pred)
experiment.log_metric("MicroF1",micro)
experiment.log_metric("MacroF1",macro)

#Confusion matrix
#class_labels = {
#0: "Unclassified",
#1 : "Healthy grass",
#2 : "Stressed grass",
#3 : "Artificial turf",
#4 : "Evergreen trees",
#5 : "Deciduous trees",
#6 : "Bare earth",
#7 : "Water",
#8 : "Residential buildings",
#9 : "Non-residential buildings",
#10 : "Roads",
#11 : "Sidewalks",
#12 : "Crosswalks",
#13 : "Major thoroughfares",
#14 : "Highways",
#15 : "Railways",
#16 : "Paved parking lots",
#17 : "Unpaved parking lots",
#18 : "Cars",
#19 : "Trains",
#20 : "Stadium seat"
#}

#print("Unique labels in ytrue {}, unique labels in y_pred {}".format(np.unique(np.argmax(y_true,1)),np.unique(np.argmax(y_pred,1))))
#experiment.log_confusion_matrix(y_true = y_true, y_predicted = y_pred, labels=list(class_labels.values()), title="Confusion Matrix")
class_labels = {
0: "Unclassified",
1 : "Healthy grass",
2 : "Stressed grass",
3 : "Artificial turf",
4 : "Evergreen trees",
5 : "Deciduous trees",
6 : "Bare earth",
7 : "Water",
8 : "Residential buildings",
9 : "Non-residential buildings",
10 : "Roads",
11 : "Sidewalks",
12 : "Crosswalks",
13 : "Major thoroughfares",
14 : "Highways",
15 : "Railways",
16 : "Paved parking lots",
17 : "Unpaved parking lots",
18 : "Cars",
19 : "Trains",
20 : "Stadium seat"
}

print("Unique labels in ytrue {}, unique labels in y_pred {}".format(np.unique(np.argmax(y_true,1)),np.unique(np.argmax(y_pred,1))))
experiment.log_confusion_matrix(y_true = y_true, y_predicted = y_pred, labels=list(class_labels.values()), title="Confusion Matrix")

#Predict
predict_tfrecords = glob.glob("/orange/ewhite/b.weinstein/Houston2018/tfrecords/predict/*.tfrecord")
Expand Down

0 comments on commit a1f1e85

Please sign in to comment.