Skip to content

Commit

Permalink
hotfix
Browse files Browse the repository at this point in the history
  • Loading branch information
jhyuklee committed Jan 18, 2022
1 parent 87fa48b commit 00a2c36
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,8 @@ p-serve: dump-dir large-index

# Serve demo
serve-demo: model-name dump-dir large-index
CUDA_VISIBLE_DEVICES=4 make q-serve MODEL_NAME=$(MODEL_NAME) Q_PORT=$(Q_PORT)
CUDA_VISIBLE_DEVICES=4 make p-serve DUMP_DIR=$(DUMP_DIR) Q_PORT=$(Q_PORT) I_PORT=$(I_PORT)
CUDA_VISIBLE_DEVICES=0 make q-serve MODEL_NAME=$(MODEL_NAME) Q_PORT=$(Q_PORT)
CUDA_VISIBLE_DEVICES=0 make p-serve DUMP_DIR=$(DUMP_DIR) Q_PORT=$(Q_PORT) I_PORT=$(I_PORT)

# Evaluation using the open QA demo (used for benchmark)
eval-demo: nq-open-data
Expand Down
6 changes: 3 additions & 3 deletions densephrases/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from densephrases import Options
from densephrases.utils.single_utils import load_encoder
from densephrases.utils.open_utils import load_phrase_index, get_query2vec, load_qa_pairs
from densephrases.utils.open_utils import load_phrase_index, get_query2vec
from densephrases.utils.data_utils import TrueCaser

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,7 +67,7 @@ def search(self, query='', retrieval_unit='phrase', top_k=10, truecase=True, ret
query = [self.truecase.get_true_case(query) if query == query.lower() else query for query in batch_query]

# Get question vector
outs = self.query2vec(batch_query)
outs = list(self.query2vec(batch_query))
start = np.concatenate([out[0] for out in outs], 0)
end = np.concatenate([out[1] for out in outs], 0)
query_vec = np.concatenate([start, end], 1)
Expand Down Expand Up @@ -110,7 +110,7 @@ def search(self, query='', retrieval_unit='phrase', top_k=10, truecase=True, ret

def set_encoder(self, load_dir, device='cuda'):
self.args.load_dir = load_dir
self.model, self.tokenizer, self.config = load_encoder(device, self.args)
self.model, self.tokenizer, self.config = load_encoder(device, self.args, query_only=True)
self.query2vec = get_query2vec(
query_encoder=self.model, tokenizer=self.tokenizer, args=self.args, batch_size=64
)
Expand Down

0 comments on commit 00a2c36

Please sign in to comment.