|
30 | 30 | }
|
31 | 31 |
|
32 | 32 | estimators['fm-3'] = clone(estimators['fm-2']).set_params(degree=3)
|
| 33 | +estimators['fm-2-ada'] = clone(estimators['fm-2']).set_params( |
| 34 | + solver='adagrad', learning_rate=0.01, max_iter=20) |
| 35 | +estimators['fm-3-ada'] = clone(estimators['fm-3']).set_params( |
| 36 | + solver='adagrad', learning_rate=0.01, max_iter=20 |
| 37 | +) |
33 | 38 | estimators['polynet-3'] = (clone(estimators['polynet-2'])
|
34 | 39 | .set_params(degree=3, n_components=10))
|
35 | 40 |
|
36 | 41 | if __name__ == '__main__':
|
37 | 42 | data_train = fetch_20newsgroups_vectorized(subset="train")
|
38 | 43 | data_test = fetch_20newsgroups_vectorized(subset="test")
|
39 |
| - X_train = sp.csc_matrix(data_train.data) |
40 |
| - X_test = sp.csc_matrix(data_test.data) |
| 44 | + X_train_csc = sp.csc_matrix(data_train.data) |
| 45 | + X_test_csc = sp.csc_matrix(data_test.data) |
| 46 | + X_train_csr = sp.csr_matrix(data_train.data) |
| 47 | + X_test_csr = sp.csr_matrix(data_test.data) |
41 | 48 |
|
42 | 49 | y_train = data_train.target == 0 # atheism vs rest
|
43 | 50 | y_test = data_test.target == 0
|
44 | 51 |
|
45 | 52 | print("20 newsgroups")
|
46 | 53 | print("=============")
|
47 |
| - print("X_train.shape = {0}".format(X_train.shape)) |
48 |
| - print("X_train.format = {0}".format(X_train.format)) |
49 |
| - print("X_train.dtype = {0}".format(X_train.dtype)) |
| 54 | + print("X_train.shape = {0}".format(X_train_csr.shape)) |
| 55 | + print("X_train.dtype = {0}".format(X_train_csr.dtype)) |
50 | 56 | print("X_train density = {0}"
|
51 |
| - "".format(X_train.nnz / np.product(X_train.shape))) |
| 57 | + "".format(X_train_csr.nnz / np.product(X_train_csr.shape))) |
52 | 58 | print("y_train {0}".format(y_train.shape))
|
53 |
| - print("X_test {0}".format(X_test.shape)) |
54 |
| - print("X_test.format = {0}".format(X_test.format)) |
55 |
| - print("X_test.dtype = {0}".format(X_test.dtype)) |
| 59 | + print("X_test {0}".format(X_test_csr.shape)) |
| 60 | + print("X_test.dtype = {0}".format(X_test_csr.dtype)) |
56 | 61 | print("y_test {0}".format(y_test.shape))
|
57 | 62 | print()
|
58 | 63 |
|
|
62 | 67 |
|
63 | 68 | for name, clf in sorted(estimators.items()):
|
64 | 69 | print("Training %s ... " % name, end="")
|
| 70 | + if 'ada' in name: |
| 71 | + X_train, X_teest = X_train_csr, X_test_csr |
| 72 | + else: |
| 73 | + X_train, X_test = X_train_csc, X_test_csc |
65 | 74 | t0 = time()
|
66 | 75 | clf.fit(X_train, y_train)
|
67 | 76 | train_time[name] = time() - t0
|
|
0 commit comments