Ver código fonte

Fix multipart write/read.

lganzzzo 4 anos atrás
pai
commit
ce46616075

+ 2 - 0
CMakeLists.txt

@@ -74,6 +74,8 @@ message("\n#####################################################################
 
 message("oatpp version: '${OATPP_THIS_MODULE_VERSION}'")
 
+#SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=thread")
+
 add_subdirectory(src)
 
 if(OATPP_BUILD_TESTS)

+ 5 - 46
src/oatpp/web/mime/multipart/Reader.cpp

@@ -136,56 +136,15 @@ void AsyncPartsParser::setDefaultPartReader(const std::shared_ptr<AsyncPartReade
 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
 // InMemoryReader
 
-Reader::Reader(Multipart* multipart, data::stream::ReadCallback* readCallback)
+Reader::Reader(Multipart* multipart)
   : m_partsParser(std::make_shared<PartsParser>(multipart))
   , m_parser(multipart->getBoundary(), m_partsParser, nullptr)
-  , m_readCallback(readCallback)
 {}
 
-void Reader::readAll() {
-
-  data::v_io_size res = -1;
-  data::buffer::IOBuffer buffer;
-
-  while(!m_parser.finished()) {
-
-    async::Action action;
-    res = m_readCallback->read(buffer.getData(), buffer.getSize(), action);
-
-    if(!action.isNone()) {
-      throw std::runtime_error("[oatpp::web::mime::multipart::Reader::readAll()]: Error. Async action is unexpected.");
-    }
-
-    if(res > 0) {
-
-      data::buffer::InlineWriteData inlineData(buffer.getData(), res);
-      while(inlineData.bytesLeft > 0 && !m_parser.finished()) {
-        async::Action action;
-        m_parser.parseNext(inlineData, action);
-        if(!action.isNone()) {
-          throw std::runtime_error("[oatpp::web::mime::multipart::Reader::readAll()]: Error. Async action is unexpected.");
-        }
-      }
-
-    } else {
-
-      switch(res) {
-
-        case data::IOError::RETRY_READ:
-          continue;
-
-        case data::IOError::RETRY_WRITE:
-          continue;
-
-        default:
-          return;
-
-      }
-
-    }
-
-  }
-
+data::v_io_size Reader::write(const void *data, v_buff_size count, async::Action& action) {
+  data::buffer::InlineWriteData inlineData(data, count);
+  m_parser.parseNext(inlineData, action);
+  return count - inlineData.bytesLeft;
 }
 
 void Reader::setPartReader(const oatpp::String& partName, const std::shared_ptr<PartReader>& reader) {

+ 4 - 4
src/oatpp/web/mime/multipart/Reader.hpp

@@ -170,21 +170,21 @@ public:
 
 /**
  * In memory Multipart reader.
+ * Extends - &id:oatpp::data::stream::WriteCallback;.
  */
-class Reader {
+class Reader : public oatpp::data::stream::WriteCallback {
 private:
   std::shared_ptr<PartsParser> m_partsParser;
   StatefulParser m_parser;
-  data::stream::ReadCallback* m_readCallback;
 public:
 
   /**
    * Constructor.
    * @param multipart - Multipart object to save read data to.
    */
-  Reader(Multipart* multipart, data::stream::ReadCallback*);
+  Reader(Multipart* multipart);
 
-  void readAll();
+  data::v_io_size write(const void *data, v_buff_size count, async::Action& action) override;
 
   /**
    * Set named part reader. <br>

+ 4 - 6
src/oatpp/web/protocol/http/incoming/ResponseHeadersReader.cpp

@@ -24,7 +24,7 @@
 
 #include "ResponseHeadersReader.hpp"
 
-#include "oatpp/core/data/stream/ChunkedBuffer.hpp"
+#include "oatpp/core/data/stream/BufferStream.hpp"
 
 namespace oatpp { namespace web { namespace protocol { namespace http { namespace incoming {
 
@@ -56,6 +56,7 @@ data::v_io_size ResponseHeadersReader::readHeadersSectionIterative(ReadHeadersIt
         result.bufferPosStart = i + 1;
         result.bufferPosEnd = res;
         iteration.done = true;
+        return res;
       }
     }
 
@@ -72,7 +73,7 @@ ResponseHeadersReader::Result ResponseHeadersReader::readHeaders(const std::shar
   ReadHeadersIteration iteration;
   async::Action action;
   
-  oatpp::data::stream::ChunkedBuffer buffer;
+  oatpp::data::stream::BufferOutputStream buffer;
 
   while(!iteration.done) {
 
@@ -92,7 +93,6 @@ ResponseHeadersReader::Result ResponseHeadersReader::readHeaders(const std::shar
 
   }
 
-  
   if(error.ioStatus > 0) {
     auto headersText = buffer.toString();
     oatpp::parser::Caret caret (headersText);
@@ -115,18 +115,16 @@ ResponseHeadersReader::readHeadersAsync(const std::shared_ptr<oatpp::data::strea
   private:
     ResponseHeadersReader* m_this;
     std::shared_ptr<oatpp::data::stream::IOStream> m_connection;
-    v_word32 m_accumulator;
     v_buff_size m_progress;
     ReadHeadersIteration m_iteration;
     ResponseHeadersReader::Result m_result;
-    oatpp::data::stream::ChunkedBuffer m_bufferStream;
+    oatpp::data::stream::BufferOutputStream m_bufferStream;
   public:
     
     ReaderCoroutine(ResponseHeadersReader* _this,
                     const std::shared_ptr<oatpp::data::stream::IOStream>& connection)
     : m_this(_this)
     , m_connection(connection)
-    , m_accumulator(0)
     , m_progress(0)
     {}
     

+ 90 - 90
src/oatpp/web/protocol/http/outgoing/MultipartBody.cpp

@@ -38,77 +38,77 @@ MultipartBody::MultipartReadCallback::MultipartReadCallback(const std::shared_pt
   , m_readStream(nullptr, nullptr, 0)
 {}
 
-data::v_io_size MultipartBody::MultipartReadCallback::readBody(void *buffer, v_buff_size count) {
-//  auto& part = *m_iterator;
-//  const auto& stream = part->getInputStream();
-//  if(!stream) {
-//    OATPP_LOGW("[oatpp::web::protocol::http::outgoing::MultipartBody::MultipartReadCallback::readBody()]", "Warning. Part has no input stream", m_state);
-//    m_iterator ++;
-//    return 0;
-//  }
-//  auto res = stream->read(buffer, count);
-//  if(res == 0) {
-//    m_iterator ++;
-//  }
-//  return res;
+data::v_io_size MultipartBody::MultipartReadCallback::readBody(void *buffer, v_buff_size count, async::Action& action) {
+  auto& part = *m_iterator;
+  const auto& stream = part->getInputStream();
+  if(!stream) {
+    OATPP_LOGW("[oatpp::web::protocol::http::outgoing::MultipartBody::MultipartReadCallback::readBody()]", "Warning. Part has no input stream", m_state);
+    m_iterator ++;
+    return 0;
+  }
+  auto res = stream->read(buffer, count, action);
+  if(res == 0) {
+    m_iterator ++;
+  }
+  return res;
 }
 
 data::v_io_size MultipartBody::MultipartReadCallback::read(void *buffer, v_buff_size count, async::Action& action) {
-//
-//  if(m_state == STATE_FINISHED) {
-//    return 0;
-//  }
-//
-//  p_char8 currBufferPtr = (p_char8) buffer;
-//  data::v_io_size bytesLeft = count;
-//
-//  data::v_io_size res = 0;
-//
-//  while(bytesLeft > 0) {
-//
-//    switch (m_state) {
-//
-//      case STATE_BOUNDARY:
-//        res = readBoundary(m_multipart, m_iterator, m_readStream, currBufferPtr, bytesLeft);
-//        break;
-//
-//      case STATE_HEADERS:
-//        res = readHeaders(m_multipart, m_iterator, m_readStream, currBufferPtr, bytesLeft);
-//        break;
-//
-//      case STATE_BODY:
-//        res = readBody(currBufferPtr, bytesLeft);
-//        break;
-//
-//      default:
-//        OATPP_LOGE("[oatpp::web::protocol::http::outgoing::MultipartBody::MultipartReadCallback::read()]", "Error. Invalid state %d", m_state);
-//        return 0;
-//
-//    }
-//
-//    if(res > 0) {
-//      currBufferPtr = &currBufferPtr[res];
-//      bytesLeft -= res;
-//    } else if(res == 0) {
-//
-//      if(m_state == STATE_BOUNDARY && m_iterator == m_multipart->getAllParts().end()) {
-//        m_state = STATE_FINISHED;
-//        break;
-//      }
-//
-//      m_state += 1;
-//      if(m_state == STATE_ROUND) {
-//        m_state = 0;
-//      }
-//
-//    } else {
-//      OATPP_LOGE("[oatpp::web::protocol::http::outgoing::MultipartBody::MultipartReadCallback::read()]", "Error. Invalid read result %d. State=%d", res, m_state);
-//      return 0;
-//    }
-//
-//  }
-//
-//  return count - bytesLeft;
+
+  if(m_state == STATE_FINISHED) {
+    return 0;
+  }
+
+  p_char8 currBufferPtr = (p_char8) buffer;
+  data::v_io_size bytesLeft = count;
+
+  data::v_io_size res = 0;
+
+  while(bytesLeft > 0 && action.isNone()) {
+
+    switch (m_state) {
+
+      case STATE_BOUNDARY:
+        res = readBoundary(m_multipart, m_iterator, m_readStream, currBufferPtr, bytesLeft);
+        break;
+
+      case STATE_HEADERS:
+        res = readHeaders(m_multipart, m_iterator, m_readStream, currBufferPtr, bytesLeft);
+        break;
+
+      case STATE_BODY:
+        res = readBody(currBufferPtr, bytesLeft, action);
+        break;
+
+      default:
+        OATPP_LOGE("[oatpp::web::protocol::http::outgoing::MultipartBody::MultipartReadCallback::read()]", "Error. Invalid state %d", m_state);
+        return 0;
+
+    }
+
+    if(res > 0) {
+      currBufferPtr = &currBufferPtr[res];
+      bytesLeft -= res;
+    } else if(res == 0) {
+
+      if(m_state == STATE_BOUNDARY && m_iterator == m_multipart->getAllParts().end()) {
+        m_state = STATE_FINISHED;
+        break;
+      }
+
+      m_state += 1;
+      if(m_state == STATE_ROUND) {
+        m_state = 0;
+      }
+
+    } else if(action.isNone()) {
+      OATPP_LOGE("[oatpp::web::protocol::http::outgoing::MultipartBody::MultipartReadCallback::read()]", "Error. Invalid read result %d. State=%d", res, m_state);
+      return 0;
+    }
+
+  }
+
+  return count - bytesLeft;
 }
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -120,28 +120,28 @@ data::v_io_size MultipartBody::readBoundary(const std::shared_ptr<Multipart>& mu
                                             void *buffer,
                                             v_buff_size count)
 {
-//  if (!readStream.getDataMemoryHandle()) {
-//
-//    oatpp::String boundary;
-//
-//    if (iterator == multipart->getAllParts().end()) {
-//      boundary = "\r\n--" + multipart->getBoundary() + "--\r\n";
-//    } else if (iterator == multipart->getAllParts().begin()) {
-//      boundary = "--" + multipart->getBoundary() + "\r\n";
-//    } else {
-//      boundary = "\r\n--" + multipart->getBoundary() + "\r\n";
-//    }
-//
-//    readStream.reset(boundary.getPtr(), boundary->getData(), boundary->getSize());
-//
-//  }
-//
-//  auto res = readStream.read(buffer, count);
-//  if(res == 0) {
-//    readStream.reset();
-//  }
-//
-//  return res;
+  if (!readStream.getDataMemoryHandle()) {
+
+    oatpp::String boundary;
+
+    if (iterator == multipart->getAllParts().end()) {
+      boundary = "\r\n--" + multipart->getBoundary() + "--\r\n";
+    } else if (iterator == multipart->getAllParts().begin()) {
+      boundary = "--" + multipart->getBoundary() + "\r\n";
+    } else {
+      boundary = "\r\n--" + multipart->getBoundary() + "\r\n";
+    }
+
+    readStream.reset(boundary.getPtr(), boundary->getData(), boundary->getSize());
+
+  }
+
+  auto res = readStream.readSimple(buffer, count);
+  if(res == 0) {
+    readStream.reset();
+  }
+
+  return res;
 }
 
 data::v_io_size MultipartBody::readHeaders(const std::shared_ptr<Multipart>& multipart,

+ 1 - 1
src/oatpp/web/protocol/http/outgoing/MultipartBody.hpp

@@ -82,7 +82,7 @@ private:
     v_int32 m_state;
     oatpp::data::stream::BufferInputStream m_readStream;
   private:
-    data::v_io_size readBody(void *buffer, v_buff_size count);
+    data::v_io_size readBody(void *buffer, v_buff_size count, async::Action& action);
   public:
 
     MultipartReadCallback(const std::shared_ptr<Multipart>& multipart);

+ 0 - 5
test/oatpp/AllTestsMain.cpp

@@ -75,12 +75,9 @@ void runTests() {
   OATPP_RUN_TEST(oatpp::test::core::data::share::MemoryLabelTest);
 
   OATPP_RUN_TEST(oatpp::test::core::data::share::LazyStringMapTest);
-*/
 
   OATPP_RUN_TEST(oatpp::test::core::data::buffer::ProcessorTest);
 
-/*
-
   OATPP_RUN_TEST(oatpp::test::core::data::stream::ChunkedBufferTest);
   OATPP_RUN_TEST(oatpp::test::core::data::stream::BufferStreamTest);
 
@@ -102,11 +99,9 @@ void runTests() {
 
   OATPP_RUN_TEST(oatpp::test::network::virtual_::PipeTest);
   OATPP_RUN_TEST(oatpp::test::network::virtual_::InterfaceTest);
-*/
 
   OATPP_RUN_TEST(oatpp::test::web::protocol::http::encoding::ChunkedTest);
 
-/*
   OATPP_RUN_TEST(oatpp::test::web::mime::multipart::StatefulParserTest);
 
   OATPP_RUN_TEST(oatpp::test::web::server::api::ApiControllerTest);

+ 2 - 2
test/oatpp/web/FullAsyncTest.cpp

@@ -240,11 +240,11 @@ void FullAsyncTest::onRun() {
 
         multipart = std::make_shared<oatpp::web::mime::multipart::Multipart>(response->getHeaders());
 
-        oatpp::web::mime::multipart::Reader multipartReader(multipart.get(), response->getBodyStream().get());
+        oatpp::web::mime::multipart::Reader multipartReader(multipart.get());
         multipartReader.setPartReader("value1", std::make_shared<oatpp::web::mime::multipart::InMemoryPartReader>(10));
         multipartReader.setPartReader("value2", std::make_shared<oatpp::web::mime::multipart::InMemoryPartReader>(10));
 
-        multipartReader.readAll();
+        response->transferBody(&multipartReader);
 
         OATPP_ASSERT(multipart->getAllParts().size() == 2);
         auto part1 = multipart->getNamedPart("value1");

+ 2 - 2
test/oatpp/web/FullTest.cpp

@@ -415,11 +415,11 @@ void FullTest::onRun() {
 
         multipart = std::make_shared<oatpp::web::mime::multipart::Multipart>(response->getHeaders());
 
-        oatpp::web::mime::multipart::Reader multipartReader(multipart.get(), response->getBodyStream().get());
+        oatpp::web::mime::multipart::Reader multipartReader(multipart.get());
         multipartReader.setPartReader("value1", std::make_shared<oatpp::web::mime::multipart::InMemoryPartReader>(10));
         multipartReader.setPartReader("value2", std::make_shared<oatpp::web::mime::multipart::InMemoryPartReader>(10));
 
-        multipartReader.readAll();
+        response->transferBody(&multipartReader);
 
         OATPP_ASSERT(multipart->getAllParts().size() == 2);
         auto part1 = multipart->getNamedPart("value1");

+ 4 - 4
test/oatpp/web/app/Controller.hpp

@@ -206,10 +206,10 @@ public:
 
     auto multipart = std::make_shared<oatpp::web::mime::multipart::Multipart>(request->getHeaders());
 
-    oatpp::web::mime::multipart::Reader multipartReader(multipart.get(), request->getBodyStream().get());
+    oatpp::web::mime::multipart::Reader multipartReader(multipart.get());
     multipartReader.setDefaultPartReader(std::make_shared<oatpp::web::mime::multipart::InMemoryPartReader>(10));
 
-    multipartReader.readAll();
+    request->transferBody(&multipartReader);
 
     auto responseBody = std::make_shared<oatpp::web::protocol::http::outgoing::MultipartBody>(multipart);
 
@@ -225,7 +225,7 @@ public:
     auto multipart = std::make_shared<multipart::Multipart>(request->getHeaders());
 
     /* Create multipart reader. */
-    multipart::Reader multipartReader(multipart.get(), request->getBodyStream().get());
+    multipart::Reader multipartReader(multipart.get());
 
     /* Configure to read part with name "part1" into memory */
     multipartReader.setPartReader("part1", multipart::createInMemoryPartReader(256 /* max-data-size */));
@@ -237,7 +237,7 @@ public:
     multipartReader.setDefaultPartReader(multipart::createInMemoryPartReader(16 * 1024 /* max-data-size */));
 
     /* Read multipart body */
-    multipartReader.readAll();
+    request->transferBody(&multipartReader);
 
     /* Print number of uploaded parts */
     OATPP_LOGD("Multipart", "parts_count=%d", multipart->count());

+ 31 - 0
test/oatpp/web/protocol/http/encoding/ChunkedTest.cpp

@@ -35,6 +35,37 @@ void ChunkedTest::onRun() {
   oatpp::String encoded;
   oatpp::String decoded;
 
+  { // Empty string
+    oatpp::data::stream::BufferInputStream inStream(oatpp::String(""));
+    oatpp::data::stream::BufferOutputStream outStream;
+
+    oatpp::web::protocol::http::encoding::EncoderChunked encoder;
+
+    const v_int32 bufferSize = 5;
+    v_char8 buffer[bufferSize];
+
+    auto count = oatpp::data::stream::transfer(&inStream, &outStream, 0, buffer, bufferSize, &encoder);
+    encoded = outStream.toString();
+
+    OATPP_ASSERT(count == 0);
+    OATPP_ASSERT(encoded == "0\r\n\r\n");
+  }
+
+  { // Empty string
+    oatpp::data::stream::BufferInputStream inStream(encoded);
+    oatpp::data::stream::BufferOutputStream outStream;
+
+    oatpp::web::protocol::http::encoding::DecoderChunked decoder;
+
+    const v_int32 bufferSize = 5;
+    v_char8 buffer[bufferSize];
+
+    auto count = oatpp::data::stream::transfer(&inStream, &outStream, 0, buffer, bufferSize, &decoder);
+    decoded = outStream.toString();
+    OATPP_ASSERT(count == encoded->getSize());
+    OATPP_ASSERT(decoded == "");
+  }
+
   {
     oatpp::data::stream::BufferInputStream inStream(data);
     oatpp::data::stream::BufferOutputStream outStream;