diff --git a/FB15k-237/completion.py b/FB15k-237/completion.py new file mode 100755 index 0000000..d14a6f5 --- /dev/null +++ b/FB15k-237/completion.py @@ -0,0 +1,224 @@ +#!/usr/bin/python + +import numpy as np +import sys + +relation = sys.argv[1] + +dataPath_ = 'tasks/' + relation +featurePath = dataPath_ + '/path_to_use.txt' +feature_stats = dataPath_ + '/path_stats.txt' +relationId_path = 'relation2id.txt' +ent_id_path = '/home/xwhan/RL_KB/data/FB15k-237/' + 'entity2id.txt' +rel_id_path = '/home/xwhan/RL_KB/data/FB15k-237/' + 'relation2id.txt' +test_data_path = '/home/xwhan/RL_KB/data/FB15k-237/tasks/' + relation + '/sort_test.pairs' + +def bfs_two(e1,e2,path,kb): + start = 0 + end = len(path) + left = set() + right = set() + left.add(e1) + right.add(e2) + + left_path = [] + right_path = [] + while(start < end): + left_step = path[start] + left_next = set() + right_step = path[end-1] + right_next = set() + + if len(left) < len(right): + left_path.append(left_step) + start += 1 + #print 'left',start + for triple in kb: + if triple[2] == left_step and triple[0] in left: + left_next.add(triple[1]) + left = left_next + + else: + right_path.append(right_step) + end -= 1 + #print 'right', end + for triple in kb: + if triple[2] == right_step and triple[1] in right: + right_next.add(triple[0]) + right = right_next + + if len(right & left) != 0: + #print right & left + return True + + return False + +def get_features(): + stats = {} + f = open(feature_stats) + path_freq = f.readlines() + f.close() + for line in path_freq: + path = line.split('\t')[0] + num = int(line.split('\t')[1]) + stats[path] = num + max_freq = np.max(stats.values()) + + relation2id = {} + f = open(relationId_path) + content = f.readlines() + f.close() + for line in content: + relation2id[line.split()[0]] = int(line.split()[1]) + + useful_paths = [] + named_paths = [] + f = open(featurePath) + paths = f.readlines() + f.close() + + for line in paths: + path = line.rstrip() + + if path not in stats: + continue + elif max_freq > 1 and stats[path] < 2: + continue + + length = len(path.split(' -> ')) + + if length <= 10: + pathIndex = [] + pathName = [] + relations = path.split(' -> ') + + for rel in relations: + pathName.append(rel) + rel_id = relation2id[rel] + pathIndex.append(rel_id) + useful_paths.append(pathIndex) + named_paths.append(pathName) + + print 'How many paths used: ', len(useful_paths) + return useful_paths, named_paths + +f1 = open(ent_id_path) +f2 = open(rel_id_path) +content1 = f1.readlines() +content2 = f2.readlines() +f1.close() +f2.close() + +entity2id = {} +relation2id = {} +for line in content1: + entity2id[line.split()[0]] = int(line.split()[1]) + +for line in content2: + relation2id[line.split()[0]] = int(line.split()[1]) + +ent_vec_E = np.loadtxt(dataPath_ + '/entity2vec.unif') +rel_vec_E = np.loadtxt(dataPath_ + '/relation2vec.unif') +rel = '/' + relation.replace("@", "/") +relation_vec_E = rel_vec_E[relation2id[rel],:] + +ent_vec_R = np.loadtxt(dataPath_ + '/entity2vec.bern') +rel_vec_R = np.loadtxt(dataPath_ + '/relation2vec.bern') +M = np.loadtxt(dataPath_ + '/A.bern') +M = M.reshape([-1,100,100]) +relation_vec_R = rel_vec_R[relation2id[rel],:] +M_vec = M[relation2id[rel],:,:] + +_, named_paths = get_features() +path_weights = [] +for path in named_paths: + weight = 1.0/len(path) + path_weights.append(weight) +path_weights = np.array(path_weights) +f = open(dataPath_ + '/graph.txt') +kb_lines = f.readlines() +f.close() +kb = [] +for line in kb_lines: + e1 = line.split()[0] + rel = line.split()[1] + e2 = line.split()[2] + kb.append((e1,e2,rel)) + +f = open(test_data_path) +test_data = f.readlines() +f.close() +test_pairs = [] +test_labels = [] +for line in test_data: + e1 = line.split(',')[0].replace('thing$','') + e1 = '/' + e1[0] + '/' + e1[2:] + e2 = line.split(',')[1].split(':')[0].replace('thing$','') + e2 = '/' + e2[0] + '/' + e2[2:] + test_pairs.append((e1,e2)) + label = 1 if line[-2] == '+' else 0 + test_labels.append(label) + +scores_E = [] +scores_R = [] +scores_rl = [] + +print 'How many queries: ', len(test_pairs) +for idx, sample in enumerate(test_pairs): + e1_vec_E = ent_vec_E[entity2id[sample[0]],:] + e2_vec_E = ent_vec_E[entity2id[sample[1]],:] + score_E = -np.sum(np.square(e1_vec_E + relation_vec_E - e2_vec_E)) + scores_E.append(score_E) + + e1_vec_R = ent_vec_R[entity2id[sample[0]],:] + e2_vec_R = ent_vec_R[entity2id[sample[1]],:] + e1_vec_rel = np.matmul(e1_vec_R, M_vec) + e2_vec_rel = np.matmul(e2_vec_R, M_vec) + score_R = -np.sum(np.square(e1_vec_rel + relation_vec_R - e2_vec_rel)) + scores_R.append(score_R) + + features = [] + for path in named_paths: + features.append(int(bfs_two(sample[0], sample[1], path, kb))) + #features = features*path_weights + score_rl = sum(features) + scores_rl.append(score_rl) + +rank_stats_E = zip(scores_E, test_labels) +rank_stats_R = zip(scores_R, test_labels) +rank_stats_rl = zip(scores_rl, test_labels) +rank_stats_E.sort(key = lambda x:x[0], reverse=True) +rank_stats_R.sort(key = lambda x:x[0], reverse=True) +rank_stats_rl.sort(key = lambda x:x[0], reverse=True) + +correct = 0 +ranks = [] +for idx, item in enumerate(rank_stats_E): + if item[1] == 1: + correct += 1 + ranks.append(correct/(1.0+idx)) +ap1 = np.mean(ranks) +print 'TransE: ', ap1 + +correct = 0 +ranks = [] +for idx, item in enumerate(rank_stats_R): + if item[1] == 1: + correct += 1 + ranks.append(correct/(1.0+idx)) +ap2 = np.mean(ranks) +print 'TransR: ', ap2 + +correct = 0 +ranks = [] +for idx, item in enumerate(rank_stats_rl): + if item[1] == 1: + correct += 1 + ranks.append(correct/(1.0+idx)) +ap3 = np.mean(ranks) +print 'RL: ', ap3 + + + + + diff --git a/FB15k-237/createQueries.py b/FB15k-237/createQueries.py new file mode 100644 index 0000000..8409905 --- /dev/null +++ b/FB15k-237/createQueries.py @@ -0,0 +1,31 @@ +import sys + + +relation = sys.argv[1] +dataPath = 'tasks/' + relation + '/all_data' +outPath1 ='/home/xwhan/ForWH/fb15k/queries/' + relation +outPath2 = '/home/xwhan/ForWH/fb15k/test_queries/' + relation + +f = open(dataPath) +content = f.readlines() +f.close() + +newlines = [] +for line in content: + e1 = line.split()[0] + e2 = line.split()[1] + e1 = 'thing$' + e1[1:].replace('/','_') + e2 = 'thing$' + e2[1:].replace('/','_') + newline = e1 + '\t' + e2 + '\n' + newlines.append(newline) + +train_lines = newlines[:int(0.7*len(newlines))] +test_lines = newlines[int(0.7*len(newlines)):] + +g1 = open(outPath1, 'w') +g1.writelines(train_lines) +g1.close() + +g2 = open(outPath2, 'w') +g2.writelines(test_lines) +g2.close() diff --git a/FB15k-237/eval_transX.py b/FB15k-237/eval_transX.py new file mode 100755 index 0000000..1798649 --- /dev/null +++ b/FB15k-237/eval_transX.py @@ -0,0 +1,132 @@ +#!/usr/bin/python + +import sys +import numpy as np + +relation = sys.argv[1] + +dataPath_ = '/home/xwhan/RL_KB/data/FB15k-237/tasks/' + relation + +ent_id_path = '/home/xwhan/RL_KB/data/FB15k-237/' + 'entity2id.txt' +rel_id_path = '/home/xwhan/RL_KB/data/FB15k-237/' + 'relation2id.txt' +test_data_path = '/home/xwhan/RL_KB/data/FB15k-237/tasks/' + relation + '/sort_test.pairs' + +f1 = open(ent_id_path) +f2 = open(rel_id_path) +content1 = f1.readlines() +content2 = f2.readlines() +f1.close() +f2.close() + +entity2id = {} +relation2id = {} +for line in content1: + entity2id[line.split()[0]] = int(line.split()[1]) + +for line in content2: + relation2id[line.split()[0]] = int(line.split()[1]) + +ent_vec = np.loadtxt(dataPath_ + '/entity2vec.vec') +rel_vec = np.loadtxt(dataPath_ + '/relation2vec.vec') +M = np.loadtxt(dataPath_ + '/A.vec') +M = M.reshape([rel_vec.shape[0],-1]) + + +f = open(test_data_path) +test_data = f.readlines() +f.close() +test_pairs = [] +test_labels = [] +# queries = set() +for line in test_data: + e1 = line.split(',')[0].replace('thing$','') + e1 = '/' + e1[0] + '/' + e1[2:] + e2 = line.split(',')[1].split(':')[0].replace('thing$','') + e2 = '/' + e2[0] + '/' + e2[2:] + test_pairs.append((e1,e2)) + label = 1 if line[-2] == '+' else 0 + test_labels.append(label) + + +aps = [] +query = test_pairs[0][0] +y_true = [] +y_score = [] + +score_all = [] + +rel = '/' + relation.replace("@", "/") +d_r = np.expand_dims(rel_vec[relation2id[rel],:],1) +w_r = np.expand_dims(M[relation2id[rel],:],1) + +for idx, sample in enumerate(test_pairs): + #print 'query node: ', sample[0], idx + if sample[0] == query: + h = np.expand_dims(ent_vec[entity2id[sample[0]],:],1) + t = np.expand_dims(ent_vec[entity2id[sample[1]],:],1) + + h_ = h - np.matmul(w_r.transpose(), h)*w_r + t_ = t - np.matmul(w_r.transpose(), t)*w_r + + + score = -np.sum(np.square(h_ + d_r - t_)) + + score_all.append(score) + y_score.append(score) + y_true.append(test_labels[idx]) + else: + query = sample[0] + count = zip(y_score, y_true) + count.sort(key = lambda x:x[0], reverse=True) + #print count + ranks = [] + correct = 0 + for idx_, item in enumerate(count): + if item[1] == 1: + correct += 1 + ranks.append(correct/(1.0+idx_)) + if len(ranks)==0: + ranks.append(0) + aps.append(np.mean(ranks)) + if len(aps) % 10 == 0: + print 'How many queries:', len(aps) + print np.mean(aps) + y_true = [] + y_score = [] + h = np.expand_dims(ent_vec[entity2id[sample[0]],:],1) + t = np.expand_dims(ent_vec[entity2id[sample[1]],:],1) + + h_ = h - np.matmul(w_r.transpose(), h)*w_r + t_ = t - np.matmul(w_r.transpose(), t)*w_r + + + score = -np.sum(np.square(h_ + d_r - t_)) + + score_all.append(score) + y_score.append(score) + y_true.append(test_labels[idx]) + + +mean_ap = np.mean(aps) +print 'MAP: ', mean_ap + +score_label = zip(score_all, test_labels) +stats = sorted(score_label, key = lambda x:x[0], reverse=True) + +correct = 0 +ranks = [] +for idx, item in enumerate(stats): + if item[1] == 1: + correct += 1 + ranks.append(correct/(1.0+idx)) +ap1 = np.mean(ranks) +print 'TransX: ', ap1 + + + + + + + + + diff --git a/FB15k-237/evaluate.py b/FB15k-237/evaluate.py new file mode 100755 index 0000000..d8af845 --- /dev/null +++ b/FB15k-237/evaluate.py @@ -0,0 +1,296 @@ +#!/usr/bin/python + +import sys +import numpy as np +from DFS.KB import * +from sklearn import linear_model + +relation = sys.argv[1] + +dataPath_ = '/home/xwhan/RL_KB/data/FB15k-237/tasks/' + relation +featurePath = dataPath_ + '/path_to_use.txt' +feature_stats = dataPath_ + '/path_stats.txt' +relationId_path = '/home/xwhan/RL_KB/data/FB15k-237/relation2id.txt' + +def train(kb, kb_inv, named_paths): + f = open(dataPath_ + '/_train.pairs') + train_data = f.readlines() + f.close() + train_pairs = [] + train_labels = [] + for line in train_data: + e1 = line.split(',')[0].replace('thing$','') + e1 = '/' + e1[0] + '/' + e1[2:] + e2 = line.split(',')[1].split(':')[0].replace('thing$','') + e2 = '/' + e2[0] + '/' + e2[2:] + train_pairs.append((e1,e2)) + label = 1 if line[-2] == '+' else 0 + train_labels.append(label) + training_features = [] + for sample in train_pairs: + feature = [] + for path in named_paths: + feature.append(int(bfs_two(sample[0], sample[1], path, kb, kb_inv))) + training_features.append(feature) + regr = linear_model.LinearRegression() + regr.fit(training_features, train_labels) + print("training error: %.5f" + % np.mean((regr.predict(training_features) - train_labels) ** 2)) + weights = regr.coef_ + print weights + return weights + +def get_features(): + stats = {} + f = open(feature_stats) + path_freq = f.readlines() + f.close() + for line in path_freq: + path = line.split('\t')[0] + num = int(line.split('\t')[1]) + stats[path] = num + max_freq = np.max(stats.values()) + + relation2id = {} + f = open(relationId_path) + content = f.readlines() + f.close() + for line in content: + relation2id[line.split()[0]] = int(line.split()[1]) + + useful_paths = [] + named_paths = [] + f = open(featurePath) + paths = f.readlines() + f.close() + + for line in paths: + path = line.rstrip() + + # if path not in stats: + # continue + # elif max_freq > 1 and stats[path] < 2: + # continue + + length = len(path.split(' -> ')) + + if length <= 10: + pathIndex = [] + pathName = [] + relations = path.split(' -> ') + + for rel in relations: + pathName.append(rel) + rel_id = relation2id[rel] + pathIndex.append(rel_id) + useful_paths.append(pathIndex) + named_paths.append(pathName) + + print 'How many paths used: ', len(useful_paths) + return useful_paths, named_paths + +def evaluate_logic(): + kb = KB() + kb_inv = KB() + + f = open(dataPath_ + '/graph.txt') + kb_lines = f.readlines() + f.close() + + for line in kb_lines: + e1 = line.split()[0] + rel = line.split()[1] + e2 = line.split()[2] + kb.addRelation(e1,rel,e2) + kb_inv.addRelation(e2,rel,e1) + + _, named_paths = get_features() + + #path_weights = train(kb, kb_inv, named_paths) + + path_weights = [] + for path in named_paths: + weight = 1.0/len(path) + path_weights.append(weight) + + path_weights = np.array(path_weights) + + f = open(dataPath_ + '/sort_all.pairs') + test_data = f.readlines() + f.close() + print 'predict all' + test_pairs = [] + test_labels = [] + # queries = set() + for line in test_data: + e1 = line.split(',')[0].replace('thing$','') + e1 = '/' + e1[0] + '/' + e1[2:] + e2 = line.split(',')[1].split(':')[0].replace('thing$','') + e2 = '/' + e2[0] + '/' + e2[2:] + test_pairs.append((e1,e2)) + label = 1 if line[-2] == '+' else 0 + test_labels.append(label) + + # f = open(dataPath_ + '/topk.pairs') + # test_data = f.readlines() + # f.close() + # test_pairs = [] + # test_labels = [] + # for line in test_data: + # e1 = line.split()[0] + # e2 = line.split()[1] + # label = int(line.split()[2]) + # test_pairs.append((e1,e2)) + # test_labels.append(label) + + aps = [] + query = test_pairs[0][0] + y_true = [] + y_score = [] + + score_all = [] + + for idx, sample in enumerate(test_pairs): + #print 'query node: ', sample[0], idx + if sample[0] == query: + features = [] + for path in named_paths: + features.append(int(bfs_two(sample[0], sample[1], path, kb, kb_inv))) + + features = features*path_weights + + #score = np.inner(features, path_weights) + score = np.sum(features) + + score_all.append(score) + y_score.append(score) + y_true.append(test_labels[idx]) + else: + query = sample[0] + count = zip(y_score, y_true) + count.sort(key = lambda x:x[0], reverse=True) + ranks = [] + correct = 0 + for idx_, item in enumerate(count): + if item[1] == 1: + correct += 1 + ranks.append(correct/(1.0+idx_)) + #break + if len(ranks) ==0: + ranks.append(0) + #print np.mean(ranks) + aps.append(np.mean(ranks)) + if len(aps) % 10 == 0: + print 'How many queries:', len(aps) + print np.mean(aps) + y_true = [] + y_score = [] + features = [] + for path in named_paths: + features.append(int(bfs_two(sample[0], sample[1], path, kb, kb_inv))) + + features = features*path_weights + #score = np.inner(features, path_weights) + score = np.sum(features) + + score_all.append(score) + y_score.append(score) + y_true.append(test_labels[idx]) + # print y_score, y_true + + count = zip(y_score, y_true) + count.sort(key = lambda x:x[0], reverse=True) + ranks = [] + correct = 0 + for idx_, item in enumerate(count): + if item[1] == 1: + correct += 1 + ranks.append(correct/(1.0+idx_)) + #break + #if len(ranks) ==0: + # ranks.append(0) + aps.append(np.mean(ranks)) + + score_label = zip(score_all, test_labels) + score_label_ranked = sorted(score_label, key = lambda x:x[0], reverse=True) + + hits = 0 + for idx, item in enumerate(score_label_ranked): + if item[1] == 1: + hits += 1 + if idx == 9: + print 'P@10: ', hits/10.0 + elif idx ==99: + print 'P@100: ', hits/100.0 + break + + mean_ap = np.mean(aps) + print 'MAP: ', mean_ap + + +def bfs_two(e1,e2,path,kb,kb_inv): + start = 0 + end = len(path) + left = set() + right = set() + left.add(e1) + right.add(e2) + + left_path = [] + right_path = [] + while(start < end): + left_step = path[start] + left_next = set() + right_step = path[end-1] + right_next = set() + + if len(left) < len(right): + left_path.append(left_step) + start += 1 + #print 'left',start + # for triple in kb: + # if triple[2] == left_step and triple[0] in left: + # left_next.add(triple[1]) + # left = left_next + for entity in left: + try: + for path_ in kb.getPathsFrom(entity): + if path_.relation == left_step: + left_next.add(path_.connected_entity) + except Exception as e: + print len(left) + print 'not such entity' + return False + left = left_next + + else: + right_path.append(right_step) + end -= 1 + #print 'right', end + # for triple in kb: + # if triple[2] == right_step and triple[1] in right: + # right_next.add(triple[0]) + # right = right_next + for entity in right: + try: + for path_ in kb_inv.getPathsFrom(entity): + if path_.relation == right_step: + right_next.add(path_.connected_entity) + except Exception as e: + print 'no such entity' + print len(right) + return False + right = right_next + + + if len(right & left) != 0: + return True + return False + + +if __name__ == '__main__': + evaluate_logic() + # evaluate(relation) + # test(relation) + + diff --git a/FB15k-237/find_train.py b/FB15k-237/find_train.py new file mode 100644 index 0000000..1240913 --- /dev/null +++ b/FB15k-237/find_train.py @@ -0,0 +1,28 @@ +import sys + +relation_name = sys.argv[1] + +dataPath = 'tasks/' + relation_name + '/train.pairs' +outPath = 'tasks/' + relation_name + '/train_pos' + +f = open(dataPath) +content = f.readlines() +f.close() + +rel = '/' + relation_name.replace("@", "/") + +newlines = [] +for line in content: + line = line.rstrip() + label = line[-1] + if label == '+': + e1 = line.split(',')[0].replace('thing$','') + e1 = '/' + e1[0] + '/' + e1[2:] + e2 = line.split(',')[1].split(':')[0].replace('thing$','') + e2 = '/' + e2[0] + '/' + e2[2:] + newline = e1 + '\t' + e2 + '\t' + rel + '\n' + newlines.append(newline) + +g = open(outPath, 'w') +g.writelines(newlines) +g.close() diff --git a/FB15k-237/for_transE.py b/FB15k-237/for_transE.py new file mode 100644 index 0000000..e703215 --- /dev/null +++ b/FB15k-237/for_transE.py @@ -0,0 +1,29 @@ +import sys + +relation = sys.argv[1] + +datapath = './tasks/' + relation +'/' + +file1 = open(datapath + 'graph.txt') +content1 = file1.readlines() +file1.close() + +new_lines = [] +for line in content1: + e1 = line.split()[0] + rel = line.split()[1] + if rel[-4:] == '_inv': + continue + e2 = line.split()[2] + newline = e1 + '\t' + e2 + '\t' + rel + '\n' + new_lines.append(newline) + +file2 = open(datapath + 'train_pos') +content2 = file2.readlines() +file2.close() + +data = new_lines + content2 + +g = open(datapath + 'transE','w') +g.writelines(data) +g.close() \ No newline at end of file diff --git a/FB15k-237/get_stats.py b/FB15k-237/get_stats.py new file mode 100644 index 0000000..3832c69 --- /dev/null +++ b/FB15k-237/get_stats.py @@ -0,0 +1,18 @@ +from collections import Counter + +f = open('raw.kb') +content = f.readlines() +f.close() + +all_relations = [] +for line in content: + rel = line.split()[1] + all_relations.append(rel) + +relation_stats = Counter(all_relations).items() +relation_stats = sorted(relation_stats, key = lambda x:x[1], reverse=True) + +g = open('rel_stats', 'w') +for item in relation_stats: + g.write(item[0] + '\t' + str(item[1]) + '\n') +g.close() \ No newline at end of file diff --git a/FB15k-237/policy_1.sh b/FB15k-237/policy_1.sh new file mode 100755 index 0000000..2003465 --- /dev/null +++ b/FB15k-237/policy_1.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +relation=$1 +gpuid=$2 + +relation1=${relation//@/\/} +h="/" +relation2=$h$relation1 +cd ~/RL_KB/data/FB15k-237 +echo $relation +python find_train.py $relation +graphpath="tasks/" +graphpath="$graphpath$relation/graph.txt" +grep -v $relation2 full_data.txt > $graphpath +cd ../.. +CUDA_VISIBLE_DEVICES=$gpuid python sl_policy.py $relation diff --git a/FB15k-237/policy_2.sh b/FB15k-237/policy_2.sh new file mode 100755 index 0000000..ee016a5 --- /dev/null +++ b/FB15k-237/policy_2.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +relation=$1 +gpuid=$2 + + +cd ../.. +CUDA_VISIBLE_DEVICES=$gpuid python policy_agent.py $relation retrain +CUDA_VISIBLE_DEVICES=$gpuid python policy_agent.py $relation test +testpath="data/FB15k-237/tasks/" +testpath="$testpath$relation/" +cd $testpath +#sort test_all.pairs > sort_all.pairs +cd ../../../.. +pwd +python evaluate.py $relation diff --git a/FB15k-237/process.py b/FB15k-237/process.py new file mode 100644 index 0000000..46cc4d2 --- /dev/null +++ b/FB15k-237/process.py @@ -0,0 +1,71 @@ +relation_id = set() +entity_id = set() + +f1 = open('/home/xwhan/RL_KB/data/FB15k-237/relation2id.txt') +relations = f1.readlines() +f1.close() + +f2 = open('/home/xwhan/RL_KB/data/FB15k-237/entity2id.txt') +entities = f2.readlines() +f2.close() + +for line in relations: + relation_id.add(line.split()[0]) + +for line in entities: + entity_id.add(line.split()[0]) + +g = open('/home/xwhan/RL_KB/data/FB15k-237/full_data.txt') +full_data = g.readlines() +g.close() + +g1 = open('/home/xwhan/RL_KB/data/FB15k-237/kb_env.txt') +kb_env = g1.readlines() +g1.close() + +new_full_data = [] +new_kb_env = [] +for line in full_data: + e1 = line.split()[0] + rel = line.split()[1] + e2 = line.split()[2] + if (e1 in entity_id) and (e2 in entity_id) and (rel in relation_id): + new_full_data.append(line) + +print len(kb_env) +print len(full_data) + +for line in kb_env: + e1 = line.split()[0] + e2 = line.split()[1] + rel = line.split()[2] + if (e1 in entity_id) and (e2 in entity_id) and (rel in relation_id): + new_kb_env.append(line) + + +g2 = open('/home/xwhan/RL_KB/data/FB15k-237/new_full_data.txt','w') +g2.writelines(new_full_data) +g2.close() + +g3 = open('/home/xwhan/RL_KB/data/FB15k-237/new_kb_env.txt','w') +g3.writelines(new_kb_env) +g3.close() + +f = open('/home/xwhan/RL_KB/data/FB15k-237/raw.kb') +raw_data = f.readlines() +f.close() + +kb = [] +for line in raw_data: + e1 = line.split()[0] + rel = line.split()[1] + e2 = line.split()[2] + if (e1 in entity_id) and (e2 in entity_id) and (rel in relation_id): + kb.append(line) +f = open('/home/xwhan/RL_KB/data/FB15k-237/kb.txt','w') +f.writelines(kb) +f.close() + + + + diff --git a/FB15k-237/run_transX.sh b/FB15k-237/run_transX.sh new file mode 100755 index 0000000..21d34d5 --- /dev/null +++ b/FB15k-237/run_transX.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +echo $1 +relation=$1 +echo $relation +# python transX.py $relation +./transX -relation $relation +echo "done" +./eval_transX.py $relation \ No newline at end of file diff --git a/FB15k-237/test1.py b/FB15k-237/test1.py new file mode 100644 index 0000000..48ac8d4 --- /dev/null +++ b/FB15k-237/test1.py @@ -0,0 +1,5 @@ +import numpy as np + +M = np.loadtxt('A.bern') +M = M.reshape([-1,100,100]) +print M[0,:,:].shape \ No newline at end of file diff --git a/FB15k-237/transE.sh b/FB15k-237/transE.sh new file mode 100755 index 0000000..c69650b --- /dev/null +++ b/FB15k-237/transE.sh @@ -0,0 +1,16 @@ +#!/bin/bash -e + +echo $1 +relation=$1 +python for_transE.py $1 +echo "FINISHING BUILDING DATA" +echo $1 +./Train_TransE -relation $1 +echo "TransE finished" +testpath="tasks/" +testpath="$testpath$relation/" +cd $testpath +sort test_all.pairs > sort_all.pairs +pwd +cd ../.. +python transE_eval.py $1 diff --git a/FB15k-237/transE_eval.py b/FB15k-237/transE_eval.py new file mode 100755 index 0000000..d868daf --- /dev/null +++ b/FB15k-237/transE_eval.py @@ -0,0 +1,121 @@ +# import cPickle +import sys +import numpy as np + +relation = sys.argv[1] + +dataPath_ = './tasks/' + relation +ent_id_path = './entity2id.txt' +rel_id_path = './relation2id.txt' +test_data_path = dataPath_ + '/sort_test.pairs' + +f1 = open(ent_id_path) +f2 = open(rel_id_path) +content1 = f1.readlines() +content2 = f2.readlines() +f1.close() +f2.close() + +entity2id = {} +relation2id = {} +for line in content1: + entity2id[line.split()[0]] = int(line.split()[1]) + +for line in content2: + relation2id[line.split()[0]] = int(line.split()[1]) + + +ent_vec = np.loadtxt(dataPath_ + '/entity2vec.unif') +rel_vec = np.loadtxt(dataPath_ + '/relation2vec.unif') + +f = open(test_data_path) +test_data = f.readlines() +f.close() + +test_pairs = [] +test_labels = [] +# queries = set() +for line in test_data: + e1 = line.split(',')[0].replace('thing$','') + e1 = '/' + e1[0] + '/' + e1[2:] + e2 = line.split(',')[1].split(':')[0].replace('thing$','') + e2 = '/' + e2[0] + '/' + e2[2:] + test_pairs.append((e1,e2)) + label = 1 if line[-2] == '+' else 0 + test_labels.append(label) + + +aps = [] +query = test_pairs[0][0] +y_true = [] +y_score = [] +query_samples = [] + +score_all = [] + +rel = '/' + relation.replace("@", "/") +relation_vec = rel_vec[relation2id[rel],:] + +g = open(dataPath_ + '/topk.pairs','w') + +for idx, sample in enumerate(test_pairs): + #print 'query node: ', sample[0], idx + if sample[0] == query: + e1_vec = ent_vec[entity2id[sample[0]],:] + e2_vec = ent_vec[entity2id[sample[1]],:] + score = -np.sum(np.square(e1_vec + relation_vec - e2_vec)) + score_all.append(score) + y_score.append(score) + y_true.append(test_labels[idx]) + query_samples.append(sample) + else: + query = sample[0] + count = zip(y_score, y_true, query_samples) + count.sort(key = lambda x:x[0], reverse=True) + for idx_, item in enumerate(count): + if idx_ <= 40: + g.write(item[2][0]+'\t'+item[2][1]+'\t'+str(item[1])+'\n') + + ranks = [] + correct = 0 + for idx_, item in enumerate(count): + if item[1] == 1: + correct += 1 + ranks.append(correct/(1.0+idx_)) + #break + if len(ranks)==0: + ranks.append(0) + #print np.mean(ranks) + aps.append(np.mean(ranks)) + if len(aps) % 10 == 0: + print('How many queries:', len(aps)) + print(np.mean(aps)) + y_true = [] + y_score = [] + query_samples = [] + e1_vec = ent_vec[entity2id[sample[0]],:] + e2_vec = ent_vec[entity2id[sample[1]],:] + + score = -np.sum(np.square(e1_vec + relation_vec - e2_vec)) + score_all.append(score) + y_score.append(score) + y_true.append(test_labels[idx]) + query_samples.append(sample) + +g.close() +score_label = zip(score_all, test_labels) +score_label_ranked = sorted(score_label, key = lambda x:x[0], reverse=True) + +hits = 0 +for idx, item in enumerate(score_label_ranked): + if item[1] == 1: + hits += 1 + if idx == 9: + print('P@10: ', hits/10.0) + elif idx ==99: + print('P@100: ', hits/100.0) + break + +mean_ap = np.mean(aps) +print('MAP: ', mean_ap) + diff --git a/FB15k-237/transR.sh b/FB15k-237/transR.sh new file mode 100755 index 0000000..d5ee583 --- /dev/null +++ b/FB15k-237/transR.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +echo $1 +./Train_TransR -relation "$1" +echo "TransR finished" +python transR_eval.py $1 \ No newline at end of file diff --git a/FB15k-237/transR_eval.py b/FB15k-237/transR_eval.py new file mode 100755 index 0000000..c24c262 --- /dev/null +++ b/FB15k-237/transR_eval.py @@ -0,0 +1,116 @@ +#import cPickle +import sys +import numpy as np + +relation = sys.argv[1] + +dataPath_ = './tasks/' + relation + +ent_id_path = './entity2id.txt' +rel_id_path = './relation2id.txt' +test_data_path = dataPath_ + '/sort_test.pairs' + + +f1 = open(ent_id_path) +f2 = open(rel_id_path) +content1 = f1.readlines() +content2 = f2.readlines() +f1.close() +f2.close() + +entity2id = {} +relation2id = {} +for line in content1: + entity2id[line.split()[0]] = int(line.split()[1]) + +for line in content2: + relation2id[line.split()[0]] = int(line.split()[1]) + + +ent_vec = np.loadtxt(dataPath_ + '/entity2vec.bern') +rel_vec = np.loadtxt(dataPath_ + '/relation2vec.bern') +M = np.loadtxt(dataPath_ + '/A.bern') +M = M.reshape([-1,100,100]) + +f = open(test_data_path) +test_data = f.readlines() +f.close() + +test_pairs = [] +test_labels = [] +# queries = set() +for line in test_data: + e1 = line.split(',')[0].replace('thing$','') + e1 = '/' + e1[0] + '/' + e1[2:] + e2 = line.split(',')[1].split(':')[0].replace('thing$','') + e2 = '/' + e2[0] + '/' + e2[2:] + test_pairs.append((e1,e2)) + label = 1 if line[-2] == '+' else 0 + test_labels.append(label) + + +aps = [] +query = test_pairs[0][0] +y_true = [] +y_score = [] + +score_all = [] + +rel = '/' + relation.replace("@", "/") +relation_vec = rel_vec[relation2id[rel],:] +M_vec = M[relation2id[rel],:,:] + +for idx, sample in enumerate(test_pairs): + #print 'query node: ', sample[0], idx + if sample[0] == query: + e1_vec = ent_vec[entity2id[sample[0]],:] + e2_vec = ent_vec[entity2id[sample[1]],:] + + e1_vec_rel = np.matmul(e1_vec, M_vec) + e2_vec_rel = np.matmul(e2_vec, M_vec) + score = -np.sum(np.square(e1_vec_rel + relation_vec - e2_vec_rel)) + + score_all.append(score) + y_score.append(score) + y_true.append(test_labels[idx]) + else: + query = sample[0] + count = zip(y_score, y_true) + count.sort(key = lambda x:x[0], reverse=True) + #print count + ranks = [] + correct = 0 + for idx_, item in enumerate(count): + if item[1] == 1: + correct += 1 + ranks.append(correct/(1.0+idx_)) + aps.append(np.mean(ranks)) + if len(aps) % 5 == 0: + print('#queries:', len(aps)) + print(np.mean(aps)) + y_true = [] + y_score = [] + e1_vec_rel = np.matmul(e1_vec, M_vec) + e2_vec_rel = np.matmul(e2_vec, M_vec) + score = -np.sum(np.square(e1_vec_rel + relation_vec - e2_vec_rel)) + score_all.append(score) + y_score.append(score) + y_true.append(test_labels[idx]) + + +score_label = zip(score_all, test_labels) +score_label_ranked = sorted(score_label, key = lambda x:x[0], reverse=True) + +hits = 0 +for idx, item in enumerate(score_label_ranked): + if item[1] == 1: + hits += 1 + if idx == 9: + print('P@10: ', hits/10.0) + elif idx ==99: + print('P@100: ', hits/100.0) + break + +mean_ap = np.mean(aps) +print('MAP: ', mean_ap) + diff --git a/FB15k-237/transX.py b/FB15k-237/transX.py new file mode 100755 index 0000000..8a8cb53 --- /dev/null +++ b/FB15k-237/transX.py @@ -0,0 +1,61 @@ +#!/usr/bin/python + +import sys + +relation = sys.argv[1] + +f1 = open('relation2id.txt') +content1 = f1.readlines() +f1.close() + +relation2id = {} +for line in content1: + rel = line.split()[0] + id_ = line.split()[1] + relation2id[rel] = id_ + +f2 = open('entity2id.txt') +content2 = f2.readlines() +f2.close() + +entity2id = {} +for line in content2: + entity = line.split()[0] + id_ = line.split()[1] + entity2id[entity] = id_ + + +taskPath = 'tasks/' + relation + '/' +f = open(taskPath + 'transE') +triples = f.readlines() +f.close() + +size = len(triples) + +g = open(taskPath + 'triple2id.txt','w') +g.write(str(size) + '\n') + +for line in triples: + e1 = line.split('\t')[0] + e2 = line.split('\t')[1] + rel = line[:-1].split('\t')[2] + if e1 not in entity2id or e2 not in entity2id: + continue + e1_id = entity2id[e1] + e2_id = entity2id[e2] + rel_id = relation2id[rel] + g.write(e1_id + '\t' + e2_id + '\t' + rel_id + ' ' + '\n') + +g.close() + +g = open(taskPath + 'entity2id.txt','w') +g.write(str(len(entity2id.keys())) + '\n') +for item in entity2id.items(): + g.write(item[0] + '\t' + item[1] +'\n') +g.close() + +g = open(taskPath + 'relation2id.txt','w') +g.write(str(len(relation2id.keys())) + '\n') +for item in relation2id.items(): + g.write(item[0] + '\t' + item[1] + '\n') +g.close() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2c2afbc --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +tensorflow +scikit-learn \ No newline at end of file diff --git a/scripts/BFS/BFS.py b/scripts/BFS/BFS.py index 11cdb2e..d5d4812 100755 --- a/scripts/BFS/BFS.py +++ b/scripts/BFS/BFS.py @@ -1,7 +1,4 @@ -try: - from Queue import Queue -except ImportError: - from queue import Queue +from queue import Queue import random def BFS(kb, entity1, entity2): @@ -28,7 +25,7 @@ def test(): class foundPaths(object): def __init__(self, kb): self.entities = {} - for entity, relations in kb.entities.iteritems(): + for entity, relations in kb.entities.items(): self.entities[entity] = (False, "", "") def isFound(self, entity): diff --git a/scripts/BFS/KB.py b/scripts/BFS/KB.py index 863bdc0..ba2fc24 100755 --- a/scripts/BFS/KB.py +++ b/scripts/BFS/KB.py @@ -3,7 +3,7 @@ def __init__(self): self.entities = {} def addRelation(self, entity1, relation, entity2): - if self.entities.has_key(entity1): + if entity1 in self.entities: self.entities[entity1].append(Path(relation, entity2)) else: self.entities[entity1] = [Path(relation, entity2)] @@ -23,16 +23,15 @@ def removePath(self, entity1, entity2): def pickRandomIntermediatesBetween(self, entity1, entity2, num): #TO DO: COULD BE IMPROVED BY NARROWING THE RANGE OF RANDOM EACH TIME ITERATIVELY CHOOSE AN INTERMEDIATE - from sets import Set import random - res = Set() + res = set() if num > len(self.entities) - 2: raise ValueError('Number of Intermediates picked is larger than possible', 'num_entities: {}'.format(len(self.entities)), 'num_itermediates: {}'.format(num)) for i in range(num): - itermediate = random.choice(self.entities.keys()) + itermediate = random.choice(list(self.entities.keys())) while itermediate in res or itermediate == entity1 or itermediate == entity2: - itermediate = random.choice(self.entities.keys()) + itermediate = random.choice(list(self.entities.keys())) res.add(itermediate) return list(res) diff --git a/scripts/env.py b/scripts/env.py index 3f11aeb..a040fb6 100644 --- a/scripts/env.py +++ b/scripts/env.py @@ -27,7 +27,7 @@ def __init__(self, dataPath, task=None): self.path_relations = [] # Knowledge Graph for path finding - f = open(dataPath + 'kb_env_rl.txt') + f = open(dataPath + 'kb_env.txt') kb_all = f.readlines() f.close() diff --git a/scripts/evaluate.py b/scripts/evaluate.py index a665429..9fd6be6 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -5,47 +5,49 @@ from BFS.KB import * from sklearn import linear_model from keras.models import Sequential -from keras.layers import Dense, Activation +from keras.layers import Dense, Activation, Input +from utils import dataPath relation = sys.argv[1] -dataPath_ = '../NELL-995/tasks/' + relation +dataPath_ = dataPath + '/tasks/' + relation featurePath = dataPath_ + '/path_to_use.txt' feature_stats = dataPath_ + '/path_stats.txt' -relationId_path = '../NELL-995/relation2id.txt' +relationId_path = dataPath + '/relation2id.txt' + +def import_file(filename): + with open(filename) as f: + content = f.readlines() + return content def train(kb, kb_inv, named_paths): - f = open(dataPath_ + '/train.pairs') - train_data = f.readlines() - f.close() - train_pairs = [] + train_data = import_file(dataPath_ + '/train.pairs') train_labels = [] + training_features = [] for line in train_data: - e1 = line.split(',')[0].replace('thing$','') - e2 = line.split(',')[1].split(':')[0].replace('thing$','') + e1, e2 = line.split(':')[0].replace('thing$', '/').replace('_', '/').split(',') if (e1 not in kb.entities) or (e2 not in kb.entities): continue - train_pairs.append((e1,e2)) + label = 1 if line[-2] == '+' else 0 train_labels.append(label) - training_features = [] - for sample in train_pairs: + feature = [] for path in named_paths: - feature.append(int(bfs_two(sample[0], sample[1], path, kb, kb_inv))) + feature.append(int(bfs_two(e1, e2, path, kb, kb_inv))) training_features.append(feature) + model = Sequential() - input_dim = len(named_paths) - model.add(Dense(1, activation='sigmoid' ,input_dim=input_dim)) + model.add(Input(shape=(len(named_paths), ))) + model.add(Dense(1, activation='sigmoid')) model.compile(optimizer = 'rmsprop', loss='binary_crossentropy', metrics=['accuracy']) - model.fit(training_features, train_labels, nb_epoch=300, batch_size=128) + model.fit(np.array(training_features), np.array(train_labels), epochs=300, batch_size=128) return model + def get_features(): stats = {} - f = open(feature_stats) - path_freq = f.readlines() - f.close() + path_freq = import_file(feature_stats) for line in path_freq: path = line.split('\t')[0] num = int(line.split('\t')[1]) @@ -53,20 +55,15 @@ def get_features(): max_freq = np.max(stats.values()) relation2id = {} - f = open(relationId_path) - content = f.readlines() - f.close() + content = import_file(relationId_path) for line in content: relation2id[line.split()[0]] = int(line.split()[1]) - useful_paths = [] - named_paths = [] - f = open(featurePath) - paths = f.readlines() - f.close() - - print len(paths) + paths = import_file(featurePath) + print('#total paths imported: ', len(paths)) + named_paths = [] + useful_paths = [] for line in paths: path = line.rstrip() @@ -84,17 +81,13 @@ def get_features(): useful_paths.append(pathIndex) named_paths.append(pathName) - print 'How many paths used: ', len(useful_paths) + print('#paths used: ', len(useful_paths)) return useful_paths, named_paths def evaluate_logic(): kb = KB() kb_inv = KB() - - f = open(dataPath_ + '/graph.txt') - kb_lines = f.readlines() - f.close() - + kb_lines = import_file(dataPath_ + '/graph.txt') for line in kb_lines: e1 = line.split()[0] rel = line.split()[1] @@ -106,18 +99,12 @@ def evaluate_logic(): model = train(kb, kb_inv, named_paths) - - f = open(dataPath_ + '/sort_test.pairs') - test_data = f.readlines() - f.close() + test_data = import_file(dataPath_ + '/sort_test.pairs') test_pairs = [] test_labels = [] # queries = set() for line in test_data: - e1 = line.split(',')[0].replace('thing$','') - # e1 = '/' + e1[0] + '/' + e1[2:] - e2 = line.split(',')[1].split(':')[0].replace('thing$','') - # e2 = '/' + e2[0] + '/' + e2[2:] + e1, e2 = line.split(':')[0].replace('thing$', '/').replace('_', '/').split(',') if (e1 not in kb.entities) or (e2 not in kb.entities): continue test_pairs.append((e1,e2)) @@ -148,7 +135,7 @@ def evaluate_logic(): y_true.append(test_labels[idx]) else: query = sample[0] - count = zip(y_score, y_true) + count = list(zip(y_score, y_true)) count.sort(key = lambda x:x[0], reverse=True) ranks = [] correct = 0 @@ -181,7 +168,7 @@ def evaluate_logic(): y_true.append(test_labels[idx]) # print y_score, y_true - count = zip(y_score, y_true) + count = list(zip(y_score, y_true)) count.sort(key = lambda x:x[0], reverse=True) ranks = [] correct = 0 @@ -195,7 +182,7 @@ def evaluate_logic(): score_label_ranked = sorted(score_label, key = lambda x:x[0], reverse=True) mean_ap = np.mean(aps) - print 'RL MAP: ', mean_ap + print('RL MAP: ', mean_ap) def bfs_two(e1,e2,path,kb,kb_inv): diff --git a/scripts/fact_prediction_eval.py b/scripts/fact_prediction_eval.py index 52213c6..e3a9266 100755 --- a/scripts/fact_prediction_eval.py +++ b/scripts/fact_prediction_eval.py @@ -3,16 +3,18 @@ import numpy as np import sys from BFS.KB import * +from utils import dataPath + relation = sys.argv[1] -dataPath_ = '../NELL-995/tasks/' + relation +dataPath_ = dataPath + '/tasks/' + relation featurePath = dataPath_ + '/path_to_use.txt' feature_stats = dataPath_ + '/path_stats.txt' -relationId_path ='../NELL-995/' + 'relation2id.txt' -ent_id_path = '../NELL-995/' + 'entity2id.txt' -rel_id_path = '../NELL-995/' + 'relation2id.txt' -test_data_path = '../NELL-995/tasks/' + relation + '/sort_test.pairs' +relationId_path = dataPath + 'relation2id.txt' +ent_id_path = dataPath' + 'entity2id.txt' +rel_id_path = dataPath + 'relation2id.txt' +test_data_path = dataPath_ + '/sort_test.pairs' def bfs_two(e1,e2,path,kb,kb_inv): start = 0 @@ -39,9 +41,9 @@ def bfs_two(e1,e2,path,kb,kb_inv): if path_.relation == left_step: left_next.add(path_.connected_entity) except Exception as e: - print 'left', len(left) - print left - print 'not such entity' + print('left', len(left)) + print(left) + print('no such entity') return False left = left_next @@ -54,8 +56,8 @@ def bfs_two(e1,e2,path,kb,kb_inv): if path_.relation == right_step: right_next.add(path_.connected_entity) except Exception as e: - print 'right', len(right) - print 'no such entity' + print('right', len(right)) + print('no such entity') return False right = right_next @@ -110,7 +112,7 @@ def get_features(): useful_paths.append(pathIndex) named_paths.append(pathName) - print 'How many paths used: ', len(useful_paths) + print('How many paths used: ', len(useful_paths)) return useful_paths, named_paths f1 = open(ent_id_path) @@ -182,9 +184,9 @@ def get_features(): scores_R = [] scores_rl = [] -print 'How many queries: ', len(test_pairs) +print('How many queries: ', len(test_pairs)) for idx, sample in enumerate(test_pairs): - print 'Query No.%d of %d' % (idx, len(test_pairs)) + print('Query No.%d of %d' % (idx, len(test_pairs))) e1_vec_E = ent_vec_E[entity2id[sample[0]],:] e2_vec_E = ent_vec_E[entity2id[sample[1]],:] score_E = -np.sum(np.square(e1_vec_E + relation_vec_E - e2_vec_E)) @@ -218,7 +220,7 @@ def get_features(): correct += 1 ranks.append(correct/(1.0+idx)) ap1 = np.mean(ranks) -print 'TransE: ', ap1 +print('TransE: ', ap1) correct = 0 ranks = [] @@ -227,7 +229,7 @@ def get_features(): correct += 1 ranks.append(correct/(1.0+idx)) ap2 = np.mean(ranks) -print 'TransR: ', ap2 +print('TransR: ', ap2) correct = 0 @@ -237,7 +239,7 @@ def get_features(): correct += 1 ranks.append(correct/(1.0+idx)) ap3 = np.mean(ranks) -print 'RL: ', ap3 +print('RL: ', ap3) f1 = open(ent_id_path) f2 = open(rel_id_path) @@ -300,7 +302,7 @@ def get_features(): correct += 1 ranks.append(correct/(1.0+idx)) ap4 = np.mean(ranks) -print 'TransH: ', ap4 +print('TransH: ', ap4) ent_vec_D = np.loadtxt(dataPath_ + '/entity2vec.vec_D') rel_vec_D = np.loadtxt(dataPath_ + '/relation2vec.vec_D') @@ -335,5 +337,5 @@ def get_features(): correct += 1 ranks.append(correct/(1.0+idx)) ap5 = np.mean(ranks) -print 'TransD: ', ap5 +print('TransD: ', ap5) diff --git a/scripts/link_prediction_eval.sh b/scripts/link_prediction_eval.sh index 54874e3..a841bba 100755 --- a/scripts/link_prediction_eval.sh +++ b/scripts/link_prediction_eval.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/bin/bash -e relation=$1 diff --git a/scripts/networks.py b/scripts/networks.py index 359239e..5ddbfb7 100644 --- a/scripts/networks.py +++ b/scripts/networks.py @@ -1,34 +1,34 @@ import tensorflow as tf def policy_nn(state, state_dim, action_dim, initializer): - w1 = tf.get_variable('W1', [state_dim, 512], initializer = initializer, regularizer=tf.contrib.layers.l2_regularizer(0.01)) - b1 = tf.get_variable('b1', [512], initializer = tf.constant_initializer(0.0)) + w1 = tf.compat.v1.get_variable('W1', [state_dim, 512], initializer = initializer, regularizer=tf.keras.regularizers.L2(l2=0.01)) + b1 = tf.compat.v1.get_variable('b1', [512], initializer = tf.constant_initializer(0.0)) h1 = tf.nn.relu(tf.matmul(state, w1) + b1) - w2 = tf.get_variable('w2', [512, 1024], initializer = initializer, regularizer=tf.contrib.layers.l2_regularizer(0.01)) - b2 = tf.get_variable('b2', [1024], initializer = tf.constant_initializer(0.0)) + w2 = tf.compat.v1.get_variable('w2', [512, 1024], initializer = initializer, regularizer=tf.keras.regularizers.L2(l2=0.01)) + b2 = tf.compat.v1.get_variable('b2', [1024], initializer = tf.constant_initializer(0.0)) h2 = tf.nn.relu(tf.matmul(h1, w2) + b2) - w3 = tf.get_variable('w3', [1024, action_dim], initializer = initializer, regularizer=tf.contrib.layers.l2_regularizer(0.01)) - b3 = tf.get_variable('b3', [action_dim], initializer = tf.constant_initializer(0.0)) + w3 = tf.compat.v1.get_variable('w3', [1024, action_dim], initializer = initializer, regularizer=tf.keras.regularizers.L2(l2=0.01)) + b3 = tf.compat.v1.get_variable('b3', [action_dim], initializer = tf.constant_initializer(0.0)) action_prob = tf.nn.softmax(tf.matmul(h2,w3) + b3) return action_prob def value_nn(state, state_dim, initializer): - w1 = tf.get_variable('w1', [state_dim, 64], initializer = initializer) - b1 = tf.get_variable('b1', [64], initializer = tf.constant_initializer(0.0)) + w1 = tf.compat.v1.get_variable('w1', [state_dim, 64], initializer = initializer) + b1 = tf.compat.v1.get_variable('b1', [64], initializer = tf.constant_initializer(0.0)) h1 = tf.nn.relu(tf.matmul(state,w1) + b1) - w2 = tf.get_variable('w2', [64,1], initializer = initializer) - b2 = tf.get_variable('b2', [1], initializer = tf.constant_initializer(0.0)) + w2 = tf.compat.v1.get_variable('w2', [64,1], initializer = initializer) + b2 = tf.compat.v1.get_variable('b2', [1], initializer = tf.constant_initializer(0.0)) value_estimated = tf.matmul(h1, w2) + b2 return tf.squeeze(value_estimated) def q_network(state, state_dim, action_space, initializer): - w1 = tf.get_variable('w1', [state_dim, 128], initializer=initializer) - b1 = tf.get_variable('b1', [128], initializer = tf.constant_initializer(0)) + w1 = tf.compat.v1.get_variable('w1', [state_dim, 128], initializer=initializer) + b1 = tf.compat.v1.get_variable('b1', [128], initializer = tf.constant_initializer(0)) h1 = tf.nn.relu(tf.matmul(state, w1) + b1) - w2 = tf.get_variable('w2', [128, 64], initializer = initializer) - b2 = tf.get_variable('b2', [64], initializer = tf.constant_initializer(0)) + w2 = tf.compat.v1.get_variable('w2', [128, 64], initializer = initializer) + b2 = tf.compat.v1.get_variable('b2', [64], initializer = tf.constant_initializer(0)) h2 = tf.nn.relu(tf.matmul(h1, w2) + b2) - w3 = tf.get_variable('w3', [64, action_space], initializer = initializer) - b3 = tf.get_variable('b3', [action_space], initializer = tf.constant_initializer(0)) + w3 = tf.compat.v1.get_variable('w3', [64, action_space], initializer = initializer) + b3 = tf.compat.v1.get_variable('b3', [action_space], initializer = tf.constant_initializer(0)) action_values = tf.matmul(h2, w3) + b3 return [w1,b1,w2,b2,w3,b3,action_values] diff --git a/scripts/pathfinder.sh b/scripts/pathfinder.sh index 8c13ade..a02b71f 100755 --- a/scripts/pathfinder.sh +++ b/scripts/pathfinder.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/bin/bash -e relation=$1 python sl_policy.py $relation diff --git a/scripts/policy_agent.py b/scripts/policy_agent.py index df4088f..f0e6bb6 100644 --- a/scripts/policy_agent.py +++ b/scripts/policy_agent.py @@ -1,5 +1,3 @@ -from __future__ import division -from __future__ import print_function import tensorflow as tf import numpy as np import collections @@ -13,34 +11,37 @@ from env import Env +# Disable eager execution +# Else leads to an error in tf v2 due to incompatibility with placeholders. +tf.compat.v1.disable_eager_execution() + relation = sys.argv[1] task = sys.argv[2] graphpath = dataPath + 'tasks/' + relation + '/' + 'graph.txt' relationPath = dataPath + 'tasks/' + relation + '/' + 'train_pos' class PolicyNetwork(object): - def __init__(self, scope = 'policy_network', learning_rate = 0.001): - self.initializer = tf.contrib.layers.xavier_initializer() - with tf.variable_scope(scope): - self.state = tf.placeholder(tf.float32, [None, state_dim], name = 'state') - self.action = tf.placeholder(tf.int32, [None], name = 'action') - self.target = tf.placeholder(tf.float32, name = 'target') + self.initializer = tf.keras.initializers.GlorotUniform() + with tf.compat.v1.variable_scope(scope): + self.state = tf.compat.v1.placeholder(tf.float32, [None, state_dim], name='state') + self.action = tf.compat.v1.placeholder(tf.int32, [None], name='action') + self.target = tf.compat.v1.placeholder(tf.float32, name='target') self.action_prob = policy_nn(self.state, state_dim, action_space, self.initializer) action_mask = tf.cast(tf.one_hot(self.action, depth = action_space), tf.bool) self.picked_action_prob = tf.boolean_mask(self.action_prob, action_mask) - self.loss = tf.reduce_sum(-tf.log(self.picked_action_prob)*self.target) + sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope=scope)) - self.optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate) + self.loss = tf.reduce_sum(-tf.math.log(self.picked_action_prob) * self.target) + sum(tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES, scope=scope)) + self.optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate) self.train_op = self.optimizer.minimize(self.loss) def predict(self, state, sess = None): - sess = sess or tf.get_default_session() + sess = sess or tf.compat.v1.get_default_session() return sess.run(self.action_prob, {self.state:state}) def update(self, state, target, action, sess=None): - sess = sess or tf.get_default_session() + sess = sess or tf.compat.v1.get_default_session() feed_dict = { self.state: state, self.target: target, self.action: action } _, loss = sess.run([self.train_op, self.loss], feed_dict) return loss @@ -179,15 +180,15 @@ def REINFORCE(training_pairs, policy_nn, num_episodes): def retrain(): print('Start retraining') - tf.reset_default_graph() + tf.compat.v1.reset_default_graph() policy_network = PolicyNetwork(scope = 'supervised_policy') f = open(relationPath) training_pairs = f.readlines() f.close() - saver = tf.train.Saver() - with tf.Session() as sess: + saver = tf.compat.v1.train.Saver() + with tf.compat.v1.Session() as sess: saver.restore(sess, 'models/policy_supervised_' + relation) print("sl_policy restored") episodes = len(training_pairs) @@ -198,7 +199,7 @@ def retrain(): print('Retrained model saved') def test(): - tf.reset_default_graph() + tf.compat.v1.reset_default_graph() policy_network = PolicyNetwork(scope = 'supervised_policy') f = open(relationPath) @@ -210,12 +211,12 @@ def test(): success = 0 - saver = tf.train.Saver() + saver = tf.compat.v1.train.Saver() path_found = [] path_relation_found = [] path_set = set() - with tf.Session() as sess: + with tf.compat.v1.Session() as sess: saver.restore(sess, 'models/policy_retrained' + relation) print('Model reloaded') diff --git a/scripts/sl_policy.py b/scripts/sl_policy.py index f70c055..738c8e6 100644 --- a/scripts/sl_policy.py +++ b/scripts/sl_policy.py @@ -1,5 +1,3 @@ -from __future__ import division -from __future__ import print_function import tensorflow as tf import numpy as np from itertools import count @@ -12,6 +10,11 @@ from BFS.BFS import BFS import time + +# Disable eager execution +# Else leads to an error in tf v2 due to incompatibility with placeholders. +tf.compat.v1.disable_eager_execution() + relation = sys.argv[1] # episodes = int(sys.argv[2]) graphpath = dataPath + 'tasks/' + relation + '/' + 'graph.txt' @@ -20,30 +23,31 @@ class SupervisedPolicy(object): """docstring for SupervisedPolicy""" def __init__(self, learning_rate = 0.001): - self.initializer = tf.contrib.layers.xavier_initializer() - with tf.variable_scope('supervised_policy'): - self.state = tf.placeholder(tf.float32, [None, state_dim], name = 'state') - self.action = tf.placeholder(tf.int32, [None], name = 'action') + self.initializer = tf.keras.initializers.GlorotUniform() + with tf.compat.v1.variable_scope('supervised_policy'): + self.state = tf.compat.v1.placeholder(tf.float32, [None, state_dim], name='state') + self.action = tf.compat.v1.placeholder(tf.int32, [None], name='action') self.action_prob = policy_nn(self.state, state_dim, action_space, self.initializer) action_mask = tf.cast(tf.one_hot(self.action, depth = action_space), tf.bool) self.picked_action_prob = tf.boolean_mask(self.action_prob, action_mask) - self.loss = tf.reduce_sum(-tf.log(self.picked_action_prob)) + sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope = 'supervised_policy')) - self.optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate) + self.loss = tf.reduce_sum(-tf.math.log(self.picked_action_prob)) + sum(tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES, scope='supervised_policy')) + self.optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate = learning_rate) self.train_op = self.optimizer.minimize(self.loss) def predict(self, state, sess = None): - sess = sess or tf.get_default_session() + sess = sess or tf.compat.v1.get_default_session() return sess.run(self.action_prob, {self.state: state}) def update(self, state, action, sess = None): - sess = sess or tf.get_default_session() + sess = sess or tf.compat.v1.get_default_session() _, loss = sess.run([self.train_op, self.loss], {self.state: state, self.action: action}) return loss + def train(): - tf.reset_default_graph() + tf.compat.v1.reset_default_graph() policy_nn = SupervisedPolicy() f = open(relationPath) @@ -52,9 +56,9 @@ def train(): num_samples = len(train_data) - saver = tf.train.Saver() - with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) + saver = tf.compat.v1.train.Saver() + with tf.compat.v1.Session() as sess: + sess.run(tf.compat.v1.global_variables_initializer()) if num_samples > 500: num_samples = 500 else: @@ -66,11 +70,10 @@ def train(): env = Env(dataPath, train_data[episode%num_samples]) sample = train_data[episode%num_samples].split() - try: good_episodes = teacher(sample[0], sample[1], 5, env, graphpath) - except Exception as e: - print('Cannot find a path') + except KeyError as e: + print('Cannot find a path, %s' % e) continue for item in good_episodes: @@ -84,7 +87,7 @@ def train(): policy_nn.update(state_batch, action_batch) saver.save(sess, 'models/policy_supervised_' + relation) - print('Model saved') + print('Model saved at %s' % 'models/policy_supervised_' + relation) def test(test_episodes): diff --git a/scripts/utils.py b/scripts/utils.py index 19e4284..c7bfec8 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -1,5 +1,3 @@ -from __future__ import division -from __future__ import print_function import random from collections import namedtuple, Counter import numpy as np @@ -21,7 +19,7 @@ max_steps = 50 max_steps_test = 50 -dataPath = '../NELL-995/' +dataPath = '../FB15k-237/' Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))