diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index 553d7616c..fad65f0fa 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -380,7 +380,7 @@ struct string_caster { return false; } if (!PyUnicode_Check(load_src.ptr())) { - return load_bytes(load_src); + return load_raw(load_src); } // For UTF-8 we avoid the need for a temporary `bytes` object by using @@ -458,26 +458,37 @@ private: #endif } - // When loading into a std::string or char*, accept a bytes object as-is (i.e. + // When loading into a std::string or char*, accept a bytes/bytearray object as-is (i.e. // without any encoding/decoding attempt). For other C++ char sizes this is a no-op. // which supports loading a unicode from a str, doesn't take this path. template - bool load_bytes(enable_if_t::value, handle> src) { + bool load_raw(enable_if_t::value, handle> src) { if (PYBIND11_BYTES_CHECK(src.ptr())) { // We were passed raw bytes; accept it into a std::string or char* // without any encoding attempt. const char *bytes = PYBIND11_BYTES_AS_STRING(src.ptr()); - if (bytes) { - value = StringType(bytes, (size_t) PYBIND11_BYTES_SIZE(src.ptr())); - return true; + if (!bytes) { + pybind11_fail("Unexpected PYBIND11_BYTES_AS_STRING() failure."); } + value = StringType(bytes, (size_t) PYBIND11_BYTES_SIZE(src.ptr())); + return true; + } + if (PyByteArray_Check(src.ptr())) { + // We were passed a bytearray; accept it into a std::string or char* + // without any encoding attempt. + const char *bytearray = PyByteArray_AsString(src.ptr()); + if (!bytearray) { + pybind11_fail("Unexpected PyByteArray_AsString() failure."); + } + value = StringType(bytearray, (size_t) PyByteArray_Size(src.ptr())); + return true; } return false; } template - bool load_bytes(enable_if_t::value, handle>) { + bool load_raw(enable_if_t::value, handle>) { return false; } }; diff --git a/tests/test_builtin_casters.py b/tests/test_builtin_casters.py index 02207f24f..d38ae6802 100644 --- a/tests/test_builtin_casters.py +++ b/tests/test_builtin_casters.py @@ -133,6 +133,15 @@ def test_bytes_to_string(): assert m.string_length("💩".encode()) == 4 +def test_bytearray_to_string(): + """Tests the ability to pass bytearray to C++ string-accepting functions""" + assert m.string_length(bytearray(b"Hi")) == 2 + assert m.strlen(bytearray(b"bytearray")) == 9 + assert m.string_length(bytearray()) == 0 + assert m.string_length(bytearray("🦜", "utf-8", "strict")) == 4 + assert m.string_length(bytearray(b"\x80")) == 1 + + @pytest.mark.skipif(not hasattr(m, "has_string_view"), reason="no ") def test_string_view(capture): """Tests support for C++17 string_view arguments and return values"""