-
Notifications
You must be signed in to change notification settings - Fork 0
/
hg_rule_extractor.py
583 lines (508 loc) · 24.2 KB
/
hg_rule_extractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
#coding: utf8
import sys
import math
import heapq
import operator
from rule_formatters import RuleFormatter, GrexRuleFormatter, CdecT2SRuleFormatter, CdecT2TRuleFormatter
from tree import TreeNode, NonTerminalNode, TerminalNode
from collections import defaultdict
from helpers import computeSpans, Alignment, compute_generations, Rule, Span, enumerate_subsets
from hypergraph import Hypergraph, NodeWithSpan, Edge
from itertools import izip
# Turns a line of input into a hypergraph.
# Returns a hypergraph, a log weight to be used when combining this HG with others,
# and a boolean that indicates whether this tree is one of a k-best list.
# Returns None on error, including some quirky cases caused by Berkeley parser.
# Input format is one of these for each sentence:
# Format 1 (just 1-best trees):
# tree1
# Format 2:
# log p(tree1 | sent) ( tree1 )
# log p(tree2 | sent) ( tree2 )
# ...
# (empty line)
# Format 3:
# log p(sent) log p(tree1, sent) ( tree1 )
# log p(sent) log p(tree2, sent) ( tree2 )
# ...
# (empty line)
# Note the blank line after each seentence's k-best list
# and the extra parentheses that need stripped
def hypergraph_from_line(line):
line = line.strip()
# Berkeley parser will output lines like
# "Don't have a 7-best tree" when you ask it
# for 10 best, but it only has 6.
if line.startswith('Don\'t have a'):
return None
is_one_of_kbest = True
parts = line.split('\t')
if len(parts) == 1:
score = 0.0
line, = parts
is_one_of_kbest = False
elif len(parts) == 2:
score, line = parts
score = float(score)
else:
sent_prob, joint_prob, line = parts
score = float(joint_prob) - float(sent_prob)
if score == '-Infinity' or line == '(())':
return None
score = math.exp(score)
# Strip the extra parens
# TODO: What if the trees have different root nodes?
if line.startswith('( (') and line.endswith(') )'):
line = line[2:-2]
tree = TreeNode.from_string(line)
return tree, score, is_one_of_kbest
# Input is a list of (tree, weight) pairs.
# Weights are normalized to sum to 1, then the trees are combined into a single hypergraph.
def combine_trees(trees_to_combine):
if len(trees_to_combine) == 0:
return None
hypergraphs_to_combine = []
total_scores = sum(score for _, score in trees_to_combine)
for tree, score in trees_to_combine:
if total_scores != 0.0:
score = score / total_scores
else:
score = 1.0 / len(trees_to_combine)
computeSpans(tree)
tree_hg = Hypergraph.from_tree(tree, score)
tree_hg.sanity_check()
hypergraphs_to_combine.append(tree_hg)
final_hypergraph = hypergraphs_to_combine[0]
for hypergraph in hypergraphs_to_combine[1:]:
final_hypergraph.combine(hypergraph)
return final_hypergraph
def read_tree_file(stream):
trees_to_combine = []
while True:
line = stream.readline()
if not line:
break
if '[' in line or ']' in line:
print >>sys.stderr, 'Square brackets found in input. Please escape these to -LSB- and -RSB-'
sys.exit(1)
line = line.decode('utf-8').strip()
if not line:
yield combine_trees(trees_to_combine)
trees_to_combine = []
continue
parts = hypergraph_from_line(line)
if parts == None:
yield None
continue
else:
tree, score, can_combine = parts
trees_to_combine.append((tree, score))
if not can_combine:
assert len(trees_to_combine) == 1
yield combine_trees(trees_to_combine)
trees_to_combine = []
if len(trees_to_combine) > 0:
yield combine_trees(trees_to_combine)
def read_string_file(stream):
while True:
line = stream.readline()
if '[' in line or ']' in line:
print >>sys.stderr, 'Square brackets found in input. Please escape these to -LSB- and -RSB-'
sys.exit(1)
if not line:
break
line = line.decode('utf-8').strip()
yield Hypergraph.from_surface_string(line)
# Reads the next line from a file stream representing alignments.
# TODO: Allow this to add probabilities to alignment links.
def read_alignment_file(stream):
while True:
line = stream.readline()
if not line:
break
line = line.decode('utf-8').strip()
alignment = Alignment.from_string(line)
yield alignment
# Determines whether source_node and target_node are node-aligned.
# Two nodes are aligned if the alignment links emanating from their terminals
# align only to terminals of the other, or to NULL.
# The one exception is if both nodes have no alignment links coming from their
# terminals at all. In this case the nodes are not considered to be aligned.
def are_aligned(source_span, target_span, source_terminals, target_terminals, s2t_word_alignments, t2s_word_alignments):
source_node_terminals = source_terminals[source_span.start : source_span.end]
target_node_terminals = target_terminals[target_span.start : target_span.end]
has_alignments = False
for terminal in source_node_terminals:
if len(s2t_word_alignments[terminal]) > 0:
has_alignments = True
break
if not has_alignments:
return False
for source_terminal in source_node_terminals:
for target_terminal in s2t_word_alignments[source_terminal]:
if target_terminal not in target_node_terminals:
return False
for target_terminal in target_node_terminals:
for source_terminal in t2s_word_alignments[target_terminal]:
if source_terminal not in source_node_terminals:
return False
return True
# Extracts all available rules from the node pair source_node and target_node
# Each rule will come from a pair of edges, one with head at source_node and
# the other with head at target_node.
# These two edges must have the same number of non-terminal children.
def extract_rules(source_node, target_node, s2t_node_alignments, t2s_node_alignments, source_root, target_root, max_rule_size, source_terminals, target_terminals, s2t_word_alignments, t2s_word_alignments):
for source_edge in source_node.get_child_edges(source_root):
if len(source_edge.tails) > max_rule_size:
continue
source_nt_count = len([node for node in source_edge.tails if not node.is_terminal(source_root)])
for target_edge in target_node.get_child_edges(target_root):
if len(target_edge.tails) > max_rule_size:
continue
target_nt_count = len([node for node in target_edge.tails if not node.is_terminal(target_root)])
if source_nt_count != target_nt_count:
continue
# Work out which of the source_node's children correspond to which of the target_node's
# Note: Each node should have at most 1 alignment
target_terminals = set([node for node in target_edge.tails if node.is_terminal(target_root)])
target_nonterminals = set([node for node in target_edge.tails if not node.is_terminal(target_root)])
child_alignments = []
for node in source_edge.tails:
possible_alignments = target_terminals if node.is_terminal(source_root) else target_nonterminals
child_alignments.append(list(possible_alignments & s2t_node_alignments[node]))
assert len(child_alignments[-1]) <= 1
# Build up lists of the parts that make up the source and target RHS
# index is the number of NT-NT pairs that have been added to the rule so far
# the rule_part_maps give, for each NT, what opposite-side NT it corresponds to along with its index
index = 1
s2t_rule_part_map = {}
t2s_rule_part_map = {}
unused_target_children = set(target_edge.tails)
has_unaligned_nt = False
for i, s in enumerate(source_edge.tails):
if len(child_alignments[i]) == 1:
t = child_alignments[i][0]
unused_target_children.remove(t)
if not s.is_terminal(source_root):
s2t_rule_part_map[s] = (t, index)
t2s_rule_part_map[t] = (s, index)
index += 1
elif not source_edge.tails[i].is_terminal(source_root):
has_unaligned_nt = True
break
has_unaligned_nt |= False in [node.is_terminal(target_root) for node in unused_target_children]
if has_unaligned_nt:
continue
weight = source_root.weights[source_edge] * target_root.weights[target_edge]
# At this point, all information defining the rule has been calculated.
# All that remains is to output it, which requires turning various bits
# of information into string form.
# Calculate the node-to-node and word-to-word alignments within this rule
alignments = []
for i, source_part in enumerate(source_edge.tails):
for j, target_part in enumerate(target_edge.tails):
if target_part in s2t_node_alignments[source_part] or source_node in t2s_node_alignments[target_part] or \
target_part in s2t_word_alignments[source_part] or source_node in t2s_word_alignments[target_part]:
alignments.append((i, j))
yield Rule(source_edge, target_edge, s2t_rule_part_map, t2s_rule_part_map, alignments, weight)
def find_best_minimal_alignment(node, target_nodes, taken_target_nodes, target_generations):
# only allow terminals to align to terminals, and NTs to align to NTs
target_nodes = set([target_node for target_node in target_nodes
if target_node.is_terminal_flag == node.is_terminal_flag])
if len(target_nodes) == 0:
return None
# A minimal alignment will have the smallest span of all aligned nodes
min_span_size = min([target_node.span.end - target_node.span.start for target_node in target_nodes])
target_nodes = set([target_node for target_node in target_nodes
if target_node.span.end - target_node.span.start == min_span_size])
# Whatever node we're aligned to must not have been previously aligned to something else
target_nodes = target_nodes - taken_target_nodes
if len(target_nodes) == 0:
return None
# If there is more than one, such as in a unary chain
# return the aligned node lowest in the tree, i.e. with max generation
return max(target_nodes, key=lambda target_node: target_generations[target_node])
# Removes non-minimal node alignments from a hypergraph.
# If a node is aligned to multiple opposite nodes,
# its minimal alignment is the smallest one span-wise.
# If that's a tie, the minimal alignment is the one lowest
# in the tree, i.e. the one with the highest generation.
def minimize_alignments(source_root, target_root, s2t, t2s):
if args.t2s:
target_generations = {}
for node in target_root.nodes:
target_generations[node] = 2 if node.is_terminal_flag else 1
target_generations[target_root.start] = 0
else:
target_generations = compute_generations(target_root)
taken_target_nodes = set()
minimal_s2t = defaultdict(set)
minimal_t2s = defaultdict(set)
for source_node in source_root.topsort():
target_nodes = s2t[source_node]
target_node = find_best_minimal_alignment(source_node, target_nodes, taken_target_nodes, target_generations)
if target_node is not None:
taken_target_nodes.add(target_node)
minimal_s2t[source_node].add(target_node)
minimal_t2s[target_node].add(source_node)
# Ensure the root nodes aren't aligned to anything other than each other
for node in minimal_s2t.iterkeys():
if target_tree.start in minimal_s2t[node]:
minimal_s2t[node].remove(target_root.start)
for node in minimal_t2s.iterkeys():
if source_tree.start in minimal_t2s[node]:
minimal_t2s[node].remove(source_root.start)
minimal_s2t[source_tree.start] = set([target_root.start])
minimal_t2s[target_tree.start] = set([source_root.start])
for k, v in minimal_s2t.iteritems():
assert len(v) <= 1
for k, v in minimal_t2s.iteritems():
assert len(v) <= 1
return minimal_s2t, minimal_t2s
def build_word_alignment_maps(source_terminals, target_terminals, alignment):
s2t_word_alignments = defaultdict(list)
t2s_word_alignments = defaultdict(list)
for s, t in alignment.links:
if s not in range(len(source_terminals)) or t not in range(len(target_terminals)):
print >>sys.stderr, 'Invalid alignment link: %d-%d' % (s, t)
s_node = source_terminals[s]
t_node = target_terminals[t]
s2t_word_alignments[s_node].append(t_node)
t2s_word_alignments[t_node].append(s_node)
return s2t_word_alignments, t2s_word_alignments
def build_node_alignment_maps(source_tree, target_tree, are_aligned, minimal_only=False):
s2t_node_alignments = defaultdict(set)
t2s_node_alignments = defaultdict(set)
for s_node in source_tree.nodes:
s2t_node_alignments[s_node] = set()
for t_node in target_tree.nodes:
if are_aligned(s_node.span, t_node.span):
s2t_node_alignments[s_node].add(t_node)
t2s_node_alignments[t_node].add(s_node)
# The roots of the two trees are always node-aligned, even when there are no alignment links
s2t_node_alignments[source_tree.start].add(target_tree.start)
t2s_node_alignments[target_tree.start].add(source_tree.start)
if minimal_only:
s2t_node_alignments, t2s_node_alignments = minimize_alignments(source_tree, target_tree, s2t_node_alignments, t2s_node_alignments)
return s2t_node_alignments, t2s_node_alignments
# Returns a set of spans in the target_tree that are well aligned to
# nodes in source_nodes, as given by the are_aligned function.
def find_aligned_spans(target_tree, source_nodes, are_aligned):
aligned_spans = set()
target_len = max(node.span.end for node in target_tree.nodes)
for span_size in range(1, target_len + 1):
for span_start in range(target_len - span_size + 1):
for node in source_nodes:
if are_aligned(node.span, Span(span_start, span_start + span_size)):
aligned_spans.add(Span(span_start, span_start + span_size))
break
return aligned_spans
# Generates all the (ordered) sets of nodes that can exactly cover the input span,
# using at most max_nt_count non-terminal nodes.
def find_child_sets(nodes, span, max_nt_count):
for node in nodes:
if node.span.start != span.start:
continue
if max_nt_count == 0 and not node.is_terminal_flag:
continue
if node.span.end == span.end:
if max_nt_count >= 1 or node.is_terminal_flag:
yield [node]
else:
if max_nt_count > 0:
max_nt_siblings = max_nt_count - (0 if node.is_terminal_flag else 1)
else:
max_nt_siblings = 0
for sibling_set in find_child_sets(nodes, Span(node.span.end, span.end), max_nt_siblings):
assert len([1 for child in sibling_set if not child.is_terminal_flag]) <= max_nt_siblings
child_set = [node] + sibling_set
nt_count = len([1 for child in child_set if not child.is_terminal_flag])
assert nt_count <= max_nt_count
yield child_set
def add_t2s_virtual_nodes(target_tree, source_tree, are_aligned):
aligned_spans = find_aligned_spans(target_tree, source_tree.nodes, are_aligned)
for aligned_span in sorted(aligned_spans, key=lambda span: span.end - span.start):
virtual_node = NodeWithSpan('X', aligned_span, False, True)
assert aligned_span != target_tree.start.span or virtual_node == target_tree.start
target_tree.nodes.add(virtual_node)
def add_t2s_virtual_edges(target_tree, source_tree, are_aligned):
aligned_spans = find_aligned_spans(target_tree, source_tree.nodes, are_aligned)
for aligned_span in sorted(aligned_spans, key=lambda span: span.end - span.start):
virtual_node = NodeWithSpan('X', aligned_span, False, True)
# Note: virtual_node should == target_tree.root when aligned_span covers the whole sentence
assert aligned_span != target_tree.start.span or virtual_node == target_tree.start
valid_nodes = set([node for node in target_tree.nodes if node.is_terminal_flag or node.span in aligned_spans])
for child_set in find_child_sets(valid_nodes, aligned_span, args.max_rule_size + 1000):
if virtual_node in child_set:
continue
virtual_edge = Edge(virtual_node, tuple(child_set))
target_tree.add(virtual_edge)
def add_experimental_virtual_edges(target_tree, source_tree, s2t_node_alignments, t2s_node_alignments, target_terminals):
def project(source_node):
alignments = s2t_node_alignments[source_node]
#assert len(alignments) <= 1 # TODO: Could unaligned words invalidate this?
return list(alignments)[0] if len(alignments) == 1 else None
# Derivation[source_node] will hold the minimal way(s) of representing source_node using minimal constituents.
# For terminals and well-aligned NTs, there is only one such way: using the node itself.
# For NTs that are not node aligned, we will find sets of minimally aligned children that cover source_node.
derivations = {}
for source_node in source_tree.topsort():
derivations[source_node] = []
if source_node.is_terminal_flag:
derivation = (source_node,)
derivations[source_node].append((derivation, []))
elif project(source_node) != None:
derivation = (source_node,)
derivations[source_node].append((derivation, []))
else:
for edge in source_tree.head_index[source_node]:
for subset in enumerate_subsets([derivations[tail] for tail in edge.tails]):
derivation = reduce(operator.add, [derivation for derivation, _ in subset])
skipped_edges = reduce(operator.add, [edges for _, edges in subset])
for node in derivation:
assert len(s2t_node_alignments[node]) >= 1 or node.is_terminal_flag
derivations[source_node].append((derivation, [edge] + skipped_edges))
for edge in source_tree.edges.copy():
source_head = edge.head
for target_head in s2t_node_alignments[source_head]:
for source_subset in enumerate_subsets([derivations[tail] for tail in edge.tails]):
source_tails = reduce(operator.add, [derivation for derivation, _ in source_subset])
composed_edge = Edge(source_head, source_tails)
skipped_edges = reduce(operator.add, [edges for _, edges in source_subset])
if len(skipped_edges) > 0:
composed_edge.composed_edges = tuple([edge] + skipped_edges)
composed_edge.is_composed = True
assert len(edge.composed_edges) == 0
if composed_edge != edge:
assert len(skipped_edges) > 0
source_tree.add(composed_edge)
for target_subset in enumerate_subsets([list(s2t_node_alignments[tail]) for tail in source_tails if not tail.is_terminal_flag]):
target_tails = target_subset
for i in range(*target_head.span):
is_included = False
for tail in target_tails:
if i >= tail.span.start and i < tail.span.end:
is_included = True
break
if not is_included:
target_tails.append(target_terminals[i])
target_tails = tuple(sorted(target_tails, key=lambda node: node.span.start))
virtual_edge = Edge(target_head, target_tails)
target_tree.add(virtual_edge)
return
for source_node in source_tree.topsort():
head = project(source_node)
if head == None:
print >>sys.stderr, str(source_node), 'is unaligned'
continue
else:
print >>sys.stderr, str(source_node), 'is aligned to', str(head)
for edge in source_tree.head_index[source_node]:
tails = []
valid = True
for tail in edge.tails:
projection = project(tail)
if projection is None:
valid = False
break
tails.append(projection)
if valid:
virtual_edge = Edge(head, tuple(tails))
target_tree.add(virtual_edge)
print >>sys.stderr, head, tails
# Takes two hypergraphs representing source and target trees, as well as a word
# alignment, and finds all rules extractable there from.
def handle_sentence(source_tree, target_tree, alignment, formatter):
# Identify the terminal nodes in both trees
source_terminals = sorted([node for node in source_tree.nodes if node.is_terminal_flag], key=lambda node: node.span.start)
target_terminals = sorted([node for node in target_tree.nodes if node.is_terminal_flag], key=lambda node: node.span.start)
# Build word alignment maps
s2t_word_alignments, t2s_word_alignments = build_word_alignment_maps(source_terminals, target_terminals, alignment)
# Define this little helper function
spans_are_aligned = lambda source_span, target_span: are_aligned(source_span, target_span, source_terminals, target_terminals, s2t_word_alignments, t2s_word_alignments)
source_tree.add_virtual_nodes_only(args.virtual_size, False)
if not args.t2s:
target_tree.add_virtual_nodes_only(args.virtual_size, False)
else:
add_t2s_virtual_nodes(target_tree, source_tree, spans_are_aligned)
#target_tree.add_virtual_nodes_only(1000, True, lambda nodes: 'X')
# Build node alignments maps
s2t_node_alignments, t2s_node_alignments = build_node_alignment_maps(source_tree, target_tree, spans_are_aligned, args.minimal_rules)
# Add virtual nodes and edges to the tree structures
source_tree.add_virtual_nodes(args.virtual_size, False)
if not args.t2s:
target_tree.add_virtual_nodes(args.virtual_size, False)
else:
aligned_spans = set([node.span for node in t2s_node_alignments.keys()])
if not args.minimal_rules:
in_aligned_spans = lambda source_span, target_span: target_span in aligned_spans
add_t2s_virtual_edges(target_tree, source_tree, in_aligned_spans)
# Add composed edges to the tree structures
if args.minimal_rules:
if not args.t2s:
s2t_aligned_nodes = set(node for node, alignments in s2t_node_alignments.iteritems() if len(alignments) > 0)
t2s_aligned_nodes = set(node for node, alignments in t2s_node_alignments.iteritems() if len(alignments) > 0)
source_tree.add_minimal_composed_edges(args.max_rule_size, s2t_aligned_nodes)
target_tree.add_minimal_composed_edges(args.max_rule_size, t2s_aligned_nodes)
else:
add_experimental_virtual_edges(target_tree, source_tree, s2t_node_alignments, t2s_node_alignments, target_terminals)
else:
source_tree.add_composed_edges(args.max_rule_size)
if not args.t2s:
target_tree.add_composed_edges(args.max_rule_size)
print >>sys.stderr, 'Source tree contains %d nodes and %d edges' % (len(source_tree.nodes), len(source_tree.edges))
print >>sys.stderr, 'Target tree contains %d nodes and %d edges' % (len(target_tree.nodes), len(target_tree.edges))
# Finally extract rules
for source_node, target_nodes in s2t_node_alignments.copy().iteritems():
for target_node in target_nodes:
if not source_node.is_terminal(source_tree) and not target_node.is_terminal(target_tree):
for rule in extract_rules(source_node, target_node, s2t_node_alignments, t2s_node_alignments, source_tree, target_tree, args.max_rule_size, source_terminals, target_terminals, s2t_word_alignments, t2s_word_alignments):
print formatter.format_rule(rule).encode('utf-8')
sys.stdout.flush()
def get_formatter(formatter_type):
if formatter_type == 'grex':
return GrexRuleFormatter()
elif formatter_type == 'cdec':
return CdecT2SRuleFormatter()
elif formatter_type == 'cdec_t2t':
return CdecT2TRuleFormatter()
else:
assert False and 'Unknown formatter type: %s' % formatter_type
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('source_trees')
parser.add_argument('target_trees')
parser.add_argument('alignments')
parser.add_argument('--suppress_counts', action='store_true', help='Emulate old rule extractor by hiding the count field(s)')
parser.add_argument('--virtual_size', '-v', type=int, default=1, help='Maximum number of components in a virtual node')
parser.add_argument('--minimal_rules', '-m', action='store_true', help='Only extract minimal rules')
parser.add_argument('--max_rule_size', '-s', type=int, default=5, help='Maximum number of parts (terminal or non-terminal) in the RHS of a rule')
group = parser.add_mutually_exclusive_group()
group.add_argument('--s2t', action='store_true', help='String-to-tree mode. Target side file should contain (tokenized) sentences instead of trees.')
group.add_argument('--t2s', action='store_true', help='Tree-to-string mode. Source side file should contain (tokenized) sentences instead of trees.')
parser.add_argument('--debug', '-d', action='store_true', help='Debug mode')
parser.add_argument('--output_format', '-f', required=False, choices=['grex', 'cdec', 'cdec_t2t'], help='Output grammar format. Options include grex (default for tree-to-tree), cdec (default for tree-to-string), and cdec_t2t')
args = parser.parse_args()
source_tree_file = open(args.source_trees)
target_tree_file = open(args.target_trees)
alignment_file = open(args.alignments)
if args.output_format == None:
args.output_format = 'grex' if not args.t2s else 'cdec'
formatter = get_formatter(args.output_format)
read_source_file = read_tree_file if not args.s2t else read_string_file
read_target_file = read_tree_file if not args.t2s else read_string_file
sentence_number = 1
for source_tree, target_tree, alignment in izip(read_source_file(source_tree_file), read_target_file(target_tree_file), read_alignment_file(alignment_file)):
print 'Sentence', sentence_number
# Can happen if Berkeley gives borked trees
if source_tree == None or target_tree == None:
pass
else:
try:
handle_sentence(source_tree, target_tree, alignment, formatter)
except Exception as e:
if args.debug:
raise
sys.stdout.flush()
sentence_number += 1