From e7390c14f1d22b5d5aa3d340850e93c96cd5e755 Mon Sep 17 00:00:00 2001 From: Jacob Dufault Date: Mon, 20 Mar 2017 22:40:14 -0700 Subject: [PATCH] Cleanup new MessageQueue impl --- src/message_queue.cc | 329 +++++++++++++++++++++++++++++++++---------- src/message_queue.h | 37 +++-- 2 files changed, 270 insertions(+), 96 deletions(-) diff --git a/src/message_queue.cc b/src/message_queue.cc index 31609bf2..9b0afac4 100644 --- a/src/message_queue.cc +++ b/src/message_queue.cc @@ -6,11 +6,26 @@ #include #include "platform.h" +#include "resizable_buffer.h" +#include "../utils.h" +#include "../third_party/doctest/doctest/doctest.h" + +// TODO: figure out a logging solution +//#define MESSAGE_QUEUE_LOG namespace { const int kMinimumPartialPayloadSize = 128; +struct MessageHeader { + MessageHeader(uint32_t partial_id, bool has_more_chunks, size_t size) + : partial_id(partial_id), has_more_chunks(has_more_chunks), size(size) {} + + uint32_t partial_id; + bool has_more_chunks; + size_t size; +}; + struct BufferMessageIterator { static BufferMessageIterator Begin(void* buffer, size_t bytes_used) { if (bytes_used == 0) @@ -24,22 +39,22 @@ struct BufferMessageIterator { // Start of buffer to iterate. - void* buffer; + uint8_t* buffer; // Number of bytes left in buffer to parse. size_t remaining_bytes; BufferMessageIterator(void* buffer, size_t remaining_bytes) - : buffer(buffer), remaining_bytes(remaining_bytes) {} + : buffer(reinterpret_cast(buffer)), remaining_bytes(remaining_bytes) {} - Message* get() const { - assert(buffer); - return reinterpret_cast(buffer); + + MessageHeader* get() const { + return reinterpret_cast(buffer); } - Message* operator*() const { return get(); } - Message* operator->() const { return get(); } + MessageHeader* operator*() const { return get(); } + MessageHeader* operator->() const { return get(); } void operator++() { - size_t next_message_offset = get()->total_size; + size_t next_message_offset = sizeof(MessageHeader) + get()->size; if (next_message_offset >= remaining_bytes) { assert(next_message_offset == remaining_bytes); buffer = nullptr; @@ -47,10 +62,14 @@ struct BufferMessageIterator { return; } - buffer = reinterpret_cast(buffer) + next_message_offset; + buffer = buffer + next_message_offset; remaining_bytes -= next_message_offset; } + void* message_data() const { + return reinterpret_cast(buffer + sizeof(MessageHeader)); + } + bool operator==(const BufferMessageIterator& other) const { return buffer == other.buffer && remaining_bytes == other.remaining_bytes; } @@ -73,7 +92,9 @@ void Repeat(std::function action) { if (!first) { if (log_iteration_count > 1000) { log_iteration_count = 0; - std::cerr << "[info]: shmem full, waiting (" << log_count++ << ")" << std::endl; // TODO: remove +#if defined(MESSAGE_QUEUE_LOG) + std::cerr << "[info]: Buffer full, waiting (" << log_count++ << ")" << std::endl; +#endif } ++log_iteration_count; // TODO: See if we can figure out a way to use condition variables cross-process. @@ -87,12 +108,32 @@ void Repeat(std::function action) { } } + +ResizableBuffer* CreateOrFindResizableBuffer(std::unordered_map>& resizable_buffers, uint32_t id) { + auto it = resizable_buffers.find(id); + if (it != resizable_buffers.end()) + return it->second.get(); + return (resizable_buffers[id] = MakeUnique()).get(); +} + +std::unique_ptr MakeBuffer(void* content, size_t size) { + auto buffer = Buffer::Create(size); + memcpy(buffer->data, content, size); + return std::move(buffer); +} + } // namespace +Message::Message(void* data, size_t size) : data(data), size(size) {} + struct MessageQueue::BufferMetadata { // Total number of used bytes exluding the sizeof this metadata object. - void set_total_messages_byte_count(size_t used_bytes) { - total_message_bytes_ = used_bytes; + void add_used_bytes(size_t used_bytes) { + total_message_bytes_ += used_bytes; + } + + void reset() { + total_message_bytes_ = 0; } // The total number of bytes in use. @@ -112,90 +153,145 @@ private: MessageQueue::MessageQueue(std::unique_ptr buffer, bool buffer_has_data) : buffer_(std::move(buffer)) { + + assert(buffer_->capacity >= (sizeof(BufferMetadata) + kMinimumPartialPayloadSize)); + if (!buffer_has_data) new (buffer_->data) BufferMetadata(); + + local_buffer_ = Buffer::Create(buffer_->capacity - sizeof(BufferMetadata)); + memset(local_buffer_->data, 0, local_buffer_->capacity); } void MessageQueue::Enqueue(const Message& message) { - //BufferMessageIterator::Begin(first_message_in_buffer(), metadata()->total_message_bytes); - - int partial_message_id = 0; +#if defined(MESSAGE_QUEUE_LOG) + int count = 0; +#endif + uint32_t partial_id = 0; + uint8_t* payload_data = reinterpret_cast(message.data); + size_t payload_size = message.size; Repeat([&]() { +#if defined(MESSAGE_QUEUE_LOG) + if (count++ > 500) { + std::cerr << "x500 Sending partial message payload_size=" << payload_size << std::endl; + count = 0; + } +#endif + auto lock = buffer_->WaitForExclusiveAccess(); // We cannot find the entire payload in the buffer. We have to send chunks // of it over time. - if (message.total_size >= BytesAvailableInBuffer()) { + if (payload_size >= BytesAvailableInBuffer()) { // There's not enough room for our minimum payload size, so try again later. - if ((sizeof(Message) + kMinimumPartialPayloadSize) > BytesAvailableInBuffer()) + if ((sizeof(MessageHeader) + kMinimumPartialPayloadSize) > BytesAvailableInBuffer()) return RepeatResult::RunAgain; - if (partial_message_id == 0) - partial_message_id = ++metadata()->next_partial_message_id; // note: pre-increment so we 1 as initial value - - - size_t sent_payload_size = BytesAvailableInBuffer() - sizeof(Message); - free_message_in_buffer()->CopyFrom(message); - metadata()->set_total_messages_byte_count( - metadata()->total_message_bytes() + sizeof(Message) + sent_payload_size); - - //shared_buffer->free_message()->Setup(message->ipc_id, partial_message_id, true /*has_more_chunks*/, sent_payload_size, payload); - //shared_buffer->metadata()->bytes_used += sizeof(JsonMessage) + sent_payload_size; - //shared_buffer->free_message()->ipc_id = IpcId::Invalid; // Note: free_message() may be past writable memory. - - if (count++ > 50) { - std::cerr << "x50 Sending partial message with payload_size=" << sent_payload_size << std::endl; - count = 0; + if (partial_id == 0) { + // note: pre-increment so we use 1 as the initial value + partial_id = ++metadata()->next_partial_message_id; } - // Prepare for next time. + size_t sent_payload_size = BytesAvailableInBuffer() - sizeof(MessageHeader); + // |sent_payload_size| must always be smaller than |payload_size|. If it is equal to + // |payload_size|, than we could have sent it as a normal, non-partial message. It's + // also an error if it is larger than payload_size (we're sending garbage data). + assert(sent_payload_size < payload_size); + + CopyPayloadToBuffer(partial_id, payload_data, sent_payload_size, true /*has_more_chunks*/); + payload_data += sent_payload_size; payload_size -= sent_payload_size; - payload += sent_payload_size; + + // Prepare for next time. return RepeatResult::RunAgain; } - return RepeatResult::Break; - -#if false - assert(payload_size > 0); - - // We cannot find the entire payload in the buffer. We - // have to send chunks of it over time. - if ((sizeof(JsonMessage) + payload_size) > shared_buffer->bytes_available()) { - if ((sizeof(JsonMessage) + kMinimumPartialPayloadSize) > shared_buffer->bytes_available()) - return DispatchResult::RunAgain; - - if (partial_message_id == 0) - partial_message_id = ++shared_buffer->metadata()->next_partial_message_id; // note: pre-increment so we 1 as initial value - - size_t sent_payload_size = shared_buffer->bytes_available() - sizeof(JsonMessage); - shared_buffer->free_message()->Setup(message->ipc_id, partial_message_id, true /*has_more_chunks*/, sent_payload_size, payload); - shared_buffer->metadata()->bytes_used += sizeof(JsonMessage) + sent_payload_size; - //shared_buffer->free_message()->ipc_id = IpcId::Invalid; // Note: free_message() may be past writable memory. - - if (count++ > 50) { - std::cerr << "x50 Sending partial message with payload_size=" << sent_payload_size << std::endl; - count = 0; - } - - // Prepare for next time. - payload_size -= sent_payload_size; - payload += sent_payload_size; - return RepeatResult::RunAgain; - } // The entire payload fits. Send it all now. else { // Include partial message id, as there could have been previous parts of this payload. - shared_buffer->free_message()->Setup(message->ipc_id, partial_message_id, false /*has_more_chunks*/, payload_size, payload); - shared_buffer->metadata()->bytes_used += sizeof(JsonMessage) + payload_size; - shared_buffer->free_message()->ipc_id = IpcId::Invalid; - //std::cerr << "Sending full message with payload_size=" << payload_size << std::endl; + CopyPayloadToBuffer(partial_id, payload_data, payload_size, false /*has_more_chunks*/); +#if defined(MESSAGE_QUEUE_LOG) + std::cerr << "Sending full message with payload_size=" << payload_size << std::endl; +#endif return RepeatResult::Break; } + }); +} + +std::vector> MessageQueue::DequeueAll() { + std::unordered_map> resizable_buffers; + + std::vector> result; + + do { + size_t local_buffer_size = 0; + + // Move data from shared memory into a local buffer. Do this + // before parsing the blocks so that other processes can begin + // posting data as soon as possible. + { + std::unique_ptr lock = buffer_->WaitForExclusiveAccess(); + assert(BytesAvailableInBuffer() >= 0); + + // note: Do not copy over buffer_ metadata. + local_buffer_size = metadata()->total_message_bytes(); + memcpy(local_buffer_->data, + first_message_in_buffer(), + local_buffer_size); + + metadata()->reset(); + } + + // Parse blocks from shared memory. + for (auto it = BufferMessageIterator::Begin(local_buffer_->data, local_buffer_size); + it != BufferMessageIterator::End(); + ++it) { +#if defined(MESSAGE_QUEUE_LOG) + std::cerr << "Got message with partial_id=" << it->partial_id << ", payload_size=" << it->size << ", has_more_chunks=" << it->has_more_chunks << std::endl; #endif -}); + + if (it->partial_id != 0) { + auto* buf = CreateOrFindResizableBuffer(resizable_buffers, it->partial_id); + buf->Append(it.message_data(), it->size); + + if (!it->has_more_chunks) { + // We can't remove the resizable buffer yet because we need to keep the data alive. + // We will remove it the next time this function is called. + result.push_back(MakeBuffer(buf->buffer, buf->size)); + resizable_buffers.erase(it->partial_id); + } + } + else { + // Note: we can't just return pointers to |local_buffer_| because if we + // read a partial message we will invalidate all of the existing + // pointers. We could jump through hoops to make it work (ie, if no + // partial messages return pointers to local_buffer_) but it is not + // worth the effort. + assert(!it->has_more_chunks); + result.push_back(MakeBuffer(it.message_data(), it->size)); + } + } + + // Let other threads run. We still want to run as fast as possible, though. + std::this_thread::sleep_for(std::chrono::microseconds(0)); + } while (resizable_buffers.size() > 0); + + return result; + +} + +void MessageQueue::CopyPayloadToBuffer(uint32_t partial_id, void* payload, size_t payload_size, bool has_more_chunks) { + assert(BytesAvailableInBuffer() >= (sizeof(MessageHeader) + payload_size)); + + // Copy header. + MessageHeader header(partial_id, has_more_chunks, payload_size); + memcpy(first_free_address_in_buffer(), &header, sizeof(MessageHeader)); + metadata()->add_used_bytes(sizeof(MessageHeader)); + // Copy payload. + memcpy(first_free_address_in_buffer(), payload, payload_size); + metadata()->add_used_bytes(payload_size); } MessageQueue::BufferMetadata* MessageQueue::metadata() const { @@ -208,12 +304,95 @@ size_t MessageQueue::BytesAvailableInBuffer() const { Message* MessageQueue::first_message_in_buffer() const { return reinterpret_cast( - reinterpret_cast(buffer_->data) + sizeof(BufferMetadata)); + reinterpret_cast(buffer_->data) + sizeof(BufferMetadata)); } -Message* MessageQueue::free_message_in_buffer() const { - if (metadata()->total_bytes_used_including_metadata >= buffer_->capacity) +void* MessageQueue::first_free_address_in_buffer() const { + if (metadata()->total_bytes_used_including_metadata() >= buffer_->capacity) return nullptr; - return reinterpret_cast( - reinterpret_cast(buffer_->data) + metadata()->total_bytes_used_including_metadata()); -} \ No newline at end of file + return reinterpret_cast( + reinterpret_cast(buffer_->data) + + metadata()->total_bytes_used_including_metadata()); +} + +TEST_SUITE("MessageQueue"); + +TEST_CASE("simple") { + MessageQueue queue(Buffer::Create(kMinimumPartialPayloadSize * 5), false /*buffer_has_data*/); + + int data = 0; + data = 1; + queue.Enqueue(Message(&data, sizeof(data))); + data = 2; + queue.Enqueue(Message(&data, sizeof(data))); + + int expected = 0; + for (std::unique_ptr& m : queue.DequeueAll()) { + ++expected; + + REQUIRE(m->capacity == sizeof(data)); + int* value = reinterpret_cast(m->data); + REQUIRE(expected == *value); + } +} + +TEST_CASE("large payload") { + MessageQueue queue(Buffer::Create(kMinimumPartialPayloadSize * 5), false /*buffer_has_data*/); + + // Allocate big buffer. + size_t num_ints = kMinimumPartialPayloadSize * 100; + int* sent_ints = reinterpret_cast(malloc(sizeof(int) * num_ints)); + for (int i = 0; i < num_ints; ++i) + sent_ints[i] = i; + + // Queue big buffer. Add surrounding messages to make sure they get sent correctly. + // Run in a separate thread because Enqueue will block. + volatile bool done_sending = false; + std::thread sender([&]() { + int small = 5; + queue.Enqueue(Message(&small, sizeof(small))); + queue.Enqueue(Message(sent_ints, sizeof(int) * num_ints)); + queue.Enqueue(Message(&small, sizeof(small))); + done_sending = true; + }); + + // Receive sent messages. + { + // Keep dequeuing messages until we have three. + std::vector> messages; + while (messages.size() != 3) { + for (auto& message : queue.DequeueAll()) + messages.emplace_back(std::move(message)); + } + sender.join(); + + // Small + { + REQUIRE(sizeof(int) == messages[0]->capacity); + int* value = reinterpret_cast(messages[0]->data); + REQUIRE(*value == 5); + } + + // Big + { + int* received_ints = reinterpret_cast(messages[1]->data); + REQUIRE(received_ints != sent_ints); + REQUIRE(messages[1]->capacity == (sizeof(int) * num_ints)); + for (int i = 0; i < num_ints; ++i) { + REQUIRE(received_ints[i] == i); + REQUIRE(received_ints[i] == sent_ints[i]); + } + } + + // Small + { + REQUIRE(sizeof(int) == messages[2]->capacity); + int* value = reinterpret_cast(messages[2]->data); + REQUIRE(*value == 5); + } + } + + free(sent_ints); +} + +TEST_SUITE_END(); \ No newline at end of file diff --git a/src/message_queue.h b/src/message_queue.h index e2ab030e..4bbea4d4 100644 --- a/src/message_queue.h +++ b/src/message_queue.h @@ -2,26 +2,27 @@ #include #include +#include #include "buffer.h" +class ResizableBuffer; + struct Message { - // Unique message identifier. - uint8_t message_id; + Message(void* data, size_t size); - // Total size of the message (including metadata that this object stores). - size_t total_size; - - // Size of the extra message data immediately following the message payload. - size_t message_size() const { return total_size - sizeof(Message); } + void* data; + size_t size; }; // A MessageQueue is a FIFO container storing messages in an arbitrary memory -// buffer. +// buffer that is cross-thread and cross-process safe. This means: // - Multiple separate MessageQueues instantiations can point to the -// same underlying buffer -// - Buffer is fully relocatable, ie, it can have multiple different +// same underlying buffer and use it at the same time. +// - The buffer is fully relocatable, ie, it can have multiple different // addresses (as is the case for memory shared across processes). +// +// There can be multiple writers, but there can only be one reader. struct MessageQueue { // Create a new MessageQueue using |buffer| as the backing data storage. // This does *not* take ownership over the memory stored in |buffer|. @@ -44,28 +45,22 @@ struct MessageQueue { void Enqueue(const Message& message); // Take all messages from the queue. - // - // note: - // We could make this allocation free by returning raw pointers to the - // internal process-local buffer, but that is pretty haphazard and likely - // to cause a very confusing crash. The extra memory allocations here from - // unique_ptr going to make a performance difference. - std::vector> DequeueAll(); - - // Take the first available message from the queue. - std::unique_ptr DequeueFirst(); + std::vector> DequeueAll(); private: struct BufferMetadata; + void CopyPayloadToBuffer(uint32_t partial_id, void* payload, size_t payload_size, bool has_more_chunks); + BufferMetadata* metadata() const; // Returns the number of bytes currently available in the buffer. size_t BytesAvailableInBuffer() const; Message* first_message_in_buffer() const; // First free message in the buffer. - Message* free_message_in_buffer() const; + void* first_free_address_in_buffer() const; std::unique_ptr buffer_; + std::unique_ptr local_buffer_; }; /*