Skip to content
This repository was archived by the owner on Dec 6, 2023. It is now read-only.

Commit 016a940

Browse files
committed
update benchmark script
1 parent 74f03d0 commit 016a940

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

benchmarks/bench_20newsgroups.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,29 +30,34 @@
3030
}
3131

3232
estimators['fm-3'] = clone(estimators['fm-2']).set_params(degree=3)
33+
estimators['fm-2-ada'] = clone(estimators['fm-2']).set_params(
34+
solver='adagrad', learning_rate=0.01, max_iter=20)
35+
estimators['fm-3-ada'] = clone(estimators['fm-3']).set_params(
36+
solver='adagrad', learning_rate=0.01, max_iter=20
37+
)
3338
estimators['polynet-3'] = (clone(estimators['polynet-2'])
3439
.set_params(degree=3, n_components=10))
3540

3641
if __name__ == '__main__':
3742
data_train = fetch_20newsgroups_vectorized(subset="train")
3843
data_test = fetch_20newsgroups_vectorized(subset="test")
39-
X_train = sp.csc_matrix(data_train.data)
40-
X_test = sp.csc_matrix(data_test.data)
44+
X_train_csc = sp.csc_matrix(data_train.data)
45+
X_test_csc = sp.csc_matrix(data_test.data)
46+
X_train_csr = sp.csr_matrix(data_train.data)
47+
X_test_csr = sp.csr_matrix(data_test.data)
4148

4249
y_train = data_train.target == 0 # atheism vs rest
4350
y_test = data_test.target == 0
4451

4552
print("20 newsgroups")
4653
print("=============")
47-
print("X_train.shape = {0}".format(X_train.shape))
48-
print("X_train.format = {0}".format(X_train.format))
49-
print("X_train.dtype = {0}".format(X_train.dtype))
54+
print("X_train.shape = {0}".format(X_train_csr.shape))
55+
print("X_train.dtype = {0}".format(X_train_csr.dtype))
5056
print("X_train density = {0}"
51-
"".format(X_train.nnz / np.product(X_train.shape)))
57+
"".format(X_train_csr.nnz / np.product(X_train_csr.shape)))
5258
print("y_train {0}".format(y_train.shape))
53-
print("X_test {0}".format(X_test.shape))
54-
print("X_test.format = {0}".format(X_test.format))
55-
print("X_test.dtype = {0}".format(X_test.dtype))
59+
print("X_test {0}".format(X_test_csr.shape))
60+
print("X_test.dtype = {0}".format(X_test_csr.dtype))
5661
print("y_test {0}".format(y_test.shape))
5762
print()
5863

@@ -62,6 +67,10 @@
6267

6368
for name, clf in sorted(estimators.items()):
6469
print("Training %s ... " % name, end="")
70+
if 'ada' in name:
71+
X_train, X_teest = X_train_csr, X_test_csr
72+
else:
73+
X_train, X_test = X_train_csc, X_test_csc
6574
t0 = time()
6675
clf.fit(X_train, y_train)
6776
train_time[name] = time() - t0

0 commit comments

Comments
 (0)