Skip to content

Commit

Permalink
fix(//core/conversion/evaluator): Custom to IValue that handles int[]
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Aug 5, 2020
1 parent 0e90f78 commit 68c934a
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 3 deletions.
4 changes: 3 additions & 1 deletion core/conversion/evaluators/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ cc_library(
"NodeEvaluatorRegistry.cpp",
"prim.cpp",
"aten.cpp",
"eval_macros.h"
"eval_macros.h",
"eval_util.h",
"eval_util.cpp"
],
deps = [
"//core/util:prelude",
Expand Down
105 changes: 105 additions & 0 deletions core/conversion/evaluators/eval_util.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#include "ATen/core/ivalue.h"
#include "ATen/core/List.h"
#include "core/util/prelude.h"
#include "ATen/core/functional.h"

namespace trtorch {
namespace core {
namespace conversion {
namespace evaluators {

//TODO: Switch back to PyTorch canonical implimentation
c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v) {
if (v->node()->kind() != torch::jit::prim::Constant || v->type()->cast<c10::FunctionType>()) {
return c10::nullopt;
}
const torch::jit::Node* node = v->node();
const c10::TypePtr& type = v->type();
if (type->isSubtypeOf(c10::TensorType::get())) {
return node->t(c10::attr::value);
} else if (type->isSubtypeOf(c10::BoolType::get())) {
return (bool)node->i(c10::attr::value);
} else if (
type->isSubtypeOf(c10::NumberType::get()) &&
node->kindOf(c10::attr::value) == torch::jit::AttributeKind::i) {
return node->i(c10::attr::value);
} else if (
type->isSubtypeOf(c10::NumberType::get()) &&
node->kindOf(c10::attr::value) == torch::jit::AttributeKind::f) {
return node->f(c10::attr::value);
} else if (type->isSubtypeOf(c10::ListType::ofInts())) {
try {
const auto& is = node->is(c10::attr::value);
return is;
} catch (const std::exception& ex) {
const auto& ival = node->ival(c10::attr::value);
return ival;
}
} else if (type->isSubtypeOf(c10::ListType::ofFloats())) {
try {
const auto& fs = node->fs(c10::attr::value);
return fs;
} catch (const std::exception& ex) {
const auto& ival = node->ival(c10::attr::value);
return ival;
}
} else if (type->isSubtypeOf(c10::ListType::ofBools())) {
const auto bs = c10::fmap<bool>(node->is(c10::attr::value));
return bs;
} else if (type->isSubtypeOf(c10::ListType::ofTensors())) {
try {
const auto& ts = node->ts(c10::attr::value);
return ts;
} catch (const std::exception& ex) {
const auto& ival = node->ival(c10::attr::value);
return ival;
}
} else if (type->isSubtypeOf(c10::ListType::ofStrings())) {
try {
const auto& ss = node->ss(c10::attr::value);
auto vals = c10::impl::GenericList(c10::StringType::get());
for (const auto& str : ss) {
vals.push_back(str);
}
return vals;
} catch (const std::exception& ex) {
const auto& ival = node->ival(c10::attr::value);
return ival;
}
} else if (
type->cast<c10::ListType>() &&
node->kindOf(c10::attr::value) == torch::jit::AttributeKind::ival) {
const auto& list = node->ival(c10::attr::value);
TRTORCH_ASSERT(list.isList(), "Is not a list");
return list;
} else if (
type->cast<c10::DictType>() &&
node->kindOf(c10::attr::value) == torch::jit::AttributeKind::ival) {
const auto& dict = node->ival(c10::attr::value);
TRTORCH_ASSERT(dict.isGenericDict(), "Is not a dict");
return dict;
} else if (
type->cast<c10::TupleType>() &&
node->kindOf(c10::attr::value) == torch::jit::AttributeKind::ival) {
const auto& tup = node->ival(c10::attr::value);
TRTORCH_ASSERT(tup.isTuple(), "Is not a tuple");
return tup;
} else if (type == c10::StringType::get()) {
const auto& s = node->s(c10::attr::value);
return s;
} else if (type == c10::DeviceObjType::get()) {
auto d = c10::Device(node->s(c10::attr::value));
return d;
} else if (node->mustBeNone()) {
return torch::jit::IValue();
} else {
std::stringstream ss;
ss << "constant literal not supported for: " << type->str();
throw std::runtime_error(ss.str());
}
}

} // namespace evaluators
} // namespace conversion
} // namespace core
} // namespace trtorch
15 changes: 15 additions & 0 deletions core/conversion/evaluators/eval_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include "torch/csrc/jit/ir/ir.h"

namespace trtorch {
namespace core {
namespace conversion {
namespace evaluators {

c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v);

} // namespace evaluators
} // namespace conversion
} // namespace core
} // namespace trtorch
5 changes: 3 additions & 2 deletions core/conversion/evaluators/prim.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <limits>

#include "torch/csrc/jit/ir/ir.h"
#include "torch/csrc/jit/ir/constants.h"
//#include "torch/csrc/jit/ir/constants.h"
#include "ATen/core/functional.h"
#include "ATen/core/ivalue.h"
#include "ATen/core/List.h"
Expand All @@ -11,6 +11,7 @@

#include "core/conversion/evaluators/evaluators.h"
#include "core/conversion/evaluators/eval_macros.h"
#include "core/conversion/evaluators/eval_util.h"

namespace trtorch {
namespace core {