Skip to content

Create shapenet_tfrecord_gen.py #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: mesh_rcnn_tfrecord_gen_pix3d
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 0 additions & 179 deletions official/vision/beta/projects/mesh_rcnn/data/create_pix3d_tf_record.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
{
"04379243": "table",
"02958343": "car",
"03001627": "chair",
"02691156": "airplane",
"04256520": "sofa",
"04090263": "rifle",
"03636649": "lamp",
"04530566": "watercraft",
"02828884": "bench",
"03691459": "loudspeaker",
"02933112": "cabinet",
"03211117": "display",
"04401088": "telephone",
"02924116": "bus",
"02808440": "bathtub",
"03467517": "guitar",
"03325088": "faucet",
"03046257": "clock",
"03991062": "flowerpot",
"03593526": "jar",
"02876657": "bottle",
"02871439": "bookshelf",
"03642806": "laptop",
"03624134": "knife",
"04468005": "train",
"02747177": "trash bin",
"03790512": "motorbike",
"03948459": "pistol",
"03337140": "file cabinet",
"02818832": "bed",
"03928116": "piano",
"04330267": "stove",
"03797390": "mug",
"02880940": "bowl",
"04554684": "washer",
"04004475": "printer",
"03513137": "helmet",
"03761084": "microwaves",
"04225987": "skateboard",
"04460130": "tower",
"02942699": "camera",
"02801938": "basket",
"02946921": "can",
"03938244": "pillow",
"03710193": "mailbox",
"03207941": "dishwasher",
"04099429": "rocket",
"02773838": "bag",
"02843684": "birdhouse",
"03261776": "earphone",
"03759954": "microphone",
"04074963": "remote",
"03085013": "keyboard",
"02834778": "bicycle",
"02954340": "cap",
"02858304": "boat",
"02992529": "mobile phone"
}
137 changes: 137 additions & 0 deletions official/vision/beta/projects/mesh_rcnn/data/shapenet_tfrecord_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import json
import logging
import os

import tensorflow as tf
from absl import app # pylint:disable=unused-import
from absl import flags

from official.vision.beta.data import tfrecord_lib
from official.vision.beta.data.tfrecord_lib import convert_to_feature

flags.DEFINE_multi_string('shapenet_dir', '', 'Directory containing '
'ShapeNet.')
flags.DEFINE_string('output_file_prefix', '', 'Path to output file')
flags.DEFINE_integer('num_shards', 32, 'Number of shards for output file.')

FLAGS = flags.FLAGS

logger = tf.get_logger()
logger.setLevel(logging.INFO)


def parse_obj_file(file):
"""
Parses relevant data out of a .obj file. This contains all of the model information.
Args:
file: file path to .obj file
Return:
vertices: vertices of object
faces: faces of object
"""
vertices = []
faces = []

obj_file = open(file, 'r')
lines = obj_file.readlines()

for line in lines:
lineID = line[0:2]

if lineID == "v ":
vertex = line[2:].split(" ")

for i, v in enumerate(vertex):
vertex[i] = float(v)

vertices.append(vertex)

if lineID == "f ":

face = line[2:].split(" ")

for i, f in enumerate(face):
face[i] = [int(x) - 1 for x in f.split("/")]

faces.append(face)

return vertices, faces


def create_tf_example(image):
model_id = image["model_id"]
label = image["label"]

temp_file_dir = os.join(image["shapenet_dir"], image["synset_id"])
model_vertices, model_faces = parse_obj_file(os.join(temp_file_dir, image["model_id"]))

feature_dict = {"model_id": convert_to_feature(model_id),
"label": convert_to_feature(label),
"vertices": convert_to_feature(model_vertices),
"faces": convert_to_feature(model_faces)}

example = tf.train.Example(
features=tf.train.Features(feature=feature_dict))

return example, 0


def generate_annotations(images, shapenet_dir):
for image in images:
yield {"shapenet_dir": shapenet_dir,
"label": image["label"],
"model_id": image["model_id"],
"synset_id": image["synset_id"]}


def _create_tf_record_from_shapenet_dir(shapenet_dir,
output_path,
num_shards):
"""Loads Shapenet json files and converts to tf.Record format.
Args:
images_info_file: shapenet_dir download directory
output_path: Path to output tf.Record file.
num_shards: Number of output files to create.
"""

logging.info('writing to output path: %s', output_path)

# create synset ID to label mapping dictionary
with open('C:/Users/Ethan/PycharmProjects/tf-models/official/vision/beta/projects/mesh_rcnn/data'
'/shapenet_synset_dict.json', "r") as dict_file:
synset_dict = json.load(dict_file)

# images list
images = []

for _, synset_directories, _ in os.walk(shapenet_dir[0]):
for synset_directory in synset_directories:
for _, object_directories, _ in os.walk(os.path.join(shapenet_dir[0], synset_directory)):
for object_directory in object_directories:
image = {"model_id": object_directory,
"label": synset_dict[synset_directory],
"shapenet_dir": shapenet_dir,
"synset_id": synset_directory}
images.append(image)

shapenet_annotations_iter = generate_annotations(
images=images, shapenet_dir=shapenet_dir)

num_skipped = tfrecord_lib.write_tf_record_dataset(
output_path, shapenet_annotations_iter, create_tf_example, num_shards)

logging.info('Finished writing, skipped %d annotations.', num_skipped)


def main(_):
assert FLAGS.shapenet_dir, '`shapenet_dir` missing.'

directory = os.path.dirname(FLAGS.output_file_prefix)
if not tf.io.gfile.isdir(directory):
tf.io.gfile.makedirs(directory)

_create_tf_record_from_shapenet_dir('shapenet_dir', 'tmp', 32)


if __name__ == '__main__':
app.run(main)
Loading