From 39f88328370362814a73b2a3118ae9607cfa9c0c Mon Sep 17 00:00:00 2001 From: kai lin Date: Wed, 19 Nov 2025 15:04:27 -0500 Subject: [PATCH] refactor: centralize aws-chunked encoding in ChunkingInterceptor Move aws-chunked encoding logic from individual HTTP clients to a centralized ChunkingInterceptor for better separation of concerns. - Add ChunkingInterceptor to handle aws-chunked encoding at request level - Remove custom chunking logic from CRT, Curl, and Windows HTTP clients - Simplify HTTP clients to focus on transport-only responsibilities - Maintain full backwards compatibility with existing APIs unit test for chunking stream added logic to detect custom http client and smart default reversing logic to check for chunked mode changing chunking interceptor to use array instead of vector --- .../aws/core/client/ClientConfiguration.h | 16 ++ .../include/aws/core/http/HttpClient.h | 5 + .../include/aws/core/http/crt/CRTHttpClient.h | 2 + .../aws/core/http/curl/CurlHttpClient.h | 2 + .../http/windows/IXmlHttpRequest2HttpClient.h | 2 + .../core/http/windows/WinHttpSyncHttpClient.h | 2 + .../core/http/windows/WinINetSyncHttpClient.h | 2 + .../smithy/client/AwsSmithyClientBase.h | 9 +- .../client/features/ChunkingInterceptor.h | 225 ++++++++++++++++++ .../source/auth/signer/AWSAuthV4Signer.cpp | 22 +- .../source/client/AWSClient.cpp | 7 +- .../source/client/ClientConfiguration.cpp | 1 + .../source/client/UserAgent.cpp | 9 + .../source/http/crt/CRTHttpClient.cpp | 7 +- .../source/http/curl/CurlHttpClient.cpp | 25 +- .../source/http/windows/WinSyncHttpClient.cpp | 46 +--- .../smithy/client/AwsSmithyClientBase.cpp | 8 +- .../utils/stream/ChunkingInterceptorTest.cpp | 162 +++++++++++++ 18 files changed, 458 insertions(+), 94 deletions(-) create mode 100644 src/aws-cpp-sdk-core/include/smithy/client/features/ChunkingInterceptor.h create mode 100644 tests/aws-cpp-sdk-core-tests/utils/stream/ChunkingInterceptorTest.cpp diff --git a/src/aws-cpp-sdk-core/include/aws/core/client/ClientConfiguration.h b/src/aws-cpp-sdk-core/include/aws/core/client/ClientConfiguration.h index 0249d7d0124..360b39a1622 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/client/ClientConfiguration.h +++ b/src/aws-cpp-sdk-core/include/aws/core/client/ClientConfiguration.h @@ -78,6 +78,16 @@ namespace Aws WHEN_REQUIRED, }; + /** + * Control HTTP client chunking implementation mode. + * DEFAULT: Use SDK's ChunkingInterceptor for aws-chunked encoding + * CLIENT_IMPLEMENTATION: Rely on HTTP client's native chunking (default for custom clients) + */ + enum class HttpClientChunkedMode { + DEFAULT, + CLIENT_IMPLEMENTATION, + }; + struct RequestCompressionConfig { UseRequestCompression useRequestCompression=UseRequestCompression::ENABLE; size_t requestMinCompressionSizeBytes = 10240; @@ -493,6 +503,12 @@ namespace Aws * https://docs.aws.amazon.com/sdkref/latest/guide/feature-account-endpoints.html */ Aws::String accountIdEndpointMode = "preferred"; + + /** + * Control HTTP client chunking implementation mode. + * Default is set automatically: CLIENT_IMPLEMENTATION for custom clients, DEFAULT for AWS clients. + */ + HttpClientChunkedMode httpClientChunkedMode = HttpClientChunkedMode::CLIENT_IMPLEMENTATION; /** * Configuration structure for credential providers in the AWS SDK. * This structure allows passing configuration options to credential providers diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/HttpClient.h b/src/aws-cpp-sdk-core/include/aws/core/http/HttpClient.h index cb6e928e768..d38c77f4dc0 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/http/HttpClient.h +++ b/src/aws-cpp-sdk-core/include/aws/core/http/HttpClient.h @@ -48,6 +48,11 @@ namespace Aws */ virtual bool SupportsChunkedTransferEncoding() const { return true; } + /** + * Returns true if this is a default AWS SDK HTTP client implementation. + */ + virtual bool IsDefaultAwsHttpClient() const { return false; } + /** * Stops all requests in progress and prevents any others from initiating. */ diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/crt/CRTHttpClient.h b/src/aws-cpp-sdk-core/include/aws/core/http/crt/CRTHttpClient.h index e5cb2533387..a0a87619042 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/http/crt/CRTHttpClient.h +++ b/src/aws-cpp-sdk-core/include/aws/core/http/crt/CRTHttpClient.h @@ -52,6 +52,8 @@ namespace Aws Aws::Utils::RateLimits::RateLimiterInterface* readLimiter, Aws::Utils::RateLimits::RateLimiterInterface* writeLimiter) const override; + bool IsDefaultAwsHttpClient() const override { return true; } + private: // Yeah, I know, but someone made MakeRequest() const and didn't think about the fact that // making an HTTP request most certainly mutates state. It was me. I'm the person that did that, and diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/curl/CurlHttpClient.h b/src/aws-cpp-sdk-core/include/aws/core/http/curl/CurlHttpClient.h index 924cd59d830..087ed8d2c6f 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/http/curl/CurlHttpClient.h +++ b/src/aws-cpp-sdk-core/include/aws/core/http/curl/CurlHttpClient.h @@ -37,6 +37,8 @@ class AWS_CORE_API CurlHttpClient: public HttpClient Aws::Utils::RateLimits::RateLimiterInterface* readLimiter = nullptr, Aws::Utils::RateLimits::RateLimiterInterface* writeLimiter = nullptr) const override; + bool IsDefaultAwsHttpClient() const override { return true; } + static void InitGlobalState(); static void CleanupGlobalState(); diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/windows/IXmlHttpRequest2HttpClient.h b/src/aws-cpp-sdk-core/include/aws/core/http/windows/IXmlHttpRequest2HttpClient.h index 24a427edee7..995b20197a6 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/http/windows/IXmlHttpRequest2HttpClient.h +++ b/src/aws-cpp-sdk-core/include/aws/core/http/windows/IXmlHttpRequest2HttpClient.h @@ -54,6 +54,8 @@ namespace Aws */ virtual bool SupportsChunkedTransferEncoding() const override { return false; } + bool IsDefaultAwsHttpClient() const override { return true; } + protected: /** * Override any configuration on request handle. diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinHttpSyncHttpClient.h b/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinHttpSyncHttpClient.h index 61e3b0c4a3c..4c630780e37 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinHttpSyncHttpClient.h +++ b/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinHttpSyncHttpClient.h @@ -42,6 +42,8 @@ namespace Aws */ const char* GetLogTag() const override { return "WinHttpSyncHttpClient"; } + bool IsDefaultAwsHttpClient() const override { return true; } + private: // WinHttp specific implementations void* OpenRequest(const std::shared_ptr& request, void* connection, const Aws::StringStream& ss) const override; diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinINetSyncHttpClient.h b/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinINetSyncHttpClient.h index 52a1ce2d8f4..51b2680c4a6 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinINetSyncHttpClient.h +++ b/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinINetSyncHttpClient.h @@ -39,6 +39,8 @@ namespace Aws * Gets log tag for use in logging in the base class. */ const char* GetLogTag() const override { return "WinInetSyncHttpClient"; } + + bool IsDefaultAwsHttpClient() const override { return true; } private: // WinHttp specific implementations diff --git a/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientBase.h b/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientBase.h index b808fe2bf54..bd3f2704380 100644 --- a/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientBase.h +++ b/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientBase.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -20,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -99,8 +101,13 @@ namespace client m_serviceUserAgentName(std::move(serviceUserAgentName)), m_httpClient(std::move(httpClient)), m_errorMarshaller(std::move(errorMarshaller)), - m_interceptors{Aws::MakeShared("AwsSmithyClientBase", *m_clientConfig)} + m_interceptors({ + Aws::MakeShared("AwsSmithyClientBase", *m_clientConfig), + Aws::MakeShared("AwsSmithyClientBase", + m_httpClient->IsDefaultAwsHttpClient() ? Aws::Client::HttpClientChunkedMode::DEFAULT : m_clientConfig->httpClientChunkedMode) + }) { + baseInit(); } diff --git a/src/aws-cpp-sdk-core/include/smithy/client/features/ChunkingInterceptor.h b/src/aws-cpp-sdk-core/include/smithy/client/features/ChunkingInterceptor.h new file mode 100644 index 00000000000..2ed637db883 --- /dev/null +++ b/src/aws-cpp-sdk-core/include/smithy/client/features/ChunkingInterceptor.h @@ -0,0 +1,225 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace smithy { +namespace client { +namespace features { + +static const size_t AWS_DATA_BUFFER_SIZE = 65536; +static const char* ALLOCATION_TAG = "ChunkingInterceptor"; +static const char* CHECKSUM_HEADER_PREFIX = "x-amz-checksum-"; + +template +class AwsChunkedStreamBuf : public std::streambuf { +public: + AwsChunkedStreamBuf(Aws::Http::HttpRequest* request, + const std::shared_ptr& stream, + size_t bufferSize = DataBufferSize) + : m_request(request), + m_stream(stream), + m_data(bufferSize) + { + assert(m_stream != nullptr); + if (m_stream == nullptr) { + AWS_LOGSTREAM_ERROR("AwsChunkedStream", "stream is null"); + } + assert(m_request != nullptr); + if (m_request == nullptr) { + AWS_LOGSTREAM_ERROR("AwsChunkedStream", "request is null"); + } + } + +protected: + int_type underflow() override { + if (gptr() && gptr() < egptr()) { + return traits_type::to_int_type(*gptr()); + } + + // only read and write to chunked stream if the underlying stream + // is still in a valid state and we have buffer space + if (m_stream->good() && m_chunkingBufferPos >= m_chunkingBufferSize) { + // Reset buffer for new data only when buffer is consumed + m_chunkingBufferPos = 0; + m_chunkingBufferSize = 0; + + // Check if we have enough space for worst-case chunk (data + header + footer) + size_t maxChunkSize = m_data.GetLength() + 20; // data + hex header + CRLF + if (m_chunkingBufferSize + maxChunkSize <= m_chunkingBuffer.GetLength()) { + // Try to read in a 64K chunk, if we cant we know the stream is over + m_stream->read(m_data.GetUnderlyingData(), m_data.GetLength()); + size_t bytesRead = static_cast(m_stream->gcount()); + writeChunk(bytesRead); + + // if we've read everything from the stream, we want to add the trailer + // to the underlying stream + if ((m_stream->peek() == EOF || m_stream->eof()) && !m_stream->bad()) { + writeTrailerToUnderlyingStream(); + } + } + } + + // if the chunking buffer is empty there is nothing to read + if (m_chunkingBufferPos >= m_chunkingBufferSize) { + return traits_type::eof(); + } + + // Set up buffer pointers to read from chunking buffer + size_t remainingBytes = m_chunkingBufferSize - m_chunkingBufferPos; + size_t bytesToRead = std::min(remainingBytes, DataBufferSize); + + setg(m_chunkingBuffer.GetUnderlyingData() + m_chunkingBufferPos, + m_chunkingBuffer.GetUnderlyingData() + m_chunkingBufferPos, + m_chunkingBuffer.GetUnderlyingData() + m_chunkingBufferPos + bytesToRead); + + m_chunkingBufferPos += bytesToRead; + + return traits_type::to_int_type(*gptr()); + } + +private: + void writeTrailerToUnderlyingStream() { + Aws::String trailer = "0\r\n"; + if (m_request->GetRequestHash().second != nullptr) { + trailer += "x-amz-checksum-" + m_request->GetRequestHash().first + ":" + + Aws::Utils::HashingUtils::Base64Encode(m_request->GetRequestHash().second->GetHash().GetResult()) + "\r\n"; + } + trailer += "\r\n"; + if (m_chunkingBufferSize + trailer.length() <= m_chunkingBuffer.GetLength()) { + std::memcpy(m_chunkingBuffer.GetUnderlyingData() + m_chunkingBufferSize, trailer.c_str(), trailer.length()); + m_chunkingBufferSize += trailer.length(); + } + } + + void writeChunk(size_t bytesRead) { + if (m_request->GetRequestHash().second != nullptr) { + m_request->GetRequestHash().second->Update(reinterpret_cast(m_data.GetUnderlyingData()), bytesRead); + } + + if (bytesRead > 0) { + Aws::String chunkHeader = Aws::Utils::StringUtils::ToHexString(bytesRead) + "\r\n"; + size_t totalSize = chunkHeader.length() + bytesRead + 2; + if (m_chunkingBufferSize + totalSize <= m_chunkingBuffer.GetLength()) { + std::memcpy(m_chunkingBuffer.GetUnderlyingData() + m_chunkingBufferSize, chunkHeader.c_str(), chunkHeader.length()); + m_chunkingBufferSize += chunkHeader.length(); + std::memcpy(m_chunkingBuffer.GetUnderlyingData() + m_chunkingBufferSize, m_data.GetUnderlyingData(), bytesRead); + m_chunkingBufferSize += bytesRead; + std::memcpy(m_chunkingBuffer.GetUnderlyingData() + m_chunkingBufferSize, "\r\n", 2); + m_chunkingBufferSize += 2; + } + } + } + + // Buffer for chunked data plus overhead for HTTP chunked encoding headers, trailers, and safety margin + Aws::Utils::Array m_chunkingBuffer{DataBufferSize + 128}; + size_t m_chunkingBufferSize{0}; + size_t m_chunkingBufferPos{0}; + Aws::Http::HttpRequest* m_request{nullptr}; + std::shared_ptr m_stream; + Aws::Utils::Array m_data; +}; + +class AwsChunkedIOStream : public Aws::IOStream { +public: + AwsChunkedIOStream(Aws::Http::HttpRequest* request, + const std::shared_ptr& originalBody, + size_t bufferSize = AWS_DATA_BUFFER_SIZE) + : Aws::IOStream(&m_buf), + m_buf(request, originalBody, bufferSize) {} + +private: + AwsChunkedStreamBuf<> m_buf; +}; + +/** + * Interceptor that handles chunked encoding for streaming requests with checksums. + * Wraps request body with chunked stream and sets appropriate headers. + */ +class ChunkingInterceptor : public smithy::interceptor::Interceptor { +public: + explicit ChunkingInterceptor(Aws::Client::HttpClientChunkedMode httpClientChunkedMode) + : m_httpClientChunkedMode(httpClientChunkedMode) {} + ~ChunkingInterceptor() override = default; + + ModifyRequestOutcome ModifyBeforeSigning(smithy::interceptor::InterceptorContext& context) override { + auto request = context.GetTransmitRequest(); + + if (!ShouldApplyChunking(request)) { + return request; + } + + auto originalBody = request->GetContentBody(); + if (!originalBody) { + return request; + } + + // Set up chunked encoding headers for checksum calculation + const auto& hashPair = request->GetRequestHash(); + if (hashPair.second != nullptr) { + Aws::String checksumHeaderValue = Aws::String(CHECKSUM_HEADER_PREFIX) + hashPair.first; + request->DeleteHeader(checksumHeaderValue.c_str()); + request->SetHeaderValue(Aws::Http::AWS_TRAILER_HEADER, checksumHeaderValue); + request->SetTransferEncoding(Aws::Http::CHUNKED_VALUE); + + if (!request->HasContentEncoding()) { + request->SetContentEncoding(Aws::Http::AWS_CHUNKED_VALUE); + } else { + Aws::String currentEncoding = request->GetContentEncoding(); + if (currentEncoding.find(Aws::Http::AWS_CHUNKED_VALUE) == Aws::String::npos) { + request->SetContentEncoding(Aws::String{Aws::Http::AWS_CHUNKED_VALUE} + "," + currentEncoding); + } + } + + if (request->HasHeader(Aws::Http::CONTENT_LENGTH_HEADER)) { + request->SetHeaderValue(Aws::Http::DECODED_CONTENT_LENGTH_HEADER, request->GetHeaderValue(Aws::Http::CONTENT_LENGTH_HEADER)); + request->DeleteHeader(Aws::Http::CONTENT_LENGTH_HEADER); + } + } + + auto chunkedBody = Aws::MakeShared( + ALLOCATION_TAG, request.get(), originalBody); + + request->AddContentBody(chunkedBody); + return request; + } + + ModifyResponseOutcome ModifyBeforeDeserialization(smithy::interceptor::InterceptorContext& context) override { + return context.GetTransmitResponse(); + } + +private: + bool ShouldApplyChunking(const std::shared_ptr& request) const { + // Use configuration setting to determine chunking behavior + if (m_httpClientChunkedMode != Aws::Client::HttpClientChunkedMode::DEFAULT) { + return false; + } + + if (!request || !request->GetContentBody()) { + return false; + } + + // Check if request has checksum requirements + const auto& hashPair = request->GetRequestHash(); + return hashPair.second != nullptr; + } + + Aws::Client::HttpClientChunkedMode m_httpClientChunkedMode; +}; + +} // namespace features +} // namespace client +} // namespace smithy \ No newline at end of file diff --git a/src/aws-cpp-sdk-core/source/auth/signer/AWSAuthV4Signer.cpp b/src/aws-cpp-sdk-core/source/auth/signer/AWSAuthV4Signer.cpp index 1fc6094955c..bcca2d2e602 100644 --- a/src/aws-cpp-sdk-core/source/auth/signer/AWSAuthV4Signer.cpp +++ b/src/aws-cpp-sdk-core/source/auth/signer/AWSAuthV4Signer.cpp @@ -218,26 +218,10 @@ bool AWSAuthV4Signer::SignRequestWithCreds(Aws::Http::HttpRequest& request, cons request.SetAwsSessionToken(credentials.GetSessionToken()); } - // If the request checksum, set the signer to use a unsigned - // trailing payload. otherwise use it in the header - if (request.GetRequestHash().second != nullptr && !request.GetRequestHash().first.empty() && request.GetContentBody() != nullptr) { - AWS_LOGSTREAM_DEBUG(v4LogTag, "Note: Http payloads are not being signed. signPayloads=" - << signBody << " http scheme=" << Http::SchemeMapper::ToString(request.GetUri().GetScheme())); - if (request.GetRequestHash().second != nullptr) { + // If the request has checksum and chunking was applied by interceptor, use streaming payload + if (request.GetRequestHash().second != nullptr && !request.GetRequestHash().first.empty() && + request.GetContentBody() != nullptr && request.HasHeader(Http::AWS_TRAILER_HEADER)) { payloadHash = STREAMING_UNSIGNED_PAYLOAD_TRAILER; - Aws::String checksumHeaderValue = Aws::String("x-amz-checksum-") + request.GetRequestHash().first; - request.DeleteHeader(checksumHeaderValue.c_str()); - request.SetHeaderValue(Http::AWS_TRAILER_HEADER, checksumHeaderValue); - request.SetTransferEncoding(CHUNKED_VALUE); - request.HasContentEncoding() - ? request.SetContentEncoding(Aws::String{Http::AWS_CHUNKED_VALUE} + "," + request.GetContentEncoding()) - : request.SetContentEncoding(Http::AWS_CHUNKED_VALUE); - - if (request.HasHeader(Http::CONTENT_LENGTH_HEADER)) { - request.SetHeaderValue(Http::DECODED_CONTENT_LENGTH_HEADER, request.GetHeaderValue(Http::CONTENT_LENGTH_HEADER)); - request.DeleteHeader(Http::CONTENT_LENGTH_HEADER); - } - } } else { payloadHash = ComputePayloadHash(request); if (payloadHash.empty()) { diff --git a/src/aws-cpp-sdk-core/source/client/AWSClient.cpp b/src/aws-cpp-sdk-core/source/client/AWSClient.cpp index 1d4733f6eb6..90f83a58f33 100644 --- a/src/aws-cpp-sdk-core/source/client/AWSClient.cpp +++ b/src/aws-cpp-sdk-core/source/client/AWSClient.cpp @@ -46,6 +46,7 @@ #include #include +#include #include #include @@ -139,7 +140,8 @@ AWSClient::AWSClient(const Aws::Client::ClientConfiguration& configuration, m_enableClockSkewAdjustment(configuration.enableClockSkewAdjustment), m_requestCompressionConfig(configuration.requestCompressionConfig), m_userAgentInterceptor{Aws::MakeShared(AWS_CLIENT_LOG_TAG, configuration, m_retryStrategy->GetStrategyName(), m_serviceName)}, - m_interceptors{Aws::MakeShared(AWS_CLIENT_LOG_TAG), m_userAgentInterceptor} + m_interceptors{Aws::MakeShared(AWS_CLIENT_LOG_TAG), Aws::MakeShared(AWS_CLIENT_LOG_TAG, + m_httpClient->IsDefaultAwsHttpClient() ? Aws::Client::HttpClientChunkedMode::DEFAULT : configuration.httpClientChunkedMode), m_userAgentInterceptor} { } @@ -165,7 +167,8 @@ AWSClient::AWSClient(const Aws::Client::ClientConfiguration& configuration, m_enableClockSkewAdjustment(configuration.enableClockSkewAdjustment), m_requestCompressionConfig(configuration.requestCompressionConfig), m_userAgentInterceptor{Aws::MakeShared(AWS_CLIENT_LOG_TAG, configuration, m_retryStrategy->GetStrategyName(), m_serviceName)}, - m_interceptors{Aws::MakeShared(AWS_CLIENT_LOG_TAG, configuration), m_userAgentInterceptor} + m_interceptors{Aws::MakeShared(AWS_CLIENT_LOG_TAG, configuration), Aws::MakeShared(AWS_CLIENT_LOG_TAG, + m_httpClient->IsDefaultAwsHttpClient() ? Aws::Client::HttpClientChunkedMode::DEFAULT : configuration.httpClientChunkedMode), m_userAgentInterceptor} { } diff --git a/src/aws-cpp-sdk-core/source/client/ClientConfiguration.cpp b/src/aws-cpp-sdk-core/source/client/ClientConfiguration.cpp index 6eaf26e2a38..fe21ba5de11 100644 --- a/src/aws-cpp-sdk-core/source/client/ClientConfiguration.cpp +++ b/src/aws-cpp-sdk-core/source/client/ClientConfiguration.cpp @@ -220,6 +220,7 @@ void setLegacyClientConfigurationParameters(ClientConfiguration& clientConfig) clientConfig.writeRateLimiter = nullptr; clientConfig.readRateLimiter = nullptr; clientConfig.httpLibOverride = Aws::Http::TransferLibType::DEFAULT_CLIENT; + clientConfig.httpClientChunkedMode = HttpClientChunkedMode::CLIENT_IMPLEMENTATION; clientConfig.followRedirects = FollowRedirectsPolicy::DEFAULT; clientConfig.disableExpectHeader = false; clientConfig.enableClockSkewAdjustment = true; diff --git a/src/aws-cpp-sdk-core/source/client/UserAgent.cpp b/src/aws-cpp-sdk-core/source/client/UserAgent.cpp index 909184b447a..dd6fa87c3a7 100644 --- a/src/aws-cpp-sdk-core/source/client/UserAgent.cpp +++ b/src/aws-cpp-sdk-core/source/client/UserAgent.cpp @@ -183,6 +183,15 @@ Aws::String UserAgent::SerializeWithFeatures(const Aws::Set& f SerializeMetadata(METADATA, m_compilerMetadata); } + // Add HTTP client metadata +#if AWS_SDK_USE_CRT_HTTP + SerializeMetadata(METADATA, "http#crt"); +#elif ENABLE_CURL_CLIENT + SerializeMetadata(METADATA, "http#curl"); +#elif ENABLE_WINDOWS_CLIENT + SerializeMetadata(METADATA, "http#winhttp"); +#endif + // metrics Aws::Vector encodedMetrics{}; diff --git a/src/aws-cpp-sdk-core/source/http/crt/CRTHttpClient.cpp b/src/aws-cpp-sdk-core/source/http/crt/CRTHttpClient.cpp index 4c392bdf280..14c1ef25b0f 100644 --- a/src/aws-cpp-sdk-core/source/http/crt/CRTHttpClient.cpp +++ b/src/aws-cpp-sdk-core/source/http/crt/CRTHttpClient.cpp @@ -9,7 +9,6 @@ #include #include #include -#include #include #include @@ -379,11 +378,7 @@ namespace Aws if (request->GetContentBody()) { bool isStreaming = request->IsEventStreamRequest(); - if (request->HasHeader(Aws::Http::CONTENT_ENCODING_HEADER) && request->GetHeaderValue(Aws::Http::CONTENT_ENCODING_HEADER) == Aws::Http::AWS_CHUNKED_VALUE) { - crtRequest->SetBody(Aws::MakeShared>(CRT_HTTP_CLIENT_TAG, request.get(), request->GetContentBody())); - } else { - crtRequest->SetBody(Aws::MakeShared(CRT_HTTP_CLIENT_TAG, m_configuration.writeRateLimiter, request->GetContentBody(), *this, *request, isStreaming)); - } + crtRequest->SetBody(Aws::MakeShared(CRT_HTTP_CLIENT_TAG, m_configuration.writeRateLimiter, request->GetContentBody(), *this, *request, isStreaming)); } Crt::Http::HttpRequestOptions requestOptions; diff --git a/src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp b/src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp index 1ac37f63eaf..58fc56875de 100644 --- a/src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp +++ b/src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp @@ -14,7 +14,6 @@ #include #include #include -#include #include #include @@ -155,21 +154,16 @@ static const char* CURL_HTTP_CLIENT_TAG = "CurlHttpClient"; struct CurlReadCallbackContext { CurlReadCallbackContext(const CurlHttpClient* client, CURL* curlHandle, HttpRequest* request, - Aws::Utils::RateLimits::RateLimiterInterface* limiter, - std::shared_ptr> chunkedStream = nullptr) + Aws::Utils::RateLimits::RateLimiterInterface* limiter) : m_client(client), m_curlHandle(curlHandle), m_rateLimiter(limiter), - m_request(request), - m_chunkEnd(false), - m_chunkedStream{std::move(chunkedStream)} {} + m_request(request) {} const CurlHttpClient* m_client; CURL* m_curlHandle; Aws::Utils::RateLimits::RateLimiterInterface* m_rateLimiter; HttpRequest* m_request; - bool m_chunkEnd; - std::shared_ptr> m_chunkedStream; }; static int64_t GetContentLengthFromHeader(CURL* connectionHandle, @@ -315,8 +309,6 @@ static size_t ReadBody(char* ptr, size_t size, size_t nmemb, void* userdata, boo const std::shared_ptr& ioStream = request->GetContentBody(); size_t amountToRead = size * nmemb; - bool isAwsChunked = request->HasHeader(Aws::Http::CONTENT_ENCODING_HEADER) && - request->GetHeaderValue(Aws::Http::CONTENT_ENCODING_HEADER).find(Aws::Http::AWS_CHUNKED_VALUE) != Aws::String::npos; if (ioStream != nullptr && amountToRead > 0) { @@ -334,8 +326,6 @@ static size_t ReadBody(char* ptr, size_t size, size_t nmemb, void* userdata, boo return 0; } amountRead = (size_t)ioStream->readsome(ptr, amountToRead); - } else if (isAwsChunked && context->m_chunkedStream != nullptr) { - amountRead = context->m_chunkedStream->BufferedRead(ptr, amountToRead); } else { ioStream->read(ptr, amountToRead); amountRead = static_cast(ioStream->gcount()); @@ -380,7 +370,7 @@ static size_t SeekBody(void* userdata, curl_off_t offset, int origin) return CURL_SEEKFUNC_FAIL; } - // fail seek for aws-chunk encoded body as the length and offset is unknown + // Fail seek for aws-chunk encoded body as the length and offset is unknown if (context->m_request && context->m_request->HasHeader(Aws::Http::CONTENT_ENCODING_HEADER) && context->m_request->GetHeaderValue(Aws::Http::CONTENT_ENCODING_HEADER).find(Aws::Http::AWS_CHUNKED_VALUE) != Aws::String::npos) @@ -388,6 +378,7 @@ static size_t SeekBody(void* userdata, curl_off_t offset, int origin) return CURL_SEEKFUNC_FAIL; } + HttpRequest* request = context->m_request; const std::shared_ptr& ioStream = request->GetContentBody(); @@ -713,13 +704,7 @@ std::shared_ptr CurlHttpClient::MakeRequest(const std::shared_ptr< CurlWriteCallbackContext writeContext(this, connectionHandle ,request.get(), response.get(), readLimiter); - const auto readContext = [this, &connectionHandle, &request, &writeLimiter]() -> CurlReadCallbackContext { - if (request->GetContentBody() != nullptr) { - auto chunkedBodyPtr = Aws::MakeShared>(CURL_HTTP_CLIENT_TAG, request.get(), request->GetContentBody()); - return {this, connectionHandle, request.get(), writeLimiter, std::move(chunkedBodyPtr)}; - } - return {this, connectionHandle, request.get(), writeLimiter}; - }(); + CurlReadCallbackContext readContext(this, connectionHandle, request.get(), writeLimiter); SetOptCodeForHttpMethod(connectionHandle, request); diff --git a/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp b/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp index 7677e02052f..ee35bb5a81f 100644 --- a/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp +++ b/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp @@ -7,7 +7,6 @@ #include #include #include -#include #include #include #include @@ -100,21 +99,14 @@ bool WinSyncHttpClient::StreamPayloadToRequest(const std::shared_ptrHasTransferEncoding() && request->GetTransferEncoding() == Aws::Http::CHUNKED_VALUE; - bool isAwsChunked = request->HasHeader(Aws::Http::CONTENT_ENCODING_HEADER) && - request->GetHeaderValue(Aws::Http::CONTENT_ENCODING_HEADER).find(Aws::Http::AWS_CHUNKED_VALUE) != Aws::String::npos; auto payloadStream = request->GetContentBody(); - const char CRLF[] = "\r\n"; if(payloadStream) { uint64_t bytesWritten; uint64_t bytesToRead = HTTP_REQUEST_WRITE_BUFFER_LENGTH; auto startingPos = payloadStream->tellg(); bool done = false; - // aws-chunk = hex(chunk-size) + CRLF + chunk-data + CRLF - // Length of hex(HTTP_REQUEST_WRITE_BUFFER_LENGTH) is 4; - // Length of each CRLF is 2. - // Reserve 8 bytes in total, should the request be aws-chunked. - char streamBuffer[ HTTP_REQUEST_WRITE_BUFFER_LENGTH + 8 ]; + char streamBuffer[HTTP_REQUEST_WRITE_BUFFER_LENGTH]; while(success && !done) { payloadStream->read(streamBuffer, bytesToRead); @@ -124,21 +116,6 @@ bool WinSyncHttpClient::StreamPayloadToRequest(const std::shared_ptr 0) { - if (isAwsChunked) - { - if (request->GetRequestHash().second != nullptr) - { - request->GetRequestHash().second->Update(reinterpret_cast(streamBuffer), static_cast(bytesRead)); - } - - Aws::String hex = Aws::Utils::StringUtils::ToHexString(static_cast(bytesRead)); - memcpy(streamBuffer + hex.size() + 2, streamBuffer, static_cast(bytesRead)); - memcpy(streamBuffer + hex.size() + 2 + bytesRead, CRLF, 2); - memcpy(streamBuffer, hex.c_str(), hex.size()); - memcpy(streamBuffer + hex.size(), CRLF, 2); - bytesRead += hex.size() + 4; - } - bytesWritten = DoWriteData(hHttpRequest, streamBuffer, bytesRead, isChunked); if (!bytesWritten) { @@ -164,27 +141,6 @@ bool WinSyncHttpClient::StreamPayloadToRequest(const std::shared_ptrGetRequestHash().second != nullptr) - { - chunkedTrailer << "x-amz-checksum-" << request->GetRequestHash().first << ":" - << Aws::Utils::HashingUtils::Base64Encode(request->GetRequestHash().second->GetHash().GetResult()) << CRLF; - } - chunkedTrailer << CRLF; - bytesWritten = DoWriteData(hHttpRequest, const_cast(chunkedTrailer.str().c_str()), chunkedTrailer.str().size(), isChunked); - if (!bytesWritten) - { - success = false; - } - else if(writeLimiter) - { - writeLimiter->ApplyAndPayForCost(bytesWritten); - } - } - if (success && isChunked) { bytesWritten = FinalizeWriteData(hHttpRequest); diff --git a/src/aws-cpp-sdk-core/source/smithy/client/AwsSmithyClientBase.cpp b/src/aws-cpp-sdk-core/source/smithy/client/AwsSmithyClientBase.cpp index e799ab5a5a1..0d70b089e5c 100644 --- a/src/aws-cpp-sdk-core/source/smithy/client/AwsSmithyClientBase.cpp +++ b/src/aws-cpp-sdk-core/source/smithy/client/AwsSmithyClientBase.cpp @@ -25,6 +25,7 @@ #include using namespace smithy::client; +using namespace smithy::client::features; using namespace smithy::interceptor; using namespace smithy::components::tracing; @@ -102,7 +103,12 @@ void AwsSmithyClientBase::baseCopyAssign(const AwsSmithyClientBase& other, m_serviceUserAgentName = other.m_serviceUserAgentName; m_httpClient = std::move(httpClient); m_errorMarshaller = std::move(errorMarshaller); - m_interceptors = Aws::Vector>{Aws::MakeShared("AwsSmithyClientBase")}; + + m_interceptors = Aws::Vector>{ + Aws::MakeShared("AwsSmithyClientBase", *m_clientConfig), + Aws::MakeShared("AwsSmithyClientBase", + m_httpClient->IsDefaultAwsHttpClient() ? Aws::Client::HttpClientChunkedMode::DEFAULT : m_clientConfig->httpClientChunkedMode) + }; baseCopyInit(); } diff --git a/tests/aws-cpp-sdk-core-tests/utils/stream/ChunkingInterceptorTest.cpp b/tests/aws-cpp-sdk-core-tests/utils/stream/ChunkingInterceptorTest.cpp new file mode 100644 index 00000000000..742219e3505 --- /dev/null +++ b/tests/aws-cpp-sdk-core-tests/utils/stream/ChunkingInterceptorTest.cpp @@ -0,0 +1,162 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace Aws; +using namespace Aws::Http::Standard; +using namespace smithy::client::features; +using namespace Aws::Utils::Crypto; + +const char* CHUNKING_TEST_LOG_TAG = "CHUNKING_INTERCEPTOR_TEST"; + +// Mock implementation of AmazonWebServiceRequest +class MockRequest : public Aws::AmazonWebServiceRequest { +public: + std::shared_ptr GetBody() const override { return nullptr; } + Aws::Http::HeaderValueCollection GetHeaders() const override { return {}; } + const char* GetServiceRequestName() const override { return "MockRequest"; } +}; + +class ChunkingInterceptorTest : public Aws::Testing::AwsCppSdkGTestSuite { +protected: + template + void withChunkedStream(const std::string& input, size_t bufferSize, Fn&& fn) { + StandardHttpRequest request{"test.com", Http::HttpMethod::HTTP_GET}; + auto requestHash = Aws::MakeShared(CHUNKING_TEST_LOG_TAG); + request.SetRequestHash("crc32", requestHash); + auto inputStream = Aws::MakeShared(CHUNKING_TEST_LOG_TAG); + *inputStream << input; + + AwsChunkedIOStream wrapper{&request, inputStream, bufferSize}; + Aws::IOStream* stream = &wrapper; + + fn(*stream); + } +}; + +TEST_F(ChunkingInterceptorTest, ChunkedStreamShouldWork) { + withChunkedStream("1234567890123456789012345", 10, [](Aws::IOStream& chunkedStream) { + char buffer[100]; + std::stringstream output; + + // Read in 10-byte chunks + for (int i = 0; i < 4; i++) { + chunkedStream.read(buffer, 10); + auto bytesRead = chunkedStream.gcount(); + output.write(buffer, bytesRead); + } + + // Read trailing checksum (greater than 10 chars) + chunkedStream.read(buffer, 40); + auto bytesRead = static_cast(chunkedStream.gcount()); + EXPECT_EQ(36ul, bytesRead); + output.write(buffer, bytesRead); + + EXPECT_EQ("A\r\n1234567890\r\nA\r\n1234567890\r\n5\r\n12345\r\n0\r\nx-amz-checksum-crc32:78DeVw==\r\n\r\n", output.str()); + }); +} + +TEST_F(ChunkingInterceptorTest, ShouldNotRequireTwoReadsOnSmallChunk) { + withChunkedStream("12345", 100, [](Aws::IOStream& chunkedStream) { + char buffer[100]; + chunkedStream.read(buffer, 100); + auto bytesRead = static_cast(chunkedStream.gcount()); + EXPECT_EQ(46ul, bytesRead); + + std::string output(buffer, bytesRead); + EXPECT_EQ("5\r\n12345\r\n0\r\nx-amz-checksum-crc32:y/U6HA==\r\n\r\n", output); + }); +} + +TEST_F(ChunkingInterceptorTest, ShouldWorkOnSmallBuffer) { + withChunkedStream("1234567890", 5, [](Aws::IOStream& chunkedStream) { + char buffer[100]; + + // First read - explicitly ask for 10 bytes (first chunk: "5\r\n12345\r\n") + chunkedStream.read(buffer, 10); + auto bytesRead = static_cast(chunkedStream.gcount()); + EXPECT_EQ(10ul, bytesRead); + std::string firstRead(buffer, bytesRead); + EXPECT_EQ("5\r\n12345\r\n", firstRead); + + // Second read - now we expect the rest (46 bytes: second chunk + trailer) + chunkedStream.read(buffer, 100); + bytesRead = static_cast(chunkedStream.gcount()); + EXPECT_EQ(46ul, bytesRead); + std::string secondRead(buffer, bytesRead); + EXPECT_EQ("5\r\n67890\r\n0\r\nx-amz-checksum-crc32:Jh2u5Q==\r\n\r\n", secondRead); + + // Subsequent reads should return 0 + chunkedStream.read(buffer, 100); + bytesRead = static_cast(chunkedStream.gcount()); + EXPECT_EQ(0ul, bytesRead); + }); +} + +TEST_F(ChunkingInterceptorTest, ShouldWorkOnEmptyStream) { + withChunkedStream("", 5, [](Aws::IOStream& chunkedStream) { + char buffer[100]; + chunkedStream.read(buffer, 100); + auto bytesRead = static_cast(chunkedStream.gcount()); + EXPECT_EQ(36ul, bytesRead); + + std::string output(buffer, bytesRead); + EXPECT_EQ("0\r\nx-amz-checksum-crc32:AAAAAA==\r\n\r\n", output); + }); +} + +// Custom HTTP client (inherits default IsDefaultAwsHttpClient() = false from base class) +class CustomHttpClient : public Aws::Http::HttpClient { +public: + std::shared_ptr MakeRequest(const std::shared_ptr&, + Aws::Utils::RateLimits::RateLimiterInterface*, + Aws::Utils::RateLimits::RateLimiterInterface*) const override { + return nullptr; + } +}; + +TEST_F(ChunkingInterceptorTest, ShouldNotApplyChunkingForCustomHttpClient) { + // Simulate the GetChunkingConfig behavior from AWSClient.cpp + // When IsDefaultAwsHttpClient() returns false, httpClientChunkedMode is set to CLIENT_IMPLEMENTATION + Aws::Client::ClientConfiguration config; + auto customHttpClient = Aws::MakeShared(CHUNKING_TEST_LOG_TAG); + + // This simulates the logic in GetChunkingConfig function + if (!customHttpClient->IsDefaultAwsHttpClient()) { + config.httpClientChunkedMode = Aws::Client::HttpClientChunkedMode::CLIENT_IMPLEMENTATION; + } + + ChunkingInterceptor interceptor(config.httpClientChunkedMode); + + // Create request with checksum (would normally trigger chunking) + auto request = Aws::MakeShared(CHUNKING_TEST_LOG_TAG, "test.com", Http::HttpMethod::HTTP_POST); + auto requestHash = Aws::MakeShared(CHUNKING_TEST_LOG_TAG); + request->SetRequestHash("crc32", requestHash); + + auto inputStream = Aws::MakeShared(CHUNKING_TEST_LOG_TAG); + *inputStream << "test data"; + request->AddContentBody(inputStream); + + // Create interceptor context with a mock request + MockRequest mockRequest; + smithy::interceptor::InterceptorContext context(mockRequest); + context.SetTransmitRequest(request); + + // Apply interceptor + auto result = interceptor.ModifyBeforeSigning(context); + + // Verify chunking was NOT applied because custom HTTP client uses default IsDefaultAwsHttpClient() = false + EXPECT_EQ(request, result.GetResult()); + EXPECT_FALSE(request->HasHeader(Aws::Http::AWS_TRAILER_HEADER)); + EXPECT_FALSE(request->HasHeader(Aws::Http::TRANSFER_ENCODING_HEADER)); +} \ No newline at end of file