diff --git a/scene_viz.py b/scene_viz.py index 7db7236..4f4c988 100644 --- a/scene_viz.py +++ b/scene_viz.py @@ -46,7 +46,7 @@ def __init__(self, objects: List[ObjDescriptor], prop: dict): self.inv_focal[None] = 1. / self.focal[None] self.num_objects = len(objects) - max_tri_num = max([obj.tri_num for obj in objects]) + self.num_prims = sum([obj.tri_num for obj in objects]) self.cam_orient = ti.Vector.field(3, float, ()) self.cam_t = ti.Vector.field(3, float, ()) @@ -60,15 +60,16 @@ def __init__(self, objects: List[ObjDescriptor], prop: dict): self.aabbs = ti.Vector.field(3, float, (self.num_objects, 2)) self.normals = ti.Vector.field(3, float) - self.meshes = ti.Vector.field(3, float) # leveraging SSDS, shape (N, mesh_num, 3) - vector3d + self.prims = ti.Vector.field(3, float) # leveraging SSDS, shape (N, mesh_num, 3) - vector3d self.precom_vec = ti.Vector.field(3, float) self.pixels = ti.Vector.field(3, float, (1024, 1024)) # maximum size: 1024 - self.bitmasked_nodes = ti.root.dense(ti.i, self.num_objects).bitmasked(ti.j, max_tri_num) - self.bitmasked_nodes.place(self.normals) - self.bitmasked_nodes.bitmasked(ti.k, 3).place(self.meshes) # for simple shapes, this would be efficient - self.bitmasked_nodes.dense(ti.k, 3).place(self.precom_vec) - self.mesh_cnt = ti.field(int, self.num_objects) + self.dense_nodes = ti.root.dense(ti.i, self.num_prims) + self.dense_nodes.place(self.normals) + self.dense_nodes.dense(ti.j, 3).place(self.prims, self.precom_vec) # for simple shapes, this would be efficient + + # pos0: start_idx, pos1: number of primitives, pos2: obj_id (being triangle / sphere? Others to be added, like cylinder, etc.) + self.obj_info = ti.field(int, (self.num_objects, 3)) self.initialze(objects) def set_width(self, val: int): @@ -91,19 +92,26 @@ def local_to_global(self): return forward, lateral, elevate def initialze(self, objects: List[ObjDescriptor]): + acc_prim_num = 0 for i, obj in enumerate(objects): for j, (mesh, normal) in enumerate(zip(obj.meshes, obj.normals)): - self.normals[i, j] = vec3(normal) - for k, vec in enumerate(mesh): - self.meshes[i, j, k] = vec3(vec) + cur_id = acc_prim_num + j + self.prims[cur_id, 0] = vec3(mesh[0]) + self.prims[cur_id, 1] = vec3(mesh[1]) if mesh.shape[0] > 2: # not a sphere - self.precom_vec[i, j, 0] = self.meshes[i, j, 1] - self.meshes[i, j, 0] - self.precom_vec[i, j, 1] = self.meshes[i, j, 2] - self.meshes[i, j, 0] - self.precom_vec[i, j, 2] = self.meshes[i, j, 0] + self.prims[cur_id, 2] = vec3(mesh[2]) + self.precom_vec[cur_id, 0] = self.prims[cur_id, 1] - self.prims[cur_id, 0] + self.precom_vec[cur_id, 1] = self.prims[cur_id, 2] - self.prims[cur_id, 0] + self.precom_vec[cur_id, 2] = self.prims[cur_id, 0] else: - self.precom_vec[i, j, 0] = self.meshes[i, j, 0] - self.precom_vec[i, j, 1] = self.meshes[i, j, 1] - self.mesh_cnt[i] = obj.tri_num + self.precom_vec[cur_id, 0] = self.prims[cur_id, 0] + self.precom_vec[cur_id, 1] = self.prims[cur_id, 1] + self.normals[cur_id] = vec3(normal) + self.obj_info[i, 0] = acc_prim_num + self.obj_info[i, 1] = obj.tri_num + self.obj_info[i, 2] = obj.type + acc_prim_num += obj.tri_num + self.aabbs[i, 0] = vec3(obj.aabb[0]) # unrolled self.aabbs[i, 1] = vec3(obj.aabb[1])