diff --git a/core/conversion/converters/impl/internal_ops.cpp b/core/conversion/converters/impl/internal_ops.cpp index b83312cebf..1689929ea2 100644 --- a/core/conversion/converters/impl/internal_ops.cpp +++ b/core/conversion/converters/impl/internal_ops.cpp @@ -1,3 +1,5 @@ +#include +#include #include "core/conversion/converters/converters.h" #include "core/util/prelude.h" #include "torch/torch.h" @@ -18,20 +20,18 @@ auto linear_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pat auto in = args[0].ITensorOrFreeze(ctx); auto out = in; if (in->getType() == nvinfer1::DataType::kBOOL) { - auto not_layer = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kNOT); - TORCHTRT_CHECK(not_layer, "Unable to create not layer for attn_bias_from_attn_mask"); - not_layer->setName((util::node_info(n) + "_not").c_str()); - auto neg_inf = torch::tensor(-std::numeric_limits::infinity()); - auto neg_inf_itensor = tensor_to_const(ctx, neg_inf); - auto prod_layer = add_elementwise( + std::vector singleton_dims(in->getDimensions().nbDims, 1); + auto options = torch::TensorOptions().dtype(torch::kFloat32); + auto zero = tensor_to_const( + ctx, torch::full(singleton_dims, 0.0f, options), util::node_info(n) + "_zero"); + auto neg_inf = tensor_to_const( ctx, - nvinfer1::ElementWiseOperation::kPROD, - not_layer->getOutput(0), - neg_inf_itensor, - util::node_info(n) + "_mul"); - auto add_layer = add_elementwise( - ctx, nvinfer1::ElementWiseOperation::kSUM, prod_layer->getOutput(0), in, util::node_info(n) + "_add"); - out = add_layer->getOutput(0); + torch::full(singleton_dims, -std::numeric_limits::infinity(), options), + util::node_info(n) + "_neg_inf"); + auto select_layer = ctx->net->addSelect(*in, *zero, *neg_inf); + TORCHTRT_CHECK(select_layer, "Unable to create select layer for attn_bias_from_attn_mask"); + select_layer->setName(util::node_info(n).c_str()); + out = select_layer->getOutput(0); } auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out); LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions()); diff --git a/tests/core/conversion/converters/test_scaled_dot_product_attention.cpp b/tests/core/conversion/converters/test_scaled_dot_product_attention.cpp index 0b62b57cee..d3f7f57121 100644 --- a/tests/core/conversion/converters/test_scaled_dot_product_attention.cpp +++ b/tests/core/conversion/converters/test_scaled_dot_product_attention.cpp @@ -61,6 +61,39 @@ TEST(Converters, ATenScaledDotProductAttnMaskFloatConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0])); } +TEST(Converters, ATenScaledDotProductAttnMaskBoolDoesNotProduceNaN) { + const auto graph = R"IR( + graph(%query : Tensor, %key : Tensor, %value : Tensor, %attn_mask : Tensor): + %0 : float = prim::Constant[value=0.]() + %false : bool = prim::Constant[value=0]() + %scale : NoneType = prim::Constant() + %enable_gqa : bool = prim::Constant[value=0]() + %3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false, %scale, %enable_gqa) + return (%3))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, &*g); + + auto query = at::arange(16, {at::kCUDA}).to(at::kFloat).reshape({1, 1, 4, 4}) / 16.0; + auto key = at::arange(16, {at::kCUDA}).to(at::kFloat).reshape({1, 1, 4, 4}) / 13.0; + auto value = at::arange(16, {at::kCUDA}).to(at::kFloat).reshape({1, 1, 4, 4}) / 11.0; + auto attn_mask = at::tensor( + {1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1}, + at::TensorOptions().dtype(at::kBool).device(at::kCUDA)) + .reshape({1, 1, 4, 4}); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {query, key, value, attn_mask}); + + torch_tensorrt::core::lowering::passes::UnpackScaledDotProductAttention(g); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {query, key, value, attn_mask}); + + ASSERT_FALSE(torch::isnan(jit_results[0]).any().item()); + ASSERT_FALSE(torch::isnan(trt_results[0]).any().item()); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0])); +} + TEST(Converters, ATenScaledDotProductAttnMaskIntConvertsCorrectly) { const auto graph = R"IR( graph(%query : Tensor, %key : Tensor, %value : Tensor, %attn_mask : Tensor):