#include "message_queue.h" #include #include #include #include #include #include #include "platform.h" #include "resizable_buffer.h" #include "utils.h" // TODO: figure out a logging solution //#define MESSAGE_QUEUE_LOG namespace { const int kMinimumPartialPayloadSize = 128; struct MessageHeader { MessageHeader(uint32_t partial_id, bool has_more_chunks, size_t size) : partial_id(partial_id), has_more_chunks(has_more_chunks), size(size) {} uint32_t partial_id; bool has_more_chunks; size_t size; }; struct BufferMessageIterator { static BufferMessageIterator Begin(void* buffer, size_t bytes_used) { if (bytes_used == 0) return End(); return BufferMessageIterator(buffer, bytes_used); } static BufferMessageIterator End() { return BufferMessageIterator(nullptr, 0); } // Start of buffer to iterate. uint8_t* buffer; // Number of bytes left in buffer to parse. size_t remaining_bytes; BufferMessageIterator(void* buffer, size_t remaining_bytes) : buffer(reinterpret_cast(buffer)), remaining_bytes(remaining_bytes) {} MessageHeader* get() const { return reinterpret_cast(buffer); } MessageHeader* operator*() const { return get(); } MessageHeader* operator->() const { return get(); } void operator++() { size_t next_message_offset = sizeof(MessageHeader) + get()->size; if (next_message_offset >= remaining_bytes) { assert(next_message_offset == remaining_bytes); buffer = nullptr; remaining_bytes = 0; return; } buffer = buffer + next_message_offset; remaining_bytes -= next_message_offset; } void* message_data() const { return reinterpret_cast(buffer + sizeof(MessageHeader)); } bool operator==(const BufferMessageIterator& other) const { return buffer == other.buffer && remaining_bytes == other.remaining_bytes; } bool operator!=(const BufferMessageIterator& other) const { return !(*this == other); } }; enum class RepeatResult { RunAgain, Break }; // Run |action| an arbitrary number of times. void Repeat(std::function action) { bool first = true; #if defined(MESSAGE_QUEUE_LOG) int log_iteration_count = 0; int log_count = 0; #endif while (true) { if (!first) { #if defined(MESSAGE_QUEUE_LOG) if (log_iteration_count > 1000) { log_iteration_count = 0; std::cerr << "[info]: Buffer full, waiting (" << log_count++ << ")" << std::endl; } ++log_iteration_count; #endif // TODO: See if we can figure out a way to use condition variables // cross-process. std::this_thread::sleep_for(std::chrono::microseconds(0)); } first = false; if (action() == RepeatResult::RunAgain) continue; break; } } ResizableBuffer* CreateOrFindResizableBuffer( std::unordered_map>& resizable_buffers, uint32_t id) { auto it = resizable_buffers.find(id); if (it != resizable_buffers.end()) return it->second.get(); return (resizable_buffers[id] = MakeUnique()).get(); } std::unique_ptr MakeBuffer(void* content, size_t size) { auto buffer = Buffer::Create(size); memcpy(buffer->data, content, size); return buffer; } } // namespace Message::Message(void* data, size_t size) : data(data), size(size) {} struct MessageQueue::BufferMetadata { // Reset buffer. void reset() { total_message_bytes_ = 0; } // Total number of used bytes excluding the sizeof this metadata object. void add_used_bytes(size_t used_bytes) { total_message_bytes_ += used_bytes; } // The total number of bytes in use. size_t total_bytes_used_including_metadata() { return total_message_bytes_ + sizeof(BufferMetadata); } // The total number of bytes currently used for messages. This does not // include the sizeof the buffer metadata. size_t total_message_bytes() { return total_message_bytes_; } int next_partial_message_id = 0; private: size_t total_message_bytes_ = 0; }; MessageQueue::MessageQueue(std::unique_ptr buffer, bool buffer_has_data) : buffer_(std::move(buffer)) { assert(buffer_->capacity >= (sizeof(BufferMetadata) + kMinimumPartialPayloadSize)); if (!buffer_has_data) new (buffer_->data) BufferMetadata(); local_buffer_ = Buffer::Create(buffer_->capacity - sizeof(BufferMetadata)); memset(local_buffer_->data, 0, local_buffer_->capacity); } void MessageQueue::Enqueue(const Message& message) { #if defined(MESSAGE_QUEUE_LOG) int count = 0; #endif uint32_t partial_id = 0; uint8_t* payload_data = reinterpret_cast(message.data); size_t payload_size = message.size; Repeat([&]() { #if defined(MESSAGE_QUEUE_LOG) if (count++ > 500) { std::cerr << "x500 Sending partial message payload_size=" << payload_size << std::endl; count = 0; } #endif auto lock = buffer_->WaitForExclusiveAccess(); // We cannot find the entire payload in the buffer. We have to send chunks // of it over time. if (payload_size >= BytesAvailableInBuffer()) { // There's not enough room for our minimum payload size, so try again // later. if ((sizeof(MessageHeader) + kMinimumPartialPayloadSize) > BytesAvailableInBuffer()) return RepeatResult::RunAgain; if (partial_id == 0) { // note: pre-increment so we use 1 as the initial value partial_id = ++metadata()->next_partial_message_id; } size_t sent_payload_size = BytesAvailableInBuffer() - sizeof(MessageHeader); // |sent_payload_size| must always be smaller than |payload_size|. If it // is equal to |payload_size|, than we could have sent it as a normal, // non-partial message. It's also an error if it is larger than // payload_size (we're sending garbage data). assert(sent_payload_size < payload_size); CopyPayloadToBuffer(partial_id, payload_data, sent_payload_size, true /*has_more_chunks*/); payload_data += sent_payload_size; payload_size -= sent_payload_size; // Prepare for next time. return RepeatResult::RunAgain; } // The entire payload fits. Send it all now. else { // Include partial message id, as there could have been previous parts of // this payload. CopyPayloadToBuffer(partial_id, payload_data, payload_size, false /*has_more_chunks*/); #if defined(MESSAGE_QUEUE_LOG) std::cerr << "Sending full message with payload_size=" << payload_size << std::endl; #endif return RepeatResult::Break; } }); } std::vector> MessageQueue::DequeueAll() { std::unordered_map> resizable_buffers; std::vector> result; while (true) { size_t local_buffer_size = 0; // Move data from shared memory into a local buffer. Do this // before parsing the blocks so that other processes can begin // posting data as soon as possible. { std::unique_ptr lock = buffer_->WaitForExclusiveAccess(); assert(BytesAvailableInBuffer() >= 0); // note: Do not copy over buffer_ metadata. local_buffer_size = metadata()->total_message_bytes(); memcpy(local_buffer_->data, first_message_in_buffer(), local_buffer_size); metadata()->reset(); } // Parse blocks from shared memory. for (auto it = BufferMessageIterator::Begin(local_buffer_->data, local_buffer_size); it != BufferMessageIterator::End(); ++it) { #if defined(MESSAGE_QUEUE_LOG) std::cerr << "Got message with partial_id=" << it->partial_id << ", payload_size=" << it->size << ", has_more_chunks=" << it->has_more_chunks << std::endl; #endif if (it->partial_id != 0) { auto* buf = CreateOrFindResizableBuffer(resizable_buffers, it->partial_id); buf->Append(it.message_data(), it->size); if (!it->has_more_chunks) { result.push_back(MakeBuffer(buf->buffer, buf->size)); resizable_buffers.erase(it->partial_id); } } else { // Note: we can't just return pointers to |local_buffer_| because if we // read a partial message we will invalidate all of the existing // pointers. We could jump through hoops to make it work (ie, if no // partial messages return pointers to local_buffer_) but it is not // worth the effort. assert(!it->has_more_chunks); result.push_back(MakeBuffer(it.message_data(), it->size)); } } // We're waiting for data to be posted to result. Delay a little so we // don't push the CPU so hard. if (!resizable_buffers.empty()) std::this_thread::sleep_for(std::chrono::microseconds(0)); else break; } return result; } void MessageQueue::CopyPayloadToBuffer(uint32_t partial_id, void* payload, size_t payload_size, bool has_more_chunks) { assert(BytesAvailableInBuffer() >= (sizeof(MessageHeader) + payload_size)); // Copy header. MessageHeader header(partial_id, has_more_chunks, payload_size); memcpy(first_free_address_in_buffer(), &header, sizeof(MessageHeader)); metadata()->add_used_bytes(sizeof(MessageHeader)); // Copy payload. memcpy(first_free_address_in_buffer(), payload, payload_size); metadata()->add_used_bytes(payload_size); } MessageQueue::BufferMetadata* MessageQueue::metadata() const { return reinterpret_cast(buffer_->data); } size_t MessageQueue::BytesAvailableInBuffer() const { return buffer_->capacity - metadata()->total_bytes_used_including_metadata(); } Message* MessageQueue::first_message_in_buffer() const { return reinterpret_cast(reinterpret_cast(buffer_->data) + sizeof(BufferMetadata)); } void* MessageQueue::first_free_address_in_buffer() const { if (metadata()->total_bytes_used_including_metadata() >= buffer_->capacity) return nullptr; return reinterpret_cast( reinterpret_cast(buffer_->data) + metadata()->total_bytes_used_including_metadata()); } TEST_SUITE("MessageQueue"); TEST_CASE("simple") { MessageQueue queue(Buffer::Create(kMinimumPartialPayloadSize * 5), false /*buffer_has_data*/); int data = 0; data = 1; queue.Enqueue(Message(&data, sizeof(data))); data = 2; queue.Enqueue(Message(&data, sizeof(data))); int expected = 0; for (std::unique_ptr& m : queue.DequeueAll()) { ++expected; REQUIRE(m->capacity == sizeof(data)); int* value = reinterpret_cast(m->data); REQUIRE(expected == *value); } } TEST_CASE("large payload") { MessageQueue queue(Buffer::Create(kMinimumPartialPayloadSize * 5), false /*buffer_has_data*/); // Allocate big buffer. size_t num_ints = kMinimumPartialPayloadSize * 100; int* sent_ints = reinterpret_cast(malloc(sizeof(int) * num_ints)); for (int i = 0; i < num_ints; ++i) sent_ints[i] = i; // Queue big buffer. Add surrounding messages to make sure they get sent // correctly. // Run in a separate thread because Enqueue will block. volatile bool done_sending = false; std::thread sender([&]() { int small = 5; queue.Enqueue(Message(&small, sizeof(small))); queue.Enqueue(Message(sent_ints, sizeof(int) * num_ints)); queue.Enqueue(Message(&small, sizeof(small))); done_sending = true; }); // Receive sent messages. { // Keep dequeuing messages until we have three. std::vector> messages; while (messages.size() != 3) { for (auto& message : queue.DequeueAll()) messages.emplace_back(std::move(message)); } sender.join(); // Small { REQUIRE(sizeof(int) == messages[0]->capacity); int* value = reinterpret_cast(messages[0]->data); REQUIRE(*value == 5); } // Big { int* received_ints = reinterpret_cast(messages[1]->data); REQUIRE(received_ints != sent_ints); REQUIRE(messages[1]->capacity == (sizeof(int) * num_ints)); for (int i = 0; i < num_ints; ++i) { REQUIRE(received_ints[i] == i); REQUIRE(received_ints[i] == sent_ints[i]); } } // Small { REQUIRE(sizeof(int) == messages[2]->capacity); int* value = reinterpret_cast(messages[2]->data); REQUIRE(*value == 5); } } free(sent_ints); } TEST_SUITE_END();