-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_db.py
286 lines (248 loc) · 10.1 KB
/
test_db.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
"""
Copyright Jeremy Nation <jeremy@jeremynation.me>.
Licensed under the GNU Affero General Public License (AGPL) v3.
Test src.db.
Optionally set the DB_DUMP_DIR environment variable to $DIR for TestDb.test_int to dump a modified
test_conn to $DIR/books.sql.
"""
import os
import shutil
from pathlib import Path
from sqlite3 import Connection, Row, connect
from typing import Any, Optional
from unittest.mock import patch
from src.command import Args, Command
from src.db import get_cmd
from src.util import (
LOG_LEVEL,
SchemaItemTypes,
get_db_fn,
get_log_records,
get_schema_fn,
read_text,
)
from tests.base import (
EXAMPLE_LIBRARY_DIR,
PLACEHOLDER_DIR_STR,
USER_SQL_FN,
VALID_DB_SQL_FN,
)
from tests.test_command import CommandTestCase
def _dump_db(conn: Connection, dump_fn: Path, temp_dir_str: str) -> None:
"""Dump conn to dump_fn and replace temp_dir_str with PLACEHOLDER_DIR_STR.
Use this to update the valid dump reference.
"""
rows = conn.execute(
"SELECT name FROM sqlite_schema WHERE type='view' ORDER BY name"
).fetchall()
new_tables = []
for row in rows:
new_tables.append(f"mat_{row['name']}")
conn.execute(f"CREATE TABLE {new_tables[-1]} AS SELECT * FROM {row['name']}")
# pylint: disable=import-outside-toplevel
from pathlib import PurePosixPath
from src.util import write_text
# pylint: enable=import-outside-toplevel
joiner = str(PurePosixPath(Path(PLACEHOLDER_DIR_STR) / "x"))[:-1]
splitter = str(Path(temp_dir_str) / "x")[:-1]
write_text(
dump_fn,
"\n".join([joiner.join(line.split(splitter)) for line in conn.iterdump()]),
)
for table in new_tables:
conn.execute(f"DROP TABLE {table}")
def _get_relative_to(old_path_str: str, new_path_str: str) -> str:
"""Replace old_path_str with new_path_str."""
return str(new_path_str / Path(old_path_str).relative_to(PLACEHOLDER_DIR_STR))
class TestDb(CommandTestCase):
"""Test the command."""
CMD: Command
@classmethod
def setUpClass(cls) -> None:
cls.CMD = get_cmd()
super().setUpClass()
def _assert_view_data_correct(self, conn: Connection) -> None:
"""Assert that views in conn match materialized views, then drop materialized views."""
rows = conn.execute(
"SELECT name FROM sqlite_schema WHERE type='view' ORDER BY name"
).fetchall()
for row in rows:
mat_table = f"mat_{row['name']}"
self.assertSequenceEqual(
conn.execute(f"SELECT * FROM {mat_table}").fetchall(),
conn.execute(f"SELECT * FROM {row['name']}").fetchall(),
)
conn.execute(f"DROP TABLE {mat_table}")
def _get_conn(self, db_fn: str) -> Connection:
"""Get connection and close it during test cleanup."""
conn = connect(db_fn)
self.addCleanup(conn.close)
return conn
def _get_sql_data(self, conn: Connection) -> dict[str, Any]:
"""Get sql data from conn."""
rows = conn.execute(
"SELECT name, sql FROM sqlite_schema WHERE type IN ('table', 'view') ORDER BY name"
).fetchall()
sql_data = {
row["name"]: {
"sql": row["sql"],
"data": conn.execute(f"SELECT * FROM {row['name']}").fetchall(),
}
for row in rows
}
self.assertTrue(sql_data)
return sql_data
def _standardize_conn(self, conn: Connection, new_path_str: str) -> None:
"""Standardize data in conn."""
for item in self.schema:
if isinstance(item, SchemaItemTypes.File):
rows = conn.execute(
f"""
SELECT file_full_path, metadata_directory, book_pkey, file_name
FROM book_{item.name}
"""
).fetchall()
for row in rows:
conn.execute(
f"""
UPDATE book_{item.name}
SET file_full_path=:file_full_path, metadata_directory=:metadata_directory
WHERE book_pkey=:book_pkey AND file_name=:file_name
""",
{
"file_full_path": _get_relative_to(
row["file_full_path"], new_path_str
),
"metadata_directory": _get_relative_to(
row["metadata_directory"], new_path_str
),
"book_pkey": row["book_pkey"],
"file_name": row["file_name"],
},
)
rows = conn.execute("SELECT pkey, metadata_directory FROM book").fetchall()
for row in rows:
conn.execute(
"UPDATE book SET metadata_directory=:metadata_directory WHERE pkey=:pkey",
{
"metadata_directory": _get_relative_to(
row["metadata_directory"], new_path_str
),
"pkey": row["pkey"],
},
)
@patch("src.db.db._BATCH_SIZE", 3)
def _test_main(
self, *, use_uuid_key: bool, db_dump_dir_str: Optional[str] = None
) -> None:
"""General method for testing db command."""
db_fn = get_db_fn(self.l_dirs[0])
db_fn.touch()
main_args = (
[Args.DIR_VARS.opt, "name1", ".", Args.DIR_VARS.opt, "name2", "."]
+ [Args.LIBRARY_DIRS.opt]
+ [str(l_dir) for l_dir in self.l_dirs]
+ [Args.USER_SQL_FILE.opt, str(USER_SQL_FN)]
)
if use_uuid_key:
main_args.append(Args.USE_UUID_KEY.opt)
with self.assertLogs(level=LOG_LEVEL) as cm:
self._run_cmd_main(main_args)
self.assertSequenceEqual(
[
f"Using schema from '{get_schema_fn(self.l_dirs[0])}'.",
f"Collecting book data and assigning {'UUID' if use_uuid_key else 'integer'} keys.",
f"Overwriting existing database file '{db_fn}'.",
f"Creating '{db_fn}'.",
"Inserted data type 'authors'.",
"Inserted 3 books into database.",
"Inserted 4 books into database.",
f"Running user SQL file '{USER_SQL_FN}'.",
f"Finished creating '{db_fn}'.",
],
get_log_records(cm),
)
test_conn = self._get_conn(str(db_fn))
test_conn.row_factory = Row
if db_dump_dir_str is not None:
_dump_db(
test_conn,
Path(db_dump_dir_str) / VALID_DB_SQL_FN.name,
str(self.l_dirs[0]),
)
test_data = self._get_sql_data(test_conn)
valid_conn = self._get_conn(":memory:")
valid_conn.row_factory = Row
valid_conn.executescript(read_text(VALID_DB_SQL_FN))
self._assert_view_data_correct(valid_conn)
with valid_conn:
self._standardize_conn(valid_conn, str(self.l_dirs[0]))
valid_data = self._get_sql_data(valid_conn)
self.assertEqual(valid_data.keys(), test_data.keys())
for v_data, t_data in zip(valid_data.values(), test_data.values()):
if use_uuid_key:
self.assertEqual(len(v_data["data"]), len(t_data["data"]))
else:
self.assertEqual(v_data["sql"], t_data["sql"])
self.assertEqual(v_data["data"], t_data["data"])
def test_int(self) -> None:
"""Test db command with integer keys."""
self._test_main(use_uuid_key=False, db_dump_dir_str=os.getenv("DB_DUMP_DIR"))
def test_uuid(self) -> None:
"""Test db command with UUID keys."""
self._test_main(use_uuid_key=True)
def test_quickstart(self) -> None:
"""Test the quickstart example."""
l_dir = self.t_dir / "example_library_dir"
shutil.copytree(EXAMPLE_LIBRARY_DIR, l_dir)
with self.assertLogs(level=LOG_LEVEL) as cm:
self._run_cmd_main(
[Args.LIBRARY_DIRS.opt, str(l_dir)]
+ [Args.OUTPUT_DIR.opt, str(self.t_dir)]
)
db_fn = get_db_fn(self.t_dir)
self.assertSequenceEqual(
[
f"Using schema from '{get_schema_fn(l_dir)}'.",
"Collecting book data and assigning integer keys.",
f"Creating '{db_fn}'.",
"Inserted data type 'authors'.",
"Inserted 2 books into database.",
f"Finished creating '{db_fn}'.",
],
get_log_records(cm),
)
test_conn = self._get_conn(str(db_fn))
self.assertEqual(
2, test_conn.execute("SELECT count(*) FROM v_summary;").fetchone()[0]
)
def test_combos(self) -> None:
"""Test that the command runs without error, don't check output."""
self._run_arg_combos(
[
[[Args.DIR_VARS.opt, "name1", ".", Args.DIR_VARS.opt, "name2", "."]],
[[Args.LIBRARY_DIRS.opt] + [str(l_dir) for l_dir in self.l_dirs]],
[[Args.OUTPUT_DIR.opt, str(self.l_dirs[0])], []],
[[Args.SCHEMA.opt, str(get_schema_fn(self.l_dirs[0]))], []],
[[Args.USE_UUID_KEY.opt], []],
[[Args.USER_SQL_FILE.opt, str(USER_SQL_FN)], []],
]
)
# check short_opts
self._run_cmd_main(
[
Args.DIR_VARS.short_opt,
"name1",
".",
Args.DIR_VARS.short_opt,
"name2",
".",
]
+ [Args.LIBRARY_DIRS.short_opt]
+ [str(l_dir) for l_dir in self.l_dirs]
+ [Args.OUTPUT_DIR.short_opt, str(self.l_dirs[0])]
+ [Args.SCHEMA.short_opt, str(get_schema_fn(self.l_dirs[0]))]
)
def test_error_extra_arg(self) -> None:
"""Test that the command errors with an extra arg."""
self._test_error_extra_arg([Args.LIBRARY_DIRS.opt, str(self.l_dirs[0])])