diff --git a/src/brpc/amf.cpp b/src/brpc/amf.cpp index b251ccf236..98025431ab 100644 --- a/src/brpc/amf.cpp +++ b/src/brpc/amf.cpp @@ -22,9 +22,21 @@ #include "butil/find_cstr.h" #include "brpc/log.h" #include "brpc/amf.h" +#include "gflags/gflags.h" namespace brpc { +DEFINE_int32(amf_max_depth, 128, "Maximum nesting depth for AMF objects and arrays"); + +static bool CheckAMFDepth(int depth) { + if (depth > FLAGS_amf_max_depth) { + LOG(ERROR) << "AMF exceeds max depth! max=" + << FLAGS_amf_max_depth << ", actually=" << depth; + return false; + } + return true; +} + const char* marker2str(AMFMarker marker) { switch (marker) { case AMF_MARKER_NUMBER: return "number"; @@ -378,12 +390,17 @@ bool ReadAMFUnsupported(AMFInputStream* stream) { } static bool ReadAMFObjectBody(google::protobuf::Message* message, - AMFInputStream* stream); -static bool SkipAMFObjectBody(AMFInputStream* stream); + AMFInputStream* stream, + int depth); +static bool SkipAMFObjectBody(AMFInputStream* stream, int depth); static bool ReadAMFObjectField(AMFInputStream* stream, google::protobuf::Message* message, - const google::protobuf::FieldDescriptor* field) { + const google::protobuf::FieldDescriptor* field, + int depth) { + if (!CheckAMFDepth(depth)) { + return false; + } const google::protobuf::Reflection* reflection = NULL; if (field) { reflection = message->GetReflection(); @@ -451,12 +468,12 @@ static bool ReadAMFObjectField(AMFInputStream* stream, LOG(WARNING) << "Can't set object to " << field->full_name(); } else { google::protobuf::Message* m = reflection->MutableMessage(message, field); - if (!ReadAMFObjectBody(m, stream)) { + if (!ReadAMFObjectBody(m, stream, depth + 1)) { return false; } } } else { - if (!SkipAMFObjectBody(stream)) { + if (!SkipAMFObjectBody(stream, depth + 1)) { return false; } } @@ -499,7 +516,11 @@ static bool ReadAMFObjectField(AMFInputStream* stream, } static bool ReadAMFObjectBody(google::protobuf::Message* message, - AMFInputStream* stream) { + AMFInputStream* stream, + int depth) { + if (!CheckAMFDepth(depth)) { + return false; + } const google::protobuf::Descriptor* desc = message->GetDescriptor(); std::string name; while (ReadAMFShortStringBody(&name, stream)) { @@ -519,14 +540,17 @@ static bool ReadAMFObjectBody(google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field = desc->FindFieldByName(name); RPC_VLOG_IF(field == NULL) << "Unknown field=" << desc->full_name() << "." << name; - if (!ReadAMFObjectField(stream, message, field)) { + if (!ReadAMFObjectField(stream, message, field, depth)) { return false; } } return true; } -static bool SkipAMFObjectBody(AMFInputStream* stream) { +static bool SkipAMFObjectBody(AMFInputStream* stream, int depth) { + if (!CheckAMFDepth(depth)) { + return false; + } std::string name; while (ReadAMFShortStringBody(&name, stream)) { if (name.empty()) { @@ -542,7 +566,7 @@ static bool SkipAMFObjectBody(AMFInputStream* stream) { } break; } - if (!ReadAMFObjectField(stream, NULL, NULL)) { + if (!ReadAMFObjectField(stream, NULL, NULL, depth)) { return false; } } @@ -550,7 +574,11 @@ static bool SkipAMFObjectBody(AMFInputStream* stream) { } static bool ReadAMFEcmaArrayBody(google::protobuf::Message* message, - AMFInputStream* stream) { + AMFInputStream* stream, + int depth) { + if (!CheckAMFDepth(depth)) { + return false; + } uint32_t count = 0; if (stream->cut_u32(&count) != 4u) { LOG(ERROR) << "stream is not long enough"; @@ -566,7 +594,7 @@ static bool ReadAMFEcmaArrayBody(google::protobuf::Message* message, const google::protobuf::FieldDescriptor* field = desc->FindFieldByName(name); RPC_VLOG_IF(field == NULL) << "Unknown field=" << desc->full_name() << "." << name; - if (!ReadAMFObjectField(stream, message, field)) { + if (!ReadAMFObjectField(stream, message, field, depth)) { return false; } } @@ -580,11 +608,11 @@ bool ReadAMFObject(google::protobuf::Message* msg, AMFInputStream* stream) { return false; } if ((AMFMarker)marker == AMF_MARKER_OBJECT) { - if (!ReadAMFObjectBody(msg, stream)) { + if (!ReadAMFObjectBody(msg, stream, 0)) { return false; } } else if ((AMFMarker)marker == AMF_MARKER_ECMA_ARRAY) { - if (!ReadAMFEcmaArrayBody(msg, stream)) { + if (!ReadAMFEcmaArrayBody(msg, stream, 0)) { return false; } } else if ((AMFMarker)marker != AMF_MARKER_NULL) { @@ -602,13 +630,17 @@ bool ReadAMFObject(google::protobuf::Message* msg, AMFInputStream* stream) { // [Reading AMFObject] -static bool ReadAMFObjectBody(AMFObject* obj, AMFInputStream* stream); -static bool ReadAMFEcmaArrayBody(AMFObject* obj, AMFInputStream* stream); -static bool ReadAMFArrayBody(AMFArray* arr, AMFInputStream* stream); +static bool ReadAMFObjectBody(AMFObject* obj, AMFInputStream* stream, int depth); +static bool ReadAMFEcmaArrayBody(AMFObject* obj, AMFInputStream* stream, int depth); +static bool ReadAMFArrayBody(AMFArray* arr, AMFInputStream* stream, int depth); static bool ReadAMFObjectField(AMFInputStream* stream, AMFObject* obj, - const std::string& name) { + const std::string& name, + int depth) { + if (!CheckAMFDepth(depth)) { + return false; + } uint8_t marker; if (stream->cut_u8(&marker) != 1u) { LOG(ERROR) << "stream is not long enough"; @@ -647,17 +679,17 @@ static bool ReadAMFObjectField(AMFInputStream* stream, } // fall through case AMF_MARKER_OBJECT: { - if (!ReadAMFObjectBody(obj->MutableObject(name), stream)) { + if (!ReadAMFObjectBody(obj->MutableObject(name), stream, depth + 1)) { return false; } } break; case AMF_MARKER_ECMA_ARRAY: { - if (!ReadAMFEcmaArrayBody(obj->MutableObject(name), stream)) { + if (!ReadAMFEcmaArrayBody(obj->MutableObject(name), stream, depth + 1)) { return false; } } break; case AMF_MARKER_STRICT_ARRAY: { - if (!ReadAMFArrayBody(obj->MutableArray(name), stream)) { + if (!ReadAMFArrayBody(obj->MutableArray(name), stream, depth + 1)) { return false; } } break; @@ -693,7 +725,10 @@ static bool ReadAMFObjectField(AMFInputStream* stream, return true; } -static bool ReadAMFObjectBody(AMFObject* obj, AMFInputStream* stream) { +static bool ReadAMFObjectBody(AMFObject* obj, AMFInputStream* stream, int depth) { + if (!CheckAMFDepth(depth)) { + return false; + } std::string name; while (ReadAMFShortStringBody(&name, stream)) { if (name.empty()) { @@ -709,14 +744,17 @@ static bool ReadAMFObjectBody(AMFObject* obj, AMFInputStream* stream) { } break; } - if (!ReadAMFObjectField(stream, obj, name)) { + if (!ReadAMFObjectField(stream, obj, name, depth)) { return false; } } return true; } -static bool ReadAMFEcmaArrayBody(AMFObject* obj, AMFInputStream* stream) { +static bool ReadAMFEcmaArrayBody(AMFObject* obj, AMFInputStream* stream, int depth) { + if (!CheckAMFDepth(depth)) { + return false; + } uint32_t count = 0; if (stream->cut_u32(&count) != 4u) { LOG(ERROR) << "stream is not long enough"; @@ -728,7 +766,7 @@ static bool ReadAMFEcmaArrayBody(AMFObject* obj, AMFInputStream* stream) { LOG(ERROR) << "Fail to read name from the stream"; return false; } - if (!ReadAMFObjectField(stream, obj, name)) { + if (!ReadAMFObjectField(stream, obj, name, depth)) { return false; } } @@ -742,11 +780,11 @@ bool ReadAMFObject(AMFObject* obj, AMFInputStream* stream) { return false; } if ((AMFMarker)marker == AMF_MARKER_OBJECT) { - if (!ReadAMFObjectBody(obj, stream)) { + if (!ReadAMFObjectBody(obj, stream, 0)) { return false; } } else if ((AMFMarker)marker == AMF_MARKER_ECMA_ARRAY) { - if (!ReadAMFEcmaArrayBody(obj, stream)) { + if (!ReadAMFEcmaArrayBody(obj, stream, 0)) { return false; } } else if ((AMFMarker)marker != AMF_MARKER_NULL) { @@ -757,7 +795,10 @@ bool ReadAMFObject(AMFObject* obj, AMFInputStream* stream) { return true; } -static bool ReadAMFArrayItem(AMFInputStream* stream, AMFArray* arr) { +static bool ReadAMFArrayItem(AMFInputStream* stream, AMFArray* arr, int depth) { + if (!CheckAMFDepth(depth)) { + return false; + } uint8_t marker; if (stream->cut_u8(&marker) != 1u) { LOG(ERROR) << "stream is not long enough"; @@ -796,17 +837,17 @@ static bool ReadAMFArrayItem(AMFInputStream* stream, AMFArray* arr) { } // fall through case AMF_MARKER_OBJECT: { - if (!ReadAMFObjectBody(arr->AddObject(), stream)) { + if (!ReadAMFObjectBody(arr->AddObject(), stream, depth + 1)) { return false; } } break; case AMF_MARKER_ECMA_ARRAY: { - if (!ReadAMFEcmaArrayBody(arr->AddObject(), stream)) { + if (!ReadAMFEcmaArrayBody(arr->AddObject(), stream, depth + 1)) { return false; } } break; case AMF_MARKER_STRICT_ARRAY: { - if (!ReadAMFArrayBody(arr->AddArray(), stream)) { + if (!ReadAMFArrayBody(arr->AddArray(), stream, depth + 1)) { return false; } } break; @@ -842,14 +883,17 @@ static bool ReadAMFArrayItem(AMFInputStream* stream, AMFArray* arr) { return true; } -static bool ReadAMFArrayBody(AMFArray* arr, AMFInputStream* stream) { +static bool ReadAMFArrayBody(AMFArray* arr, AMFInputStream* stream, int depth) { + if (!CheckAMFDepth(depth)) { + return false; + } uint32_t count = 0; if (stream->cut_u32(&count) != 4u) { LOG(ERROR) << "stream is not long enough"; return false; } for (uint32_t i = 0; i < count; ++i) { - if (!ReadAMFArrayItem(stream, arr)) { + if (!ReadAMFArrayItem(stream, arr, depth)) { return false; } } @@ -863,7 +907,7 @@ bool ReadAMFArray(AMFArray* arr, AMFInputStream* stream) { return false; } if ((AMFMarker)marker == AMF_MARKER_STRICT_ARRAY) { - if (!ReadAMFArrayBody(arr, stream)) { + if (!ReadAMFArrayBody(arr, stream, 0)) { return false; } } else if ((AMFMarker)marker != AMF_MARKER_NULL) { diff --git a/src/brpc/redis_reply.cpp b/src/brpc/redis_reply.cpp index e2053a360f..14c76f4841 100644 --- a/src/brpc/redis_reply.cpp +++ b/src/brpc/redis_reply.cpp @@ -26,6 +26,8 @@ namespace brpc { DEFINE_int32(redis_max_allocation_size, 64 * 1024 * 1024, "Maximum memory allocation size in bytes for a single redis request or reply (64MB by default)"); +DEFINE_int32(redis_max_reply_depth, 128, + "Maximum nesting depth for redis array replies"); //BAIDU_CASSERT(sizeof(RedisReply) == 24, size_match); const int RedisReply::npos = -1; @@ -94,12 +96,21 @@ bool RedisReply::SerializeTo(butil::IOBufAppender* appender) { } ParseError RedisReply::ConsumePartialIOBuf(butil::IOBuf& buf) { + return ConsumePartialIOBuf(buf, 0); +} + +ParseError RedisReply::ConsumePartialIOBuf(butil::IOBuf& buf, int depth) { + if (depth > FLAGS_redis_max_reply_depth) { + LOG(ERROR) << "redis reply exceeds max depth! max=" + << FLAGS_redis_max_reply_depth << ", actually=" << depth; + return PARSE_ERROR_ABSOLUTELY_WRONG; + } if (_type == REDIS_REPLY_ARRAY && _data.array.last_index >= 0) { // The parsing was suspended while parsing sub replies, // continue the parsing. RedisReply* subs = (RedisReply*)_data.array.replies; for (int i = _data.array.last_index; i < _length; ++i) { - ParseError err = subs[i].ConsumePartialIOBuf(buf); + ParseError err = subs[i].ConsumePartialIOBuf(buf, depth + 1); if (err != PARSE_OK) { return err; } @@ -257,7 +268,7 @@ ParseError RedisReply::ConsumePartialIOBuf(butil::IOBuf& buf) { // be continued in next calls by tracking _data.array.last_index. _data.array.last_index = 0; for (int64_t i = 0; i < count; ++i) { - ParseError err = subs[i].ConsumePartialIOBuf(buf); + ParseError err = subs[i].ConsumePartialIOBuf(buf, depth + 1); if (err != PARSE_OK) { return err; } diff --git a/src/brpc/redis_reply.h b/src/brpc/redis_reply.h index 34e64c00ec..13114be11a 100644 --- a/src/brpc/redis_reply.h +++ b/src/brpc/redis_reply.h @@ -146,6 +146,7 @@ class RedisReply { // by calling CopyFrom[Different|Same]Arena. DISALLOW_COPY_AND_ASSIGN(RedisReply); + ParseError ConsumePartialIOBuf(butil::IOBuf& buf, int depth); void FormatStringImpl(const char* fmt, va_list args, RedisReplyType type); void SetStringImpl(const butil::StringPiece& str, RedisReplyType type); diff --git a/test/brpc_redis_unittest.cpp b/test/brpc_redis_unittest.cpp index 43775d8583..80706b0623 100644 --- a/test/brpc_redis_unittest.cpp +++ b/test/brpc_redis_unittest.cpp @@ -31,6 +31,7 @@ namespace brpc { DECLARE_int32(idle_timeout_second); DECLARE_int32(redis_max_allocation_size); +DECLARE_int32(redis_max_reply_depth); } int main(int argc, char* argv[]) { @@ -98,6 +99,20 @@ static void RunRedisServer() { usleep(50000); } +class ScopedRedisMaxReplyDepth { +public: + explicit ScopedRedisMaxReplyDepth(int32_t depth) + : _old_depth(brpc::FLAGS_redis_max_reply_depth) { + brpc::FLAGS_redis_max_reply_depth = depth; + } + ~ScopedRedisMaxReplyDepth() { + brpc::FLAGS_redis_max_reply_depth = _old_depth; + } + +private: + int32_t _old_depth; +}; + class RedisTest : public testing::Test { protected: RedisTest() {} @@ -866,6 +881,30 @@ TEST_F(RedisTest, redis_reply_codec) { } } +TEST_F(RedisTest, redis_reply_rejects_deep_nested_arrays) { + ScopedRedisMaxReplyDepth scoped_depth(4); + + butil::IOBuf buf; + for (int i = 0; i <= brpc::FLAGS_redis_max_reply_depth; ++i) { + buf.append("*1\r\n"); + } + buf.append(":0\r\n"); + + butil::Arena arena; + brpc::RedisReply reply(&arena); + EXPECT_EQ(brpc::PARSE_ERROR_ABSOLUTELY_WRONG, reply.ConsumePartialIOBuf(buf)); + + buf.clear(); + for (int i = 0; i < brpc::FLAGS_redis_max_reply_depth; ++i) { + buf.append("*1\r\n"); + } + buf.append(":0\r\n"); + + brpc::RedisReply valid_reply(&arena); + EXPECT_EQ(brpc::PARSE_OK, valid_reply.ConsumePartialIOBuf(buf)); + EXPECT_TRUE(valid_reply.is_array()); +} + butil::Mutex s_mutex; std::unordered_map m; std::unordered_map int_map; diff --git a/test/brpc_rtmp_unittest.cpp b/test/brpc_rtmp_unittest.cpp index 5853f7778d..e6b9c62f67 100644 --- a/test/brpc_rtmp_unittest.cpp +++ b/test/brpc_rtmp_unittest.cpp @@ -35,12 +35,86 @@ #include "brpc/rtmp.h" #include "brpc/amf.h" +namespace brpc { +DECLARE_int32(amf_max_depth); +} + int main(int argc, char* argv[]) { testing::InitGoogleTest(&argc, argv); GFLAGS_NAMESPACE::ParseCommandLineFlags(&argc, &argv, true); return RUN_ALL_TESTS(); } +namespace { +class ScopedAMFMaxDepth { +public: + explicit ScopedAMFMaxDepth(int32_t depth) : _old_depth(brpc::FLAGS_amf_max_depth) { + brpc::FLAGS_amf_max_depth = depth; + } + ~ScopedAMFMaxDepth() { + brpc::FLAGS_amf_max_depth = _old_depth; + } + +private: + int32_t _old_depth; +}; + +void AppendAMFStrictArrayHeader(std::string* out, uint32_t count) { + out->push_back((char)brpc::AMF_MARKER_STRICT_ARRAY); + out->push_back((char)((count >> 24) & 0xFF)); + out->push_back((char)((count >> 16) & 0xFF)); + out->push_back((char)((count >> 8) & 0xFF)); + out->push_back((char)(count & 0xFF)); +} + +void AppendAMFObjectHeader(std::string* out) { + out->push_back((char)brpc::AMF_MARKER_OBJECT); +} + +void AppendAMFEcmaArrayHeader(std::string* out, uint32_t count) { + out->push_back((char)brpc::AMF_MARKER_ECMA_ARRAY); + out->push_back((char)((count >> 24) & 0xFF)); + out->push_back((char)((count >> 16) & 0xFF)); + out->push_back((char)((count >> 8) & 0xFF)); + out->push_back((char)(count & 0xFF)); +} + +void AppendAMFShortStringBody(std::string* out, const char* name) { + const uint16_t len = strlen(name); + out->push_back((char)((len >> 8) & 0xFF)); + out->push_back((char)(len & 0xFF)); + out->append(name, len); +} + +void AppendAMFObjectEnd(std::string* out) { + AppendAMFShortStringBody(out, ""); + out->push_back((char)brpc::AMF_MARKER_OBJECT_END); +} + +std::string MakeNestedAMFObject(int depth) { + std::string out; + AppendAMFObjectHeader(&out); + for (int i = 0; i < depth; ++i) { + AppendAMFShortStringBody(&out, "x"); + AppendAMFObjectHeader(&out); + } + for (int i = 0; i <= depth; ++i) { + AppendAMFObjectEnd(&out); + } + return out; +} + +std::string MakeNestedAMFEcmaArray(int depth) { + std::string out; + AppendAMFEcmaArrayHeader(&out, depth == 0 ? 0 : 1); + for (int i = 0; i < depth; ++i) { + AppendAMFShortStringBody(&out, "x"); + AppendAMFEcmaArrayHeader(&out, i + 1 == depth ? 0 : 1); + } + return out; +} +} // namespace + class TestRtmpClientStream : public brpc::RtmpClientStream { public: TestRtmpClientStream() @@ -523,6 +597,64 @@ TEST(RtmpTest, amf) { ASSERT_EQ("heheda", info3.description()); } +TEST(RtmpTest, amf_rejects_deep_nested_arrays) { + ScopedAMFMaxDepth scoped_depth(4); + + std::string req_buf; + for (int i = 0; i <= brpc::FLAGS_amf_max_depth + 1; ++i) { + AppendAMFStrictArrayHeader(&req_buf, 1); + } + req_buf.push_back((char)brpc::AMF_MARKER_NULL); + + google::protobuf::io::ArrayInputStream zc_stream(req_buf.data(), req_buf.size()); + brpc::AMFInputStream istream(&zc_stream); + brpc::AMFArray arr; + EXPECT_FALSE(brpc::ReadAMFArray(&arr, &istream)); + + req_buf.clear(); + for (int i = 0; i < brpc::FLAGS_amf_max_depth; ++i) { + AppendAMFStrictArrayHeader(&req_buf, 1); + } + req_buf.push_back((char)brpc::AMF_MARKER_NULL); + + google::protobuf::io::ArrayInputStream zc_stream2(req_buf.data(), req_buf.size()); + brpc::AMFInputStream istream2(&zc_stream2); + brpc::AMFArray valid_arr; + EXPECT_TRUE(brpc::ReadAMFArray(&valid_arr, &istream2)); +} + +TEST(RtmpTest, amf_rejects_deep_nested_objects) { + ScopedAMFMaxDepth scoped_depth(4); + + std::string req_buf = MakeNestedAMFObject(brpc::FLAGS_amf_max_depth + 1); + google::protobuf::io::ArrayInputStream zc_stream(req_buf.data(), req_buf.size()); + brpc::AMFInputStream istream(&zc_stream); + brpc::AMFObject obj; + EXPECT_FALSE(brpc::ReadAMFObject(&obj, &istream)); + + req_buf = MakeNestedAMFObject(brpc::FLAGS_amf_max_depth); + google::protobuf::io::ArrayInputStream zc_stream2(req_buf.data(), req_buf.size()); + brpc::AMFInputStream istream2(&zc_stream2); + brpc::AMFObject valid_obj; + EXPECT_TRUE(brpc::ReadAMFObject(&valid_obj, &istream2)); +} + +TEST(RtmpTest, amf_rejects_deep_nested_ecma_arrays) { + ScopedAMFMaxDepth scoped_depth(4); + + std::string req_buf = MakeNestedAMFEcmaArray(brpc::FLAGS_amf_max_depth + 1); + google::protobuf::io::ArrayInputStream zc_stream(req_buf.data(), req_buf.size()); + brpc::AMFInputStream istream(&zc_stream); + brpc::AMFObject obj; + EXPECT_FALSE(brpc::ReadAMFObject(&obj, &istream)); + + req_buf = MakeNestedAMFEcmaArray(brpc::FLAGS_amf_max_depth); + google::protobuf::io::ArrayInputStream zc_stream2(req_buf.data(), req_buf.size()); + brpc::AMFInputStream istream2(&zc_stream2); + brpc::AMFObject valid_obj; + EXPECT_TRUE(brpc::ReadAMFObject(&valid_obj, &istream2)); +} + TEST(RtmpTest, successfully_play_streams) { PlayingDummyService rtmp_service; brpc::Server server;