diff --git a/src/aws-cpp-sdk-core/include/aws/core/client/CoreErrors.h b/src/aws-cpp-sdk-core/include/aws/core/client/CoreErrors.h index 6507bbc5d56e..1f245145b6f8 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/client/CoreErrors.h +++ b/src/aws-cpp-sdk-core/include/aws/core/client/CoreErrors.h @@ -47,6 +47,7 @@ namespace Aws REQUEST_TIMEOUT = 24, NOT_INITIALIZED = 25, MEMORY_ALLOCATION = 26, + NOT_IMPLEMENTED = 27, NETWORK_CONNECTION = 99, // General failure to send message to service diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/HttpClient.h b/src/aws-cpp-sdk-core/include/aws/core/http/HttpClient.h index d38c77f4dc07..2116d0d579ed 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/http/HttpClient.h +++ b/src/aws-cpp-sdk-core/include/aws/core/http/HttpClient.h @@ -6,11 +6,14 @@ #pragma once #include +#include +#include +#include -#include #include -#include #include +#include +#include namespace Aws { @@ -77,6 +80,16 @@ namespace Aws return !m_bad; } + using AcquireConnectionOutcome = Aws::Utils::Outcome, + Aws::Client::AWSError>; + virtual AcquireConnectionOutcome AcquireConnection(const std::shared_ptr& request) { + AWS_UNREFERENCED_PARAM(request); + return Aws::Client::AWSError{Aws::Client::CoreErrors::NOT_IMPLEMENTED, + "NotImplemented", + "creating a connection is not supported on this http client", + false}; + } + protected: bool m_bad; diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/HttpClientStream.h b/src/aws-cpp-sdk-core/include/aws/core/http/HttpClientStream.h new file mode 100644 index 000000000000..5f48dec2063e --- /dev/null +++ b/src/aws-cpp-sdk-core/include/aws/core/http/HttpClientStream.h @@ -0,0 +1,25 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#pragma once +#include + +#include +#include + +namespace Aws { +namespace Http { +class HttpResponse; +class AWS_CORE_API ClientStream { + public: + virtual ~ClientStream() = default; + virtual bool Activate() = 0; + virtual int WriteData(std::shared_ptr stream, + const std::function& onComplete, + bool endStream = false) = 0; + virtual std::shared_ptr GetResponse() const = 0; +}; +} // namespace Http +} // namespace Aws diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/HttpConnection.h b/src/aws-cpp-sdk-core/include/aws/core/http/HttpConnection.h new file mode 100644 index 000000000000..869b88270b10 --- /dev/null +++ b/src/aws-cpp-sdk-core/include/aws/core/http/HttpConnection.h @@ -0,0 +1,23 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#pragma once +#include + +#include +#include + +namespace Aws { +namespace Http { +class HttpRequest; +class AWS_CORE_API Connection { + public: + virtual ~Connection() = default; + virtual std::shared_ptr NewClientStream( + const std::shared_ptr& request, + std::function onStreamComplete) = 0; +}; +} // namespace Http +} // namespace Aws diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/crt/CRTHttpClient.h b/src/aws-cpp-sdk-core/include/aws/core/http/crt/CRTHttpClient.h index a0a87619042b..31933915c32f 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/http/crt/CRTHttpClient.h +++ b/src/aws-cpp-sdk-core/include/aws/core/http/crt/CRTHttpClient.h @@ -54,6 +54,8 @@ namespace Aws bool IsDefaultAwsHttpClient() const override { return true; } + AcquireConnectionOutcome AcquireConnection(const std::shared_ptr& request) override; + private: // Yeah, I know, but someone made MakeRequest() const and didn't think about the fact that // making an HTTP request most certainly mutates state. It was me. I'm the person that did that, and diff --git a/src/aws-cpp-sdk-core/source/http/crt/CRTHttpClient.cpp b/src/aws-cpp-sdk-core/source/http/crt/CRTHttpClient.cpp index dded366705f9..38f845bfdc2f 100644 --- a/src/aws-cpp-sdk-core/source/http/crt/CRTHttpClient.cpp +++ b/src/aws-cpp-sdk-core/source/http/crt/CRTHttpClient.cpp @@ -12,8 +12,261 @@ #include #include +#include + static const char *const CRT_HTTP_CLIENT_TAG = "CRTHttpClient"; +namespace { +// Just a wrapper around a Condition Variable and a mutex, which handles wait and timed waits while protecting +// from spurious wakeups. +class AsyncWaiter { + public: + AsyncWaiter() = default; + AsyncWaiter(const AsyncWaiter&) = delete; + AsyncWaiter& operator=(const AsyncWaiter&) = delete; + + void Wakeup() { + std::lock_guard locker(m_lock); + m_wakeupIntentional = true; + m_cvar.notify_one(); + } + + void WaitOnCompletion() { + std::unique_lock uniqueLocker(m_lock); + m_cvar.wait(uniqueLocker, [this]() { return m_wakeupIntentional; }); + } + + bool WaitOnCompletionFor(const size_t ms) { + std::unique_lock uniqueLocker(m_lock); + return m_cvar.wait_for(uniqueLocker, std::chrono::milliseconds(ms), [this]() { return m_wakeupIntentional; }); + } + + private: + std::mutex m_lock; + std::condition_variable m_cvar; + bool m_wakeupIntentional{false}; +}; + +void AddRequestMetadataToCrtRequest(const std::shared_ptr& request, + const std::shared_ptr& crtRequest) { + const char* methodStr = Aws::Http::HttpMethodMapper::GetNameForHttpMethod(request->GetMethod()); + AWS_LOGSTREAM_TRACE(CRT_HTTP_CLIENT_TAG, "Making " << methodStr << " request to " << request->GetURIString()); + AWS_LOGSTREAM_TRACE(CRT_HTTP_CLIENT_TAG, "Including headers:"); + // Add http headers to the request. + for (const auto& header : request->GetHeaders()) { + Aws::Crt::Http::HttpHeader crtHeader; + AWS_LOGSTREAM_TRACE(CRT_HTTP_CLIENT_TAG, header.first << ": " << header.second); + crtHeader.name = Aws::Crt::ByteCursorFromArray((const uint8_t*)header.first.data(), header.first.length()); + crtHeader.value = Aws::Crt::ByteCursorFromArray((const uint8_t*)header.second.data(), header.second.length()); + crtRequest->AddHeader(crtHeader); + } + + // HTTP method, GET, PUT, DELETE, etc... + auto methodCursor = Aws::Crt::ByteCursorFromCString(methodStr); + crtRequest->SetMethod(methodCursor); + + // Path portion of the request + auto pathStrCpy = request->GetUri().GetURLEncodedPathRFC3986(); + auto queryStrCpy = request->GetUri().GetQueryString(); + Aws::StringStream ss; + + // CRT client has you pass the query string as part of the path. concatenate that here. + ss << pathStrCpy << queryStrCpy; + auto fullPathAndQueryCpy = ss.str(); + auto pathCursor = Aws::Crt::ByteCursorFromArray((uint8_t*)fullPathAndQueryCpy.c_str(), fullPathAndQueryCpy.length()); + crtRequest->SetPath(pathCursor); +} + +void OnResponseBodyReceived(Aws::Crt::Http::HttpStream&, const Aws::Crt::ByteCursor& body, + const std::shared_ptr& response, + const std::shared_ptr& request) { + assert(response); + for (const auto& hashIterator : request->GetResponseValidationHashes()) { + std::stringstream headerStr; + headerStr << "x-amz-checksum-" << hashIterator.first; + if (response->HasHeader(headerStr.str().c_str())) { + hashIterator.second->Update(reinterpret_cast(body.ptr), body.len); + break; + } + } + + // When data is received from the content body of the incoming response, just copy it to the output stream. + response->GetResponseBody().write((const char*)body.ptr, static_cast(body.len)); + if (response->GetResponseBody().fail()) { + const auto& ref = response->GetResponseBody(); + AWS_LOGSTREAM_ERROR(CRT_HTTP_CLIENT_TAG, "Failed to write " << body.len << " (eof: " << ref.eof() << ", bad: " << ref.bad() << ")"); + } + + if (request->IsEventStreamRequest() && !response->HasHeader(Aws::Http::X_AMZN_ERROR_TYPE)) { + response->GetResponseBody().flush(); + } + + auto& receivedHandler = request->GetDataReceivedEventHandler(); + if (receivedHandler) { + receivedHandler(request.get(), response.get(), static_cast(body.len)); + } + + AWS_LOGSTREAM_TRACE(CRT_HTTP_CLIENT_TAG, body.len << " bytes written to response."); +} + +// on response headers arriving, write them to the response. +void OnIncomingHeaders(Aws::Crt::Http::HttpStream&, enum aws_http_header_block block, const Aws::Crt::Http::HttpHeader* headersArray, + std::size_t headersCount, const std::shared_ptr& response) { + if (block == AWS_HTTP_HEADER_BLOCK_INFORMATIONAL) return; + + AWS_LOGSTREAM_TRACE(CRT_HTTP_CLIENT_TAG, "Received Headers: "); + + for (size_t i = 0; i < headersCount; ++i) { + const Aws::Crt::Http::HttpHeader* header = &headersArray[i]; + Aws::String headerNameStr((const char*)header->name.ptr, header->name.len); + Aws::String headerValueStr((const char*)header->value.ptr, header->value.len); + AWS_LOGSTREAM_TRACE(CRT_HTTP_CLIENT_TAG, headerNameStr << ": " << headerValueStr); + response->AddHeader(headerNameStr, std::move(headerValueStr)); + } +} + +void OnIncomingHeadersBlockDone(Aws::Crt::Http::HttpStream& stream, enum aws_http_header_block, + const std::shared_ptr& response) { + AWS_LOGSTREAM_TRACE(CRT_HTTP_CLIENT_TAG, "Received response code: " << stream.GetResponseStatusCode()); + response->SetResponseCode((Aws::Http::HttpResponseCode)stream.GetResponseStatusCode()); +} + +// Request is done. If there was an error set it, otherwise just wake up the cvar. +void OnStreamComplete(Aws::Crt::Http::HttpStream&, int errorCode, AsyncWaiter& waiter, + const std::shared_ptr& response) { + if (errorCode) { + // TODO: get the right error parsed out. + response->SetClientErrorType(Aws::Client::CoreErrors::NETWORK_CONNECTION); + response->SetClientErrorMessage(aws_error_debug_str(errorCode)); + } + + waiter.Wakeup(); +} + +// if the connection acquisition failed, go ahead and fail the request and wakeup the cvar. +// If it succeeded go ahead and make the request. +void OnClientConnectionAvailable(std::shared_ptr connection, int errorCode, + std::shared_ptr& connectionReference, + Aws::Crt::Http::HttpRequestOptions& requestOptions, AsyncWaiter& waiter, + const std::shared_ptr& request, + const std::shared_ptr& response, const Aws::Http::HttpClient& client) { + bool shouldContinueRequest = client.ContinueRequest(*request); + + if (!shouldContinueRequest) { + response->SetClientErrorType(Aws::Client::CoreErrors::USER_CANCELLED); + response->SetClientErrorMessage("Request cancelled by user's continuation handler"); + waiter.Wakeup(); + return; + } + + int finalErrorCode = errorCode; + if (connection) { + AWS_LOGSTREAM_DEBUG(CRT_HTTP_CLIENT_TAG, "Obtained connection handle " << (void*)connection.get()); + + auto clientStream = connection->NewClientStream(requestOptions); + connectionReference = connection; + + if (clientStream && clientStream->Activate()) { + return; + } + + finalErrorCode = aws_last_error(); + AWS_LOGSTREAM_ERROR(CRT_HTTP_CLIENT_TAG, "Initiation of request failed because " << aws_error_debug_str(finalErrorCode)); + } + + const char* errorMsg = aws_error_debug_str(finalErrorCode); + AWS_LOGSTREAM_ERROR(CRT_HTTP_CLIENT_TAG, "Obtaining connection failed because " << errorMsg); + response->SetClientErrorType(Aws::Client::CoreErrors::NETWORK_CONNECTION); + response->SetClientErrorMessage(errorMsg); + + waiter.Wakeup(); +} + +class CRTClientStream : public Aws::Http::ClientStream { + public: + CRTClientStream(std::shared_ptr stream, std::shared_ptr response, + std::shared_ptr crtRequest) + : m_stream(std::move(stream)), m_response(std::move(response)), m_crtRequest(std::move(crtRequest)) {} + ~CRTClientStream() override = default; + + bool Activate() override { + return m_stream->Activate(); + }; + + int WriteData(std::shared_ptr stream, const std::function& onComplete, bool endStream) override { + auto crtStream = std::make_shared(stream, Aws::Crt::ApiAllocator()); + return m_stream->WriteData(crtStream, + [onComplete](std::shared_ptr&, int errorCode) { + onComplete(errorCode); + }, + endStream); + } + + std::shared_ptr GetResponse() const override { return m_response; } + + private: + std::shared_ptr m_stream; + std::shared_ptr m_response; + // extend life of request for duration of client stream. Request must outlive + // the last call to WriteData. + std::shared_ptr m_crtRequest; +}; + +class CRTConnection : public Aws::Http::Connection { + public: + explicit CRTConnection(std::shared_ptr connection) : m_connection(std::move(connection)) {} + ~CRTConnection() override = default; + + std::shared_ptr NewClientStream(const std::shared_ptr& request, + std::function onStreamComplete) override { + auto crtRequest = Aws::Crt::MakeShared(Aws::Crt::g_allocator); + auto response = Aws::MakeShared(CRT_HTTP_CLIENT_TAG, request); + + AddRequestMetadataToCrtRequest(request, crtRequest); + + Aws::Crt::Http::HttpRequestOptions requestOptions{}; + requestOptions.request = crtRequest.get(); + + requestOptions.onIncomingHeaders = [response](Aws::Crt::Http::HttpStream& stream, enum aws_http_header_block block, + const Aws::Crt::Http::HttpHeader* headersArray, std::size_t headersCount) { + OnIncomingHeaders(stream, block, headersArray, headersCount, response); + }; + + requestOptions.onIncomingHeadersBlockDone = [request, response](Aws::Crt::Http::HttpStream& stream, enum aws_http_header_block block) { + OnIncomingHeadersBlockDone(stream, block, response); + auto& headersHandler = request->GetHeadersReceivedEventHandler(); + if (headersHandler) { + headersHandler(request.get(), response.get()); + } + }; + + requestOptions.onIncomingBody = [request, response](Aws::Crt::Http::HttpStream& stream, const Aws::Crt::ByteCursor& body) { + OnResponseBodyReceived(stream, body, response, request); + }; + + requestOptions.onStreamComplete = [response, onStreamComplete](Aws::Crt::Http::HttpStream &, int errorCode) -> void { + if (errorCode) { + // TODO: get the right error parsed out. + response->SetClientErrorType(Aws::Client::CoreErrors::NETWORK_CONNECTION); + response->SetClientErrorMessage(aws_error_debug_str(errorCode)); + } + onStreamComplete(errorCode); + }; + + requestOptions.UseManualDataWrites = true; + + auto crtStream = m_connection->NewClientStream(requestOptions); + if (!crtStream) { + return nullptr; + } + return Aws::MakeShared(CRT_HTTP_CLIENT_TAG, std::move(crtStream), std::move(response), std::move(crtRequest)); + } + + private: + std::shared_ptr m_connection; +}; +} // namespace + // Adapts AWS SDK input streams and rate limiters to the CRT input stream reading model. class SDKAdaptingInputStream : public Aws::Crt::Io::StdIOStreamInputStream { public: @@ -98,40 +351,6 @@ class SDKAdaptingInputStream : public Aws::Crt::Io::StdIOStreamInputStream { bool m_isStreaming; }; -// Just a wrapper around a Condition Variable and a mutex, which handles wait and timed waits while protecting -// from spurious wakeups. -class AsyncWaiter -{ -public: - AsyncWaiter() = default; - AsyncWaiter(const AsyncWaiter&) = delete; - AsyncWaiter& operator=(const AsyncWaiter&) = delete; - - void Wakeup() - { - std::lock_guard locker(m_lock); - m_wakeupIntentional = true; - m_cvar.notify_one(); - } - - void WaitOnCompletion() - { - std::unique_lock uniqueLocker(m_lock); - m_cvar.wait(uniqueLocker, [this](){return m_wakeupIntentional;}); - } - - bool WaitOnCompletionFor(const size_t ms) - { - std::unique_lock uniqueLocker(m_lock); - return m_cvar.wait_for(uniqueLocker, std::chrono::milliseconds(ms), [this](){return m_wakeupIntentional;}); - } - -private: - std::mutex m_lock; - std::condition_variable m_cvar; - bool m_wakeupIntentional{false}; -}; - namespace Aws { namespace Http @@ -202,158 +421,6 @@ namespace Aws m_connectionPools.clear(); } - static void AddRequestMetadataToCrtRequest(const std::shared_ptr& request, const std::shared_ptr& crtRequest) - { - const char* methodStr = Aws::Http::HttpMethodMapper::GetNameForHttpMethod(request->GetMethod()); - AWS_LOGSTREAM_TRACE(CRT_HTTP_CLIENT_TAG, "Making " << methodStr << " request to " << request->GetURIString()); - AWS_LOGSTREAM_TRACE(CRT_HTTP_CLIENT_TAG, "Including headers:"); - //Add http headers to the request. - for (const auto& header : request->GetHeaders()) - { - Crt::Http::HttpHeader crtHeader; - AWS_LOGSTREAM_TRACE(CRT_HTTP_CLIENT_TAG, header.first << ": " << header.second); - crtHeader.name = Crt::ByteCursorFromArray((const uint8_t *)header.first.data(), header.first.length()); - crtHeader.value = Crt::ByteCursorFromArray((const uint8_t *)header.second.data(), header.second.length()); - crtRequest->AddHeader(crtHeader); - } - - // HTTP method, GET, PUT, DELETE, etc... - auto methodCursor = Crt::ByteCursorFromCString(methodStr); - crtRequest->SetMethod(methodCursor); - - // Path portion of the request - auto pathStrCpy = request->GetUri().GetURLEncodedPathRFC3986(); - auto queryStrCpy = request->GetUri().GetQueryString(); - Aws::StringStream ss; - - //CRT client has you pass the query string as part of the path. concatenate that here. - ss << pathStrCpy << queryStrCpy; - auto fullPathAndQueryCpy = ss.str(); - auto pathCursor = Crt::ByteCursorFromArray((uint8_t *)fullPathAndQueryCpy.c_str(), fullPathAndQueryCpy.length()); - crtRequest->SetPath(pathCursor); - } - - static void OnResponseBodyReceived(Crt::Http::HttpStream& stream, const Crt::ByteCursor& body, const std::shared_ptr& response, const std::shared_ptr& request, const Http::HttpClient& client) - { - if (!client.ContinueRequest(*request) || !client.IsRequestProcessingEnabled()) - { - AWS_LOGSTREAM_INFO(CRT_HTTP_CLIENT_TAG, "Request canceled. Canceling request by closing the connection."); - stream.GetConnection().Close(); - return; - } - - //TODO: handle the read rate limiter here, once back pressure is setup. - assert(response); - for (const auto& hashIterator : request->GetResponseValidationHashes()) - { - std::stringstream headerStr; - headerStr<<"x-amz-checksum-"<HasHeader(headerStr.str().c_str())) - { - hashIterator.second->Update(reinterpret_cast(body.ptr), body.len); - break; - } - } - - // When data is received from the content body of the incoming response, just copy it to the output stream. - response->GetResponseBody().write((const char*)body.ptr, static_cast(body.len)); - if (response->GetResponseBody().fail()) { - const auto& ref = response->GetResponseBody(); - AWS_LOGSTREAM_ERROR(CRT_HTTP_CLIENT_TAG, "Failed to write " << body.len << " (eof: " << ref.eof() << ", bad: " << ref.bad() << ")"); - } - - if (request->IsEventStreamRequest() && !response->HasHeader(Aws::Http::X_AMZN_ERROR_TYPE)) - { - response->GetResponseBody().flush(); - } - - auto& receivedHandler = request->GetDataReceivedEventHandler(); - if (receivedHandler) - { - receivedHandler(request.get(), response.get(), static_cast(body.len)); - } - - AWS_LOGSTREAM_TRACE(CRT_HTTP_CLIENT_TAG, body.len << " bytes written to response."); - - } - - // on response headers arriving, write them to the response. - static void OnIncomingHeaders(Crt::Http::HttpStream&, enum aws_http_header_block block, const Crt::Http::HttpHeader* headersArray, std::size_t headersCount, const std::shared_ptr& response) - { - if (block == AWS_HTTP_HEADER_BLOCK_INFORMATIONAL) return; - - AWS_LOGSTREAM_TRACE(CRT_HTTP_CLIENT_TAG, "Received Headers: "); - - for (size_t i = 0; i < headersCount; ++i) - { - const Crt::Http::HttpHeader* header = &headersArray[i]; - Aws::String headerNameStr((const char*)header->name.ptr, header->name.len); - Aws::String headerValueStr((const char*)header->value.ptr, header->value.len); - AWS_LOGSTREAM_TRACE(CRT_HTTP_CLIENT_TAG, headerNameStr << ": " << headerValueStr); - response->AddHeader(headerNameStr, std::move(headerValueStr)); - } - } - - static void OnIncomingHeadersBlockDone(Crt::Http::HttpStream& stream, enum aws_http_header_block, const std::shared_ptr& response) - { - AWS_LOGSTREAM_TRACE(CRT_HTTP_CLIENT_TAG, "Received response code: " << stream.GetResponseStatusCode()); - response->SetResponseCode((HttpResponseCode)stream.GetResponseStatusCode()); - } - - // Request is done. If there was an error set it, otherwise just wake up the cvar. - static void OnStreamComplete(Crt::Http::HttpStream&, int errorCode, AsyncWaiter& waiter, const std::shared_ptr& response) - { - if (errorCode) - { - //TODO: get the right error parsed out. - response->SetClientErrorType(Aws::Client::CoreErrors::NETWORK_CONNECTION); - response->SetClientErrorMessage(aws_error_debug_str(errorCode)); - } - - waiter.Wakeup(); - } - - // if the connection acquisition failed, go ahead and fail the request and wakeup the cvar. - // If it succeeded go ahead and make the request. - static void OnClientConnectionAvailable(std::shared_ptr connection, int errorCode, std::shared_ptr& connectionReference, - Crt::Http::HttpRequestOptions& requestOptions, AsyncWaiter& waiter, const std::shared_ptr& request, - const std::shared_ptr& response, const HttpClient& client) - { - bool shouldContinueRequest = client.ContinueRequest(*request); - - if (!shouldContinueRequest) - { - response->SetClientErrorType(Client::CoreErrors::USER_CANCELLED); - response->SetClientErrorMessage("Request cancelled by user's continuation handler"); - waiter.Wakeup(); - return; - } - - int finalErrorCode = errorCode; - if (connection) - { - AWS_LOGSTREAM_DEBUG(CRT_HTTP_CLIENT_TAG, "Obtained connection handle " << (void*)connection.get()); - - auto clientStream = connection->NewClientStream(requestOptions); - connectionReference = connection; - - if (clientStream && clientStream->Activate()) { - return; - } - - finalErrorCode = aws_last_error(); - AWS_LOGSTREAM_ERROR(CRT_HTTP_CLIENT_TAG, "Initiation of request failed because " << aws_error_debug_str(finalErrorCode)); - - } - - const char *errorMsg = aws_error_debug_str(finalErrorCode); - AWS_LOGSTREAM_ERROR(CRT_HTTP_CLIENT_TAG, "Obtaining connection failed because " << errorMsg); - response->SetClientErrorType(Aws::Client::CoreErrors::NETWORK_CONNECTION); - response->SetClientErrorMessage(errorMsg); - - waiter.Wakeup(); - } - std::shared_ptr CRTHttpClient::MakeRequest(const std::shared_ptr& request, Aws::Utils::RateLimits::RateLimiterInterface*, Aws::Utils::RateLimits::RateLimiterInterface*) const @@ -385,7 +452,13 @@ namespace Aws requestOptions.onIncomingBody = [this, request, response](Crt::Http::HttpStream& stream, const Crt::ByteCursor& body) { - OnResponseBodyReceived(stream, body, response, request, *this); + if (!ContinueRequest(*request) || !IsRequestProcessingEnabled()) + { + AWS_LOGSTREAM_INFO(CRT_HTTP_CLIENT_TAG, "Request canceled. Canceling request by closing the connection."); + stream.GetConnection().Close(); + return; + } + OnResponseBodyReceived(stream, body, response, request); }; requestOptions.onIncomingHeaders = @@ -618,5 +691,40 @@ namespace Aws } } + Aws::Http::HttpClient::AcquireConnectionOutcome CRTHttpClient::AcquireConnection(const std::shared_ptr& request) { + auto requestConnOptions = CreateConnectionOptionsForRequest(request); + auto connectionManager = GetWithCreateConnectionManagerForRequest(request, requestConnOptions); + + if (!connectionManager) + { + return AcquireConnectionOutcome{Aws::Client::AWSError{ + Aws::Client::CoreErrors::INVALID_PARAMETER_COMBINATION, + "InvalidParameterCombination", + aws_error_debug_str(aws_last_error()), + false + }}; + } + + AcquireConnectionOutcome outcome{}; + AsyncWaiter waiter; + + connectionManager->AcquireConnection( + [&outcome, &waiter]( + std::shared_ptr acquiredConnection, int errorCode) -> void { + if (errorCode != AWS_OP_SUCCESS) { + outcome = AcquireConnectionOutcome{Aws::Client::AWSError{ + Aws::Client::CoreErrors::NETWORK_CONNECTION, + "CouldNotAcquireConnection", + aws_error_debug_str(errorCode), + false}}; + } else { + outcome = AcquireConnectionOutcome{Aws::MakeShared(CRT_HTTP_CLIENT_TAG, std::move(acquiredConnection))}; + } + waiter.Wakeup(); + }); + + waiter.WaitOnCompletion(); + return outcome; + } } } diff --git a/tests/aws-cpp-sdk-dynamodb-integration-tests/TableOperationTest.cpp b/tests/aws-cpp-sdk-dynamodb-integration-tests/TableOperationTest.cpp index a386472c23b1..33c7e3194086 100644 --- a/tests/aws-cpp-sdk-dynamodb-integration-tests/TableOperationTest.cpp +++ b/tests/aws-cpp-sdk-dynamodb-integration-tests/TableOperationTest.cpp @@ -10,7 +10,10 @@ #include #include #include +#include #include +#include +#include #include #include #include @@ -1539,5 +1542,85 @@ TEST_F(TableOperationTest, TestBatchGetItem) { EXPECT_EQ(1ul, batch_get_item_result.GetResult().GetResponses().size()); } +#if AWS_SDK_USE_CRT_HTTP +const char BASE_HTTP_WRITE_TEST_TABLE[] = "HTTP_WRITE"; +TEST_F(TableOperationTest, TestWriteDataApi) { + Aws::String tableName = BuildTableName(BASE_HTTP_WRITE_TEST_TABLE); + CreateTable(tableName, 10, 10); + + // Build a raw PutItem HTTP request + ClientConfiguration config; + config.scheme = Scheme::HTTPS; + auto httpClient = Aws::Http::CreateHttpClient(config); + + Aws::String uri = "https://dynamodb." + config.region + ".amazonaws.com"; + auto request = CreateHttpRequest(Aws::Http::URI(uri), + Aws::Http::HttpMethod::HTTP_POST, + Aws::Utils::Stream::DefaultResponseStreamFactoryMethod); + request->SetHeaderValue("content-type", "application/x-amz-json-1.0"); + request->SetHeaderValue("x-amz-target", "DynamoDB_20120810.PutItem"); + + Aws::String payload = "{\"TableName\":\"" + tableName + + "\",\"Item\":{\"" + HASH_KEY_NAME + "\":{\"S\":\"write-data-test\"}}}"; + + auto body = Aws::MakeShared(ALLOCATION_TAG); + *body << payload; + request->AddContentBody(body); + request->SetContentLength(std::to_string(payload.size())); + + // Sign the request + auto credProvider = Aws::MakeShared(ALLOCATION_TAG); + Aws::Client::AWSAuthV4Signer signer(credProvider, "dynamodb", Aws::Region::US_EAST_1); + ASSERT_TRUE(signer.SignRequest(*request)); + + // Acquire connection and create stream + auto connOutcome = httpClient->AcquireConnection(request); + ASSERT_TRUE(connOutcome.IsSuccess()); + + std::mutex mtx; + std::condition_variable cv; + bool streamDone = false; + + auto stream = connOutcome.GetResult()->NewClientStream(request, [&](int errorCode) { + AWS_UNREFERENCED_PARAM(errorCode); + std::lock_guard lk(mtx); + streamDone = true; + cv.notify_one(); + }); + ASSERT_NE(nullptr, stream); + ASSERT_TRUE(stream->Activate()); + + // Write the body + body->seekg(0, std::ios::beg); + bool writeDone = false; + int writeError = 0; + auto writeResult = stream->WriteData(body, + [&](int errorCode) { + std::lock_guard lk(mtx); + writeError = errorCode; + writeDone = true; + cv.notify_one(); + }, + true); + ASSERT_EQ(0, writeResult); + + { + std::unique_lock lk(mtx); + cv.wait(lk, [&]() { return writeDone; }); + } + ASSERT_EQ(0, writeError); + + // Wait for stream completion + { + std::unique_lock lk(mtx); + cv.wait(lk, [&]() { return streamDone; }); + } + + auto response = stream->GetResponse(); + ASSERT_NE(nullptr, response); + ASSERT_EQ(Aws::Http::HttpResponseCode::OK, response->GetResponseCode()); +} +#endif // AWS_SDK_USE_CRT_HTTP + } // anonymous namespace