Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 77 additions & 33 deletions src/brpc/amf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,21 @@
#include "butil/find_cstr.h"
#include "brpc/log.h"
#include "brpc/amf.h"
#include "gflags/gflags.h"

namespace brpc {

DEFINE_int32(amf_max_depth, 128, "Maximum nesting depth for AMF objects and arrays");

static bool CheckAMFDepth(int depth) {
if (depth > FLAGS_amf_max_depth) {
LOG(ERROR) << "AMF exceeds max depth! max="
<< FLAGS_amf_max_depth << ", actually=" << depth;
return false;
}
return true;
}

const char* marker2str(AMFMarker marker) {
switch (marker) {
case AMF_MARKER_NUMBER: return "number";
Expand Down Expand Up @@ -378,12 +390,17 @@ bool ReadAMFUnsupported(AMFInputStream* stream) {
}

static bool ReadAMFObjectBody(google::protobuf::Message* message,
AMFInputStream* stream);
static bool SkipAMFObjectBody(AMFInputStream* stream);
AMFInputStream* stream,
int depth);
static bool SkipAMFObjectBody(AMFInputStream* stream, int depth);

static bool ReadAMFObjectField(AMFInputStream* stream,
google::protobuf::Message* message,
const google::protobuf::FieldDescriptor* field) {
const google::protobuf::FieldDescriptor* field,
int depth) {
if (!CheckAMFDepth(depth)) {
return false;
}
const google::protobuf::Reflection* reflection = NULL;
if (field) {
reflection = message->GetReflection();
Expand Down Expand Up @@ -451,12 +468,12 @@ static bool ReadAMFObjectField(AMFInputStream* stream,
LOG(WARNING) << "Can't set object to " << field->full_name();
} else {
google::protobuf::Message* m = reflection->MutableMessage(message, field);
if (!ReadAMFObjectBody(m, stream)) {
if (!ReadAMFObjectBody(m, stream, depth + 1)) {
return false;
}
}
} else {
if (!SkipAMFObjectBody(stream)) {
if (!SkipAMFObjectBody(stream, depth + 1)) {
return false;
}
}
Expand Down Expand Up @@ -499,7 +516,11 @@ static bool ReadAMFObjectField(AMFInputStream* stream,
}

static bool ReadAMFObjectBody(google::protobuf::Message* message,
AMFInputStream* stream) {
AMFInputStream* stream,
int depth) {
if (!CheckAMFDepth(depth)) {
return false;
}
const google::protobuf::Descriptor* desc = message->GetDescriptor();
std::string name;
while (ReadAMFShortStringBody(&name, stream)) {
Expand All @@ -519,14 +540,17 @@ static bool ReadAMFObjectBody(google::protobuf::Message* message,
const google::protobuf::FieldDescriptor* field = desc->FindFieldByName(name);
RPC_VLOG_IF(field == NULL) << "Unknown field=" << desc->full_name()
<< "." << name;
if (!ReadAMFObjectField(stream, message, field)) {
if (!ReadAMFObjectField(stream, message, field, depth)) {
return false;
}
}
return true;
}

static bool SkipAMFObjectBody(AMFInputStream* stream) {
static bool SkipAMFObjectBody(AMFInputStream* stream, int depth) {
if (!CheckAMFDepth(depth)) {
return false;
}
std::string name;
while (ReadAMFShortStringBody(&name, stream)) {
if (name.empty()) {
Expand All @@ -542,15 +566,19 @@ static bool SkipAMFObjectBody(AMFInputStream* stream) {
}
break;
}
if (!ReadAMFObjectField(stream, NULL, NULL)) {
if (!ReadAMFObjectField(stream, NULL, NULL, depth)) {
return false;
}
}
return true;
}

static bool ReadAMFEcmaArrayBody(google::protobuf::Message* message,
AMFInputStream* stream) {
AMFInputStream* stream,
int depth) {
if (!CheckAMFDepth(depth)) {
return false;
}
uint32_t count = 0;
if (stream->cut_u32(&count) != 4u) {
LOG(ERROR) << "stream is not long enough";
Expand All @@ -566,7 +594,7 @@ static bool ReadAMFEcmaArrayBody(google::protobuf::Message* message,
const google::protobuf::FieldDescriptor* field = desc->FindFieldByName(name);
RPC_VLOG_IF(field == NULL) << "Unknown field=" << desc->full_name()
<< "." << name;
if (!ReadAMFObjectField(stream, message, field)) {
if (!ReadAMFObjectField(stream, message, field, depth)) {
return false;
}
}
Expand All @@ -580,11 +608,11 @@ bool ReadAMFObject(google::protobuf::Message* msg, AMFInputStream* stream) {
return false;
}
if ((AMFMarker)marker == AMF_MARKER_OBJECT) {
if (!ReadAMFObjectBody(msg, stream)) {
if (!ReadAMFObjectBody(msg, stream, 0)) {
return false;
}
} else if ((AMFMarker)marker == AMF_MARKER_ECMA_ARRAY) {
if (!ReadAMFEcmaArrayBody(msg, stream)) {
if (!ReadAMFEcmaArrayBody(msg, stream, 0)) {
return false;
}
} else if ((AMFMarker)marker != AMF_MARKER_NULL) {
Expand All @@ -602,13 +630,17 @@ bool ReadAMFObject(google::protobuf::Message* msg, AMFInputStream* stream) {

// [Reading AMFObject]

static bool ReadAMFObjectBody(AMFObject* obj, AMFInputStream* stream);
static bool ReadAMFEcmaArrayBody(AMFObject* obj, AMFInputStream* stream);
static bool ReadAMFArrayBody(AMFArray* arr, AMFInputStream* stream);
static bool ReadAMFObjectBody(AMFObject* obj, AMFInputStream* stream, int depth);
static bool ReadAMFEcmaArrayBody(AMFObject* obj, AMFInputStream* stream, int depth);
static bool ReadAMFArrayBody(AMFArray* arr, AMFInputStream* stream, int depth);

static bool ReadAMFObjectField(AMFInputStream* stream,
AMFObject* obj,
const std::string& name) {
const std::string& name,
int depth) {
if (!CheckAMFDepth(depth)) {
return false;
}
uint8_t marker;
if (stream->cut_u8(&marker) != 1u) {
LOG(ERROR) << "stream is not long enough";
Expand Down Expand Up @@ -647,17 +679,17 @@ static bool ReadAMFObjectField(AMFInputStream* stream,
}
// fall through
case AMF_MARKER_OBJECT: {
if (!ReadAMFObjectBody(obj->MutableObject(name), stream)) {
if (!ReadAMFObjectBody(obj->MutableObject(name), stream, depth + 1)) {
return false;
}
} break;
case AMF_MARKER_ECMA_ARRAY: {
if (!ReadAMFEcmaArrayBody(obj->MutableObject(name), stream)) {
if (!ReadAMFEcmaArrayBody(obj->MutableObject(name), stream, depth + 1)) {
return false;
}
} break;
case AMF_MARKER_STRICT_ARRAY: {
if (!ReadAMFArrayBody(obj->MutableArray(name), stream)) {
if (!ReadAMFArrayBody(obj->MutableArray(name), stream, depth + 1)) {
return false;
}
} break;
Expand Down Expand Up @@ -693,7 +725,10 @@ static bool ReadAMFObjectField(AMFInputStream* stream,
return true;
}

static bool ReadAMFObjectBody(AMFObject* obj, AMFInputStream* stream) {
static bool ReadAMFObjectBody(AMFObject* obj, AMFInputStream* stream, int depth) {
if (!CheckAMFDepth(depth)) {
return false;
}
std::string name;
while (ReadAMFShortStringBody(&name, stream)) {
if (name.empty()) {
Expand All @@ -709,14 +744,17 @@ static bool ReadAMFObjectBody(AMFObject* obj, AMFInputStream* stream) {
}
break;
}
if (!ReadAMFObjectField(stream, obj, name)) {
if (!ReadAMFObjectField(stream, obj, name, depth)) {
return false;
}
}
return true;
}

static bool ReadAMFEcmaArrayBody(AMFObject* obj, AMFInputStream* stream) {
static bool ReadAMFEcmaArrayBody(AMFObject* obj, AMFInputStream* stream, int depth) {
if (!CheckAMFDepth(depth)) {
return false;
}
uint32_t count = 0;
if (stream->cut_u32(&count) != 4u) {
LOG(ERROR) << "stream is not long enough";
Expand All @@ -728,7 +766,7 @@ static bool ReadAMFEcmaArrayBody(AMFObject* obj, AMFInputStream* stream) {
LOG(ERROR) << "Fail to read name from the stream";
return false;
}
if (!ReadAMFObjectField(stream, obj, name)) {
if (!ReadAMFObjectField(stream, obj, name, depth)) {
return false;
}
}
Expand All @@ -742,11 +780,11 @@ bool ReadAMFObject(AMFObject* obj, AMFInputStream* stream) {
return false;
}
if ((AMFMarker)marker == AMF_MARKER_OBJECT) {
if (!ReadAMFObjectBody(obj, stream)) {
if (!ReadAMFObjectBody(obj, stream, 0)) {
return false;
}
} else if ((AMFMarker)marker == AMF_MARKER_ECMA_ARRAY) {
if (!ReadAMFEcmaArrayBody(obj, stream)) {
if (!ReadAMFEcmaArrayBody(obj, stream, 0)) {
return false;
}
} else if ((AMFMarker)marker != AMF_MARKER_NULL) {
Expand All @@ -757,7 +795,10 @@ bool ReadAMFObject(AMFObject* obj, AMFInputStream* stream) {
return true;
}

static bool ReadAMFArrayItem(AMFInputStream* stream, AMFArray* arr) {
static bool ReadAMFArrayItem(AMFInputStream* stream, AMFArray* arr, int depth) {
if (!CheckAMFDepth(depth)) {
return false;
}
uint8_t marker;
if (stream->cut_u8(&marker) != 1u) {
LOG(ERROR) << "stream is not long enough";
Expand Down Expand Up @@ -796,17 +837,17 @@ static bool ReadAMFArrayItem(AMFInputStream* stream, AMFArray* arr) {
}
// fall through
case AMF_MARKER_OBJECT: {
if (!ReadAMFObjectBody(arr->AddObject(), stream)) {
if (!ReadAMFObjectBody(arr->AddObject(), stream, depth + 1)) {
return false;
}
} break;
case AMF_MARKER_ECMA_ARRAY: {
if (!ReadAMFEcmaArrayBody(arr->AddObject(), stream)) {
if (!ReadAMFEcmaArrayBody(arr->AddObject(), stream, depth + 1)) {
return false;
}
} break;
case AMF_MARKER_STRICT_ARRAY: {
if (!ReadAMFArrayBody(arr->AddArray(), stream)) {
if (!ReadAMFArrayBody(arr->AddArray(), stream, depth + 1)) {
return false;
}
} break;
Expand Down Expand Up @@ -842,14 +883,17 @@ static bool ReadAMFArrayItem(AMFInputStream* stream, AMFArray* arr) {
return true;
}

static bool ReadAMFArrayBody(AMFArray* arr, AMFInputStream* stream) {
static bool ReadAMFArrayBody(AMFArray* arr, AMFInputStream* stream, int depth) {
if (!CheckAMFDepth(depth)) {
return false;
}
uint32_t count = 0;
if (stream->cut_u32(&count) != 4u) {
LOG(ERROR) << "stream is not long enough";
return false;
}
for (uint32_t i = 0; i < count; ++i) {
if (!ReadAMFArrayItem(stream, arr)) {
if (!ReadAMFArrayItem(stream, arr, depth)) {
return false;
}
}
Expand All @@ -863,7 +907,7 @@ bool ReadAMFArray(AMFArray* arr, AMFInputStream* stream) {
return false;
}
if ((AMFMarker)marker == AMF_MARKER_STRICT_ARRAY) {
if (!ReadAMFArrayBody(arr, stream)) {
if (!ReadAMFArrayBody(arr, stream, 0)) {
return false;
}
} else if ((AMFMarker)marker != AMF_MARKER_NULL) {
Expand Down
15 changes: 13 additions & 2 deletions src/brpc/redis_reply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ namespace brpc {

DEFINE_int32(redis_max_allocation_size, 64 * 1024 * 1024,
"Maximum memory allocation size in bytes for a single redis request or reply (64MB by default)");
DEFINE_int32(redis_max_reply_depth, 128,
"Maximum nesting depth for redis array replies");

//BAIDU_CASSERT(sizeof(RedisReply) == 24, size_match);
const int RedisReply::npos = -1;
Expand Down Expand Up @@ -94,12 +96,21 @@ bool RedisReply::SerializeTo(butil::IOBufAppender* appender) {
}

ParseError RedisReply::ConsumePartialIOBuf(butil::IOBuf& buf) {
return ConsumePartialIOBuf(buf, 0);
}

ParseError RedisReply::ConsumePartialIOBuf(butil::IOBuf& buf, int depth) {
if (depth > FLAGS_redis_max_reply_depth) {
LOG(ERROR) << "redis reply exceeds max depth! max="
<< FLAGS_redis_max_reply_depth << ", actually=" << depth;
return PARSE_ERROR_ABSOLUTELY_WRONG;
}
if (_type == REDIS_REPLY_ARRAY && _data.array.last_index >= 0) {
// The parsing was suspended while parsing sub replies,
// continue the parsing.
RedisReply* subs = (RedisReply*)_data.array.replies;
for (int i = _data.array.last_index; i < _length; ++i) {
ParseError err = subs[i].ConsumePartialIOBuf(buf);
ParseError err = subs[i].ConsumePartialIOBuf(buf, depth + 1);
if (err != PARSE_OK) {
return err;
}
Expand Down Expand Up @@ -257,7 +268,7 @@ ParseError RedisReply::ConsumePartialIOBuf(butil::IOBuf& buf) {
// be continued in next calls by tracking _data.array.last_index.
_data.array.last_index = 0;
for (int64_t i = 0; i < count; ++i) {
ParseError err = subs[i].ConsumePartialIOBuf(buf);
ParseError err = subs[i].ConsumePartialIOBuf(buf, depth + 1);
if (err != PARSE_OK) {
return err;
}
Expand Down
1 change: 1 addition & 0 deletions src/brpc/redis_reply.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class RedisReply {
// by calling CopyFrom[Different|Same]Arena.
DISALLOW_COPY_AND_ASSIGN(RedisReply);

ParseError ConsumePartialIOBuf(butil::IOBuf& buf, int depth);
void FormatStringImpl(const char* fmt, va_list args, RedisReplyType type);
void SetStringImpl(const butil::StringPiece& str, RedisReplyType type);

Expand Down
Loading
Loading